1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/mask_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_TARGETING_SSE4_1
19
20 #include <smmintrin.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33 namespace {
34
35 template <int subsampling_x, int subsampling_y>
GetMask8(const uint8_t * mask,const ptrdiff_t stride)36 inline __m128i GetMask8(const uint8_t* mask, const ptrdiff_t stride) {
37 if (subsampling_x == 1 && subsampling_y == 1) {
38 const __m128i one = _mm_set1_epi8(1);
39 const __m128i mask_val_0 = LoadUnaligned16(mask);
40 const __m128i mask_val_1 = LoadUnaligned16(mask + stride);
41 const __m128i add_0 = _mm_adds_epu8(mask_val_0, mask_val_1);
42 const __m128i mask_0 = _mm_maddubs_epi16(add_0, one);
43 return RightShiftWithRounding_U16(mask_0, 2);
44 }
45 if (subsampling_x == 1) {
46 const __m128i row_vals = LoadUnaligned16(mask);
47 const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals);
48 const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8));
49 __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
50 return RightShiftWithRounding_U16(subsampled_mask, 1);
51 }
52 assert(subsampling_y == 0 && subsampling_x == 0);
53 const __m128i mask_val = LoadLo8(mask);
54 return _mm_cvtepu8_epi16(mask_val);
55 }
56
57 // Imitate behavior of ARM vtrn1q_u64.
Transpose1_U64(const __m128i a,const __m128i b)58 inline __m128i Transpose1_U64(const __m128i a, const __m128i b) {
59 return _mm_castps_si128(
60 _mm_movelh_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b)));
61 }
62
63 // Imitate behavior of ARM vtrn2q_u64.
Transpose2_U64(const __m128i a,const __m128i b)64 inline __m128i Transpose2_U64(const __m128i a, const __m128i b) {
65 return _mm_castps_si128(
66 _mm_movehl_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b)));
67 }
68
69 // Width can only be 4 when it is subsampled from a block of width 8, hence
70 // subsampling_x is always 1 when this function is called.
71 template <int subsampling_x, int subsampling_y>
GetMask4x2(const uint8_t * mask)72 inline __m128i GetMask4x2(const uint8_t* mask) {
73 if (subsampling_x == 1 && subsampling_y == 1) {
74 const __m128i mask_val_01 = LoadUnaligned16(mask);
75 // Stride is fixed because this is the smallest block size.
76 const __m128i mask_val_23 = LoadUnaligned16(mask + 16);
77 // Transpose rows to add row 0 to row 1, and row 2 to row 3.
78 const __m128i mask_val_02 = Transpose1_U64(mask_val_01, mask_val_23);
79 const __m128i mask_val_13 = Transpose2_U64(mask_val_23, mask_val_01);
80 const __m128i add_0 = _mm_adds_epu8(mask_val_02, mask_val_13);
81 const __m128i one = _mm_set1_epi8(1);
82 const __m128i mask_0 = _mm_maddubs_epi16(add_0, one);
83 return RightShiftWithRounding_U16(mask_0, 2);
84 }
85 return GetMask8<subsampling_x, 0>(mask, 0);
86 }
87
88 template <int subsampling_x, int subsampling_y>
GetInterIntraMask4x2(const uint8_t * mask,ptrdiff_t mask_stride)89 inline __m128i GetInterIntraMask4x2(const uint8_t* mask,
90 ptrdiff_t mask_stride) {
91 if (subsampling_x == 1) {
92 return GetMask4x2<subsampling_x, subsampling_y>(mask);
93 }
94 // When using intra or difference weighted masks, the function doesn't use
95 // subsampling, so |mask_stride| may be 4 or 8.
96 assert(subsampling_y == 0 && subsampling_x == 0);
97 const __m128i mask_val_0 = Load4(mask);
98 const __m128i mask_val_1 = Load4(mask + mask_stride);
99 return _mm_cvtepu8_epi16(
100 _mm_or_si128(mask_val_0, _mm_slli_si128(mask_val_1, 4)));
101 }
102
103 } // namespace
104
105 namespace low_bitdepth {
106 namespace {
107
108 // This function returns a 16-bit packed mask to fit in _mm_madd_epi16.
109 // 16-bit is also the lowest packing for hadd, but without subsampling there is
110 // an unfortunate conversion required.
111 template <int subsampling_x, int subsampling_y>
GetMask8(const uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t stride)112 inline __m128i GetMask8(const uint8_t* LIBGAV1_RESTRICT mask,
113 ptrdiff_t stride) {
114 if (subsampling_x == 1) {
115 const __m128i row_vals = LoadUnaligned16(mask);
116
117 const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals);
118 const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8));
119 __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
120
121 if (subsampling_y == 1) {
122 const __m128i next_row_vals = LoadUnaligned16(mask + stride);
123 const __m128i next_mask_val_0 = _mm_cvtepu8_epi16(next_row_vals);
124 const __m128i next_mask_val_1 =
125 _mm_cvtepu8_epi16(_mm_srli_si128(next_row_vals, 8));
126 subsampled_mask = _mm_add_epi16(
127 subsampled_mask, _mm_hadd_epi16(next_mask_val_0, next_mask_val_1));
128 }
129 return RightShiftWithRounding_U16(subsampled_mask, 1 + subsampling_y);
130 }
131 assert(subsampling_y == 0 && subsampling_x == 0);
132 const __m128i mask_val = LoadLo8(mask);
133 return _mm_cvtepu8_epi16(mask_val);
134 }
135
WriteMaskBlendLine4x2(const int16_t * LIBGAV1_RESTRICT const pred_0,const int16_t * LIBGAV1_RESTRICT const pred_1,const __m128i pred_mask_0,const __m128i pred_mask_1,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)136 inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
137 const int16_t* LIBGAV1_RESTRICT const pred_1,
138 const __m128i pred_mask_0,
139 const __m128i pred_mask_1,
140 uint8_t* LIBGAV1_RESTRICT dst,
141 const ptrdiff_t dst_stride) {
142 const __m128i pred_val_0 = LoadAligned16(pred_0);
143 const __m128i pred_val_1 = LoadAligned16(pred_1);
144 const __m128i mask_lo = _mm_unpacklo_epi16(pred_mask_0, pred_mask_1);
145 const __m128i mask_hi = _mm_unpackhi_epi16(pred_mask_0, pred_mask_1);
146 const __m128i pred_lo = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
147 const __m128i pred_hi = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
148
149 // int res = (mask_value * prediction_0[x] +
150 // (64 - mask_value) * prediction_1[x]) >> 6;
151 const __m128i compound_pred_lo = _mm_madd_epi16(pred_lo, mask_lo);
152 const __m128i compound_pred_hi = _mm_madd_epi16(pred_hi, mask_hi);
153 const __m128i compound_pred = _mm_packus_epi32(
154 _mm_srli_epi32(compound_pred_lo, 6), _mm_srli_epi32(compound_pred_hi, 6));
155
156 // dst[x] = static_cast<Pixel>(
157 // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
158 // (1 << kBitdepth8) - 1));
159 const __m128i result = RightShiftWithRounding_S16(compound_pred, 4);
160 const __m128i res = _mm_packus_epi16(result, result);
161 Store4(dst, res);
162 Store4(dst + dst_stride, _mm_srli_si128(res, 4));
163 }
164
165 template <int subsampling_x, int subsampling_y>
MaskBlending4x4_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT mask,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)166 inline void MaskBlending4x4_SSE4_1(const int16_t* LIBGAV1_RESTRICT pred_0,
167 const int16_t* LIBGAV1_RESTRICT pred_1,
168 const uint8_t* LIBGAV1_RESTRICT mask,
169 uint8_t* LIBGAV1_RESTRICT dst,
170 const ptrdiff_t dst_stride) {
171 constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
172 const __m128i mask_inverter = _mm_set1_epi16(64);
173 __m128i pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
174 __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
175 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
176 dst_stride);
177 pred_0 += 4 << 1;
178 pred_1 += 4 << 1;
179 mask += mask_stride << (1 + subsampling_y);
180 dst += dst_stride << 1;
181
182 pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
183 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
184 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
185 dst_stride);
186 }
187
188 template <int subsampling_x, int subsampling_y>
MaskBlending4xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)189 inline void MaskBlending4xH_SSE4_1(
190 const int16_t* LIBGAV1_RESTRICT pred_0,
191 const int16_t* LIBGAV1_RESTRICT pred_1,
192 const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const int height,
193 uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
194 assert(subsampling_x == 1);
195 const uint8_t* mask = mask_ptr;
196 constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
197 if (height == 4) {
198 MaskBlending4x4_SSE4_1<subsampling_x, subsampling_y>(pred_0, pred_1, mask,
199 dst, dst_stride);
200 return;
201 }
202 const __m128i mask_inverter = _mm_set1_epi16(64);
203 int y = 0;
204 do {
205 __m128i pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
206 __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
207
208 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
209 dst_stride);
210 pred_0 += 4 << 1;
211 pred_1 += 4 << 1;
212 mask += mask_stride << (1 + subsampling_y);
213 dst += dst_stride << 1;
214
215 pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
216 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
217 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
218 dst_stride);
219 pred_0 += 4 << 1;
220 pred_1 += 4 << 1;
221 mask += mask_stride << (1 + subsampling_y);
222 dst += dst_stride << 1;
223
224 pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
225 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
226 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
227 dst_stride);
228 pred_0 += 4 << 1;
229 pred_1 += 4 << 1;
230 mask += mask_stride << (1 + subsampling_y);
231 dst += dst_stride << 1;
232
233 pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
234 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
235 WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
236 dst_stride);
237 pred_0 += 4 << 1;
238 pred_1 += 4 << 1;
239 mask += mask_stride << (1 + subsampling_y);
240 dst += dst_stride << 1;
241 y += 8;
242 } while (y < height);
243 }
244
245 template <int subsampling_x, int subsampling_y>
MaskBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dst_stride)246 inline void MaskBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
247 const void* LIBGAV1_RESTRICT prediction_1,
248 const ptrdiff_t /*prediction_stride_1*/,
249 const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
250 const ptrdiff_t mask_stride, const int width,
251 const int height, void* LIBGAV1_RESTRICT dest,
252 const ptrdiff_t dst_stride) {
253 auto* dst = static_cast<uint8_t*>(dest);
254 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
255 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
256 const ptrdiff_t pred_stride_0 = width;
257 const ptrdiff_t pred_stride_1 = width;
258 if (width == 4) {
259 MaskBlending4xH_SSE4_1<subsampling_x, subsampling_y>(
260 pred_0, pred_1, mask_ptr, height, dst, dst_stride);
261 return;
262 }
263 const uint8_t* mask = mask_ptr;
264 const __m128i mask_inverter = _mm_set1_epi16(64);
265 int y = 0;
266 do {
267 int x = 0;
268 do {
269 const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
270 mask + (x << subsampling_x), mask_stride);
271 // 64 - mask
272 const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
273 const __m128i mask_lo = _mm_unpacklo_epi16(pred_mask_0, pred_mask_1);
274 const __m128i mask_hi = _mm_unpackhi_epi16(pred_mask_0, pred_mask_1);
275
276 const __m128i pred_val_0 = LoadAligned16(pred_0 + x);
277 const __m128i pred_val_1 = LoadAligned16(pred_1 + x);
278 const __m128i pred_lo = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
279 const __m128i pred_hi = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
280 // int res = (mask_value * prediction_0[x] +
281 // (64 - mask_value) * prediction_1[x]) >> 6;
282 const __m128i compound_pred_lo = _mm_madd_epi16(pred_lo, mask_lo);
283 const __m128i compound_pred_hi = _mm_madd_epi16(pred_hi, mask_hi);
284
285 const __m128i res = _mm_packus_epi32(_mm_srli_epi32(compound_pred_lo, 6),
286 _mm_srli_epi32(compound_pred_hi, 6));
287 // dst[x] = static_cast<Pixel>(
288 // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
289 // (1 << kBitdepth8) - 1));
290 const __m128i result = RightShiftWithRounding_S16(res, 4);
291 StoreLo8(dst + x, _mm_packus_epi16(result, result));
292
293 x += 8;
294 } while (x < width);
295 dst += dst_stride;
296 pred_0 += pred_stride_0;
297 pred_1 += pred_stride_1;
298 mask += mask_stride << subsampling_y;
299 } while (++y < height);
300 }
301
InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t * LIBGAV1_RESTRICT const pred_0,uint8_t * LIBGAV1_RESTRICT const pred_1,const ptrdiff_t pred_stride_1,const __m128i pred_mask_0,const __m128i pred_mask_1)302 inline void InterIntraWriteMaskBlendLine8bpp4x2(
303 const uint8_t* LIBGAV1_RESTRICT const pred_0,
304 uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1,
305 const __m128i pred_mask_0, const __m128i pred_mask_1) {
306 const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1);
307
308 const __m128i pred_val_0 = LoadLo8(pred_0);
309 __m128i pred_val_1 = Load4(pred_1);
310 pred_val_1 = _mm_or_si128(_mm_slli_si128(Load4(pred_1 + pred_stride_1), 4),
311 pred_val_1);
312 const __m128i pred = _mm_unpacklo_epi8(pred_val_0, pred_val_1);
313 // int res = (mask_value * prediction_1[x] +
314 // (64 - mask_value) * prediction_0[x]) >> 6;
315 const __m128i compound_pred = _mm_maddubs_epi16(pred, pred_mask);
316 const __m128i result = RightShiftWithRounding_U16(compound_pred, 6);
317 const __m128i res = _mm_packus_epi16(result, result);
318
319 Store4(pred_1, res);
320 Store4(pred_1 + pred_stride_1, _mm_srli_si128(res, 4));
321 }
322
323 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp4x4_SSE4_1(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride)324 inline void InterIntraMaskBlending8bpp4x4_SSE4_1(
325 const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
326 const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
327 const ptrdiff_t mask_stride) {
328 const __m128i mask_inverter = _mm_set1_epi8(64);
329 const __m128i pred_mask_u16_first =
330 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
331 mask += mask_stride << (1 + subsampling_y);
332 const __m128i pred_mask_u16_second =
333 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
334 mask += mask_stride << (1 + subsampling_y);
335 __m128i pred_mask_1 =
336 _mm_packus_epi16(pred_mask_u16_first, pred_mask_u16_second);
337 __m128i pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
338 InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
339 pred_mask_0, pred_mask_1);
340 pred_0 += 4 << 1;
341 pred_1 += pred_stride_1 << 1;
342
343 pred_mask_1 = _mm_srli_si128(pred_mask_1, 8);
344 pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
345 InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
346 pred_mask_0, pred_mask_1);
347 }
348
349 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp4xH_SSE4_1(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int height)350 inline void InterIntraMaskBlending8bpp4xH_SSE4_1(
351 const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
352 const ptrdiff_t pred_stride_1,
353 const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
354 const int height) {
355 const uint8_t* mask = mask_ptr;
356 if (height == 4) {
357 InterIntraMaskBlending8bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
358 pred_0, pred_1, pred_stride_1, mask, mask_stride);
359 return;
360 }
361 int y = 0;
362 do {
363 InterIntraMaskBlending8bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
364 pred_0, pred_1, pred_stride_1, mask, mask_stride);
365 pred_0 += 4 << 2;
366 pred_1 += pred_stride_1 << 2;
367 mask += mask_stride << (2 + subsampling_y);
368
369 InterIntraMaskBlending8bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
370 pred_0, pred_1, pred_stride_1, mask, mask_stride);
371 pred_0 += 4 << 2;
372 pred_1 += pred_stride_1 << 2;
373 mask += mask_stride << (2 + subsampling_y);
374 y += 8;
375 } while (y < height);
376 }
377
378 // This version returns 8-bit packed values to fit in _mm_maddubs_epi16 because,
379 // when is_inter_intra is true, the prediction values are brought to 8-bit
380 // packing as well.
381 template <int subsampling_x, int subsampling_y>
GetInterIntraMask8bpp8(const uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t stride)382 inline __m128i GetInterIntraMask8bpp8(const uint8_t* LIBGAV1_RESTRICT mask,
383 ptrdiff_t stride) {
384 if (subsampling_x == 1) {
385 const __m128i ret = GetMask8<subsampling_x, subsampling_y>(mask, stride);
386 return _mm_packus_epi16(ret, ret);
387 }
388 assert(subsampling_y == 0 && subsampling_x == 0);
389 // Unfortunately there is no shift operation for 8-bit packing, or else we
390 // could return everything with 8-bit packing.
391 const __m128i mask_val = LoadLo8(mask);
392 return mask_val;
393 }
394
395 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend8bpp_SSE4_1(const uint8_t * LIBGAV1_RESTRICT prediction_0,uint8_t * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height)396 void InterIntraMaskBlend8bpp_SSE4_1(
397 const uint8_t* LIBGAV1_RESTRICT prediction_0,
398 uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1,
399 const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
400 const int width, const int height) {
401 if (width == 4) {
402 InterIntraMaskBlending8bpp4xH_SSE4_1<subsampling_x, subsampling_y>(
403 prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
404 height);
405 return;
406 }
407 const uint8_t* mask = mask_ptr;
408 const __m128i mask_inverter = _mm_set1_epi8(64);
409 int y = 0;
410 do {
411 int x = 0;
412 do {
413 const __m128i pred_mask_1 =
414 GetInterIntraMask8bpp8<subsampling_x, subsampling_y>(
415 mask + (x << subsampling_x), mask_stride);
416 // 64 - mask
417 const __m128i pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
418 const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1);
419
420 const __m128i pred_val_0 = LoadLo8(prediction_0 + x);
421 const __m128i pred_val_1 = LoadLo8(prediction_1 + x);
422 const __m128i pred = _mm_unpacklo_epi8(pred_val_0, pred_val_1);
423 // int res = (mask_value * prediction_1[x] +
424 // (64 - mask_value) * prediction_0[x]) >> 6;
425 const __m128i compound_pred = _mm_maddubs_epi16(pred, pred_mask);
426 const __m128i result = RightShiftWithRounding_U16(compound_pred, 6);
427 const __m128i res = _mm_packus_epi16(result, result);
428
429 StoreLo8(prediction_1 + x, res);
430
431 x += 8;
432 } while (x < width);
433 prediction_0 += width;
434 prediction_1 += prediction_stride_1;
435 mask += mask_stride << subsampling_y;
436 } while (++y < height);
437 }
438
Init8bpp()439 void Init8bpp() {
440 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
441 assert(dsp != nullptr);
442 #if DSP_ENABLED_8BPP_SSE4_1(MaskBlend444)
443 dsp->mask_blend[0][0] = MaskBlend_SSE4_1<0, 0>;
444 #endif
445 #if DSP_ENABLED_8BPP_SSE4_1(MaskBlend422)
446 dsp->mask_blend[1][0] = MaskBlend_SSE4_1<1, 0>;
447 #endif
448 #if DSP_ENABLED_8BPP_SSE4_1(MaskBlend420)
449 dsp->mask_blend[2][0] = MaskBlend_SSE4_1<1, 1>;
450 #endif
451 // The is_inter_intra index of mask_blend[][] is replaced by
452 // inter_intra_mask_blend_8bpp[] in 8-bit.
453 #if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp444)
454 dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_SSE4_1<0, 0>;
455 #endif
456 #if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp422)
457 dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_SSE4_1<1, 0>;
458 #endif
459 #if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp420)
460 dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_SSE4_1<1, 1>;
461 #endif
462 }
463
464 } // namespace
465 } // namespace low_bitdepth
466
467 #if LIBGAV1_MAX_BITDEPTH >= 10
468 namespace high_bitdepth {
469 namespace {
470
471 constexpr int kMax10bppSample = (1 << 10) - 1;
472 constexpr int kMaskInverse = 64;
473 constexpr int kRoundBitsMaskBlend = 4;
474
RightShiftWithRoundingConst_S32(const __m128i v_val_d,int bits,const __m128i shift)475 inline __m128i RightShiftWithRoundingConst_S32(const __m128i v_val_d, int bits,
476 const __m128i shift) {
477 const __m128i v_tmp_d = _mm_add_epi32(v_val_d, shift);
478 return _mm_srai_epi32(v_tmp_d, bits);
479 }
480
481 template <int subsampling_x, int subsampling_y>
GetMask4x2(const uint8_t * mask)482 inline __m128i GetMask4x2(const uint8_t* mask) {
483 if (subsampling_x == 1 && subsampling_y == 1) {
484 const __m128i mask_row_01 = LoadUnaligned16(mask);
485 const __m128i mask_row_23 = LoadUnaligned16(mask + 16);
486 const __m128i mask_val_0 = _mm_cvtepu8_epi16(mask_row_01);
487 const __m128i mask_val_1 =
488 _mm_cvtepu8_epi16(_mm_srli_si128(mask_row_01, 8));
489 const __m128i mask_val_2 = _mm_cvtepu8_epi16(mask_row_23);
490 const __m128i mask_val_3 =
491 _mm_cvtepu8_epi16(_mm_srli_si128(mask_row_23, 8));
492 const __m128i subsampled_mask_02 = _mm_hadd_epi16(mask_val_0, mask_val_2);
493 const __m128i subsampled_mask_13 = _mm_hadd_epi16(mask_val_1, mask_val_3);
494 const __m128i subsampled_mask =
495 _mm_add_epi16(subsampled_mask_02, subsampled_mask_13);
496 return RightShiftWithRounding_U16(subsampled_mask, 2);
497 }
498 if (subsampling_x == 1) {
499 const __m128i mask_row_01 = LoadUnaligned16(mask);
500 const __m128i mask_val_0 = _mm_cvtepu8_epi16(mask_row_01);
501 const __m128i mask_val_1 =
502 _mm_cvtepu8_epi16(_mm_srli_si128(mask_row_01, 8));
503 const __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
504 return RightShiftWithRounding_U16(subsampled_mask, 1);
505 }
506 return _mm_cvtepu8_epi16(LoadLo8(mask));
507 }
508
WriteMaskBlendLine10bpp4x2_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const __m128i & pred_mask_0,const __m128i & pred_mask_1,const __m128i & offset,const __m128i & max,const __m128i & shift4,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)509 inline void WriteMaskBlendLine10bpp4x2_SSE4_1(
510 const uint16_t* LIBGAV1_RESTRICT pred_0,
511 const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
512 const __m128i& pred_mask_0, const __m128i& pred_mask_1,
513 const __m128i& offset, const __m128i& max, const __m128i& shift4,
514 uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
515 const __m128i pred_val_0 = LoadUnaligned16(pred_0);
516 const __m128i pred_val_1 = LoadHi8(LoadLo8(pred_1), pred_1 + pred_stride_1);
517
518 // int res = (mask_value * pred_0[x] + (64 - mask_value) * pred_1[x]) >> 6;
519 const __m128i compound_pred_lo_0 = _mm_mullo_epi16(pred_val_0, pred_mask_0);
520 const __m128i compound_pred_hi_0 = _mm_mulhi_epu16(pred_val_0, pred_mask_0);
521 const __m128i compound_pred_lo_1 = _mm_mullo_epi16(pred_val_1, pred_mask_1);
522 const __m128i compound_pred_hi_1 = _mm_mulhi_epu16(pred_val_1, pred_mask_1);
523 const __m128i pack0_lo =
524 _mm_unpacklo_epi16(compound_pred_lo_0, compound_pred_hi_0);
525 const __m128i pack0_hi =
526 _mm_unpackhi_epi16(compound_pred_lo_0, compound_pred_hi_0);
527 const __m128i pack1_lo =
528 _mm_unpacklo_epi16(compound_pred_lo_1, compound_pred_hi_1);
529 const __m128i pack1_hi =
530 _mm_unpackhi_epi16(compound_pred_lo_1, compound_pred_hi_1);
531 const __m128i compound_pred_lo = _mm_add_epi32(pack0_lo, pack1_lo);
532 const __m128i compound_pred_hi = _mm_add_epi32(pack0_hi, pack1_hi);
533 // res -= (bitdepth == 8) ? 0 : kCompoundOffset;
534 const __m128i sub_0 =
535 _mm_sub_epi32(_mm_srli_epi32(compound_pred_lo, 6), offset);
536 const __m128i sub_1 =
537 _mm_sub_epi32(_mm_srli_epi32(compound_pred_hi, 6), offset);
538
539 // dst[x] = static_cast<Pixel>(
540 // Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
541 // (1 << kBitdepth8) - 1));
542 const __m128i shift_0 =
543 RightShiftWithRoundingConst_S32(sub_0, kRoundBitsMaskBlend, shift4);
544 const __m128i shift_1 =
545 RightShiftWithRoundingConst_S32(sub_1, kRoundBitsMaskBlend, shift4);
546 const __m128i result = _mm_min_epi16(_mm_packus_epi32(shift_0, shift_1), max);
547 StoreLo8(dst, result);
548 StoreHi8(dst + dst_stride, result);
549 }
550
551 template <int subsampling_x, int subsampling_y>
MaskBlend10bpp4x4_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)552 inline void MaskBlend10bpp4x4_SSE4_1(const uint16_t* LIBGAV1_RESTRICT pred_0,
553 const uint16_t* LIBGAV1_RESTRICT pred_1,
554 const ptrdiff_t pred_stride_1,
555 const uint8_t* LIBGAV1_RESTRICT mask,
556 const ptrdiff_t mask_stride,
557 uint16_t* LIBGAV1_RESTRICT dst,
558 const ptrdiff_t dst_stride) {
559 const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
560 const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1);
561 const __m128i offset = _mm_set1_epi32(kCompoundOffset);
562 const __m128i max = _mm_set1_epi16(kMax10bppSample);
563 __m128i pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
564 __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
565 WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, pred_mask_0,
566 pred_mask_1, offset, max, shift4, dst,
567 dst_stride);
568 pred_0 += 4 << 1;
569 pred_1 += pred_stride_1 << 1;
570 mask += mask_stride << (1 + subsampling_y);
571 dst += dst_stride << 1;
572
573 pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
574 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
575 WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, pred_mask_0,
576 pred_mask_1, offset, max, shift4, dst,
577 dst_stride);
578 }
579
580 template <int subsampling_x, int subsampling_y>
MaskBlend10bpp4xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int height,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)581 inline void MaskBlend10bpp4xH_SSE4_1(
582 const uint16_t* LIBGAV1_RESTRICT pred_0,
583 const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
584 const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
585 const int height, uint16_t* LIBGAV1_RESTRICT dst,
586 const ptrdiff_t dst_stride) {
587 const uint8_t* mask = mask_ptr;
588 if (height == 4) {
589 MaskBlend10bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
590 pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride);
591 return;
592 }
593 const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
594 const uint8_t pred0_stride2 = 4 << 1;
595 const ptrdiff_t pred1_stride2 = pred_stride_1 << 1;
596 const ptrdiff_t mask_stride2 = mask_stride << (1 + subsampling_y);
597 const ptrdiff_t dst_stride2 = dst_stride << 1;
598 const __m128i offset = _mm_set1_epi32(kCompoundOffset);
599 const __m128i max = _mm_set1_epi16(kMax10bppSample);
600 const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1);
601 int y = height;
602 do {
603 __m128i pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
604 __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
605
606 WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
607 pred_mask_0, pred_mask_1, offset, max,
608 shift4, dst, dst_stride);
609 pred_0 += pred0_stride2;
610 pred_1 += pred1_stride2;
611 mask += mask_stride2;
612 dst += dst_stride2;
613
614 pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
615 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
616 WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
617 pred_mask_0, pred_mask_1, offset, max,
618 shift4, dst, dst_stride);
619 pred_0 += pred0_stride2;
620 pred_1 += pred1_stride2;
621 mask += mask_stride2;
622 dst += dst_stride2;
623
624 pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
625 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
626 WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
627 pred_mask_0, pred_mask_1, offset, max,
628 shift4, dst, dst_stride);
629 pred_0 += pred0_stride2;
630 pred_1 += pred1_stride2;
631 mask += mask_stride2;
632 dst += dst_stride2;
633
634 pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
635 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
636 WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
637 pred_mask_0, pred_mask_1, offset, max,
638 shift4, dst, dst_stride);
639 pred_0 += pred0_stride2;
640 pred_1 += pred1_stride2;
641 mask += mask_stride2;
642 dst += dst_stride2;
643 y -= 8;
644 } while (y != 0);
645 }
646
647 template <int subsampling_x, int subsampling_y>
MaskBlend10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)648 inline void MaskBlend10bpp_SSE4_1(
649 const void* LIBGAV1_RESTRICT prediction_0,
650 const void* LIBGAV1_RESTRICT prediction_1,
651 const ptrdiff_t prediction_stride_1,
652 const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
653 const int width, const int height, void* LIBGAV1_RESTRICT dest,
654 const ptrdiff_t dest_stride) {
655 auto* dst = static_cast<uint16_t*>(dest);
656 const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
657 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
658 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
659 const ptrdiff_t pred_stride_0 = width;
660 const ptrdiff_t pred_stride_1 = prediction_stride_1;
661 if (width == 4) {
662 MaskBlend10bpp4xH_SSE4_1<subsampling_x, subsampling_y>(
663 pred_0, pred_1, pred_stride_1, mask_ptr, mask_stride, height, dst,
664 dst_stride);
665 return;
666 }
667 const uint8_t* mask = mask_ptr;
668 const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
669 const ptrdiff_t mask_stride_ss = mask_stride << subsampling_y;
670 const __m128i offset = _mm_set1_epi32(kCompoundOffset);
671 const __m128i max = _mm_set1_epi16(kMax10bppSample);
672 const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1);
673 int y = height;
674 do {
675 int x = 0;
676 do {
677 const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
678 mask + (x << subsampling_x), mask_stride);
679 const __m128i pred_val_0 = LoadUnaligned16(pred_0 + x);
680 const __m128i pred_val_1 = LoadUnaligned16(pred_1 + x);
681 // 64 - mask
682 const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
683
684 const __m128i compound_pred_lo_0 =
685 _mm_mullo_epi16(pred_val_0, pred_mask_0);
686 const __m128i compound_pred_hi_0 =
687 _mm_mulhi_epu16(pred_val_0, pred_mask_0);
688 const __m128i compound_pred_lo_1 =
689 _mm_mullo_epi16(pred_val_1, pred_mask_1);
690 const __m128i compound_pred_hi_1 =
691 _mm_mulhi_epu16(pred_val_1, pred_mask_1);
692 const __m128i pack0_lo =
693 _mm_unpacklo_epi16(compound_pred_lo_0, compound_pred_hi_0);
694 const __m128i pack0_hi =
695 _mm_unpackhi_epi16(compound_pred_lo_0, compound_pred_hi_0);
696 const __m128i pack1_lo =
697 _mm_unpacklo_epi16(compound_pred_lo_1, compound_pred_hi_1);
698 const __m128i pack1_hi =
699 _mm_unpackhi_epi16(compound_pred_lo_1, compound_pred_hi_1);
700 const __m128i compound_pred_lo = _mm_add_epi32(pack0_lo, pack1_lo);
701 const __m128i compound_pred_hi = _mm_add_epi32(pack0_hi, pack1_hi);
702
703 const __m128i sub_0 =
704 _mm_sub_epi32(_mm_srli_epi32(compound_pred_lo, 6), offset);
705 const __m128i sub_1 =
706 _mm_sub_epi32(_mm_srli_epi32(compound_pred_hi, 6), offset);
707 const __m128i shift_0 =
708 RightShiftWithRoundingConst_S32(sub_0, kRoundBitsMaskBlend, shift4);
709 const __m128i shift_1 =
710 RightShiftWithRoundingConst_S32(sub_1, kRoundBitsMaskBlend, shift4);
711 const __m128i result =
712 _mm_min_epi16(_mm_packus_epi32(shift_0, shift_1), max);
713 StoreUnaligned16(dst + x, result);
714 x += 8;
715 } while (x < width);
716 dst += dst_stride;
717 pred_0 += pred_stride_0;
718 pred_1 += pred_stride_1;
719 mask += mask_stride_ss;
720 } while (--y != 0);
721 }
InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t pred_stride_1,const __m128i & pred_mask_0,const __m128i & pred_mask_1,const __m128i & shift6,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)722 inline void InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(
723 const uint16_t* LIBGAV1_RESTRICT prediction_0,
724 const uint16_t* LIBGAV1_RESTRICT prediction_1,
725 const ptrdiff_t pred_stride_1, const __m128i& pred_mask_0,
726 const __m128i& pred_mask_1, const __m128i& shift6,
727 uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
728 const __m128i pred_val_0 = LoadUnaligned16(prediction_0);
729 const __m128i pred_val_1 =
730 LoadHi8(LoadLo8(prediction_1), prediction_1 + pred_stride_1);
731
732 const __m128i mask_0 = _mm_unpacklo_epi16(pred_mask_1, pred_mask_0);
733 const __m128i mask_1 = _mm_unpackhi_epi16(pred_mask_1, pred_mask_0);
734 const __m128i pred_0 = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
735 const __m128i pred_1 = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
736
737 const __m128i compound_pred_0 = _mm_madd_epi16(pred_0, mask_0);
738 const __m128i compound_pred_1 = _mm_madd_epi16(pred_1, mask_1);
739 const __m128i shift_0 =
740 RightShiftWithRoundingConst_S32(compound_pred_0, 6, shift6);
741 const __m128i shift_1 =
742 RightShiftWithRoundingConst_S32(compound_pred_1, 6, shift6);
743 const __m128i res = _mm_packus_epi32(shift_0, shift_1);
744 StoreLo8(dst, res);
745 StoreHi8(dst + dst_stride, res);
746 }
747
748 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend10bpp4x4_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)749 inline void InterIntraMaskBlend10bpp4x4_SSE4_1(
750 const uint16_t* LIBGAV1_RESTRICT pred_0,
751 const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
752 const uint8_t* LIBGAV1_RESTRICT mask, const ptrdiff_t mask_stride,
753 uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
754 const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
755 const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1);
756 __m128i pred_mask_0 =
757 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
758 __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
759 InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
760 pred_mask_0, pred_mask_1, shift6,
761 dst, dst_stride);
762 pred_0 += 4 << 1;
763 pred_1 += pred_stride_1 << 1;
764 mask += mask_stride << (1 + subsampling_y);
765 dst += dst_stride << 1;
766
767 pred_mask_0 =
768 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
769 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
770 InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
771 pred_mask_0, pred_mask_1, shift6,
772 dst, dst_stride);
773 }
774
775 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend10bpp4xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int height,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)776 inline void InterIntraMaskBlend10bpp4xH_SSE4_1(
777 const uint16_t* LIBGAV1_RESTRICT pred_0,
778 const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
779 const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
780 const int height, uint16_t* LIBGAV1_RESTRICT dst,
781 const ptrdiff_t dst_stride) {
782 const uint8_t* mask = mask_ptr;
783 if (height == 4) {
784 InterIntraMaskBlend10bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
785 pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride);
786 return;
787 }
788 const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
789 const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1);
790 const uint8_t pred0_stride2 = 4 << 1;
791 const ptrdiff_t pred1_stride2 = pred_stride_1 << 1;
792 const ptrdiff_t mask_stride2 = mask_stride << (1 + subsampling_y);
793 const ptrdiff_t dst_stride2 = dst_stride << 1;
794 int y = height;
795 do {
796 __m128i pred_mask_0 =
797 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
798 __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
799 InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
800 pred_mask_0, pred_mask_1,
801 shift6, dst, dst_stride);
802 pred_0 += pred0_stride2;
803 pred_1 += pred1_stride2;
804 mask += mask_stride2;
805 dst += dst_stride2;
806
807 pred_mask_0 =
808 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
809 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
810 InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
811 pred_mask_0, pred_mask_1,
812 shift6, dst, dst_stride);
813 pred_0 += pred0_stride2;
814 pred_1 += pred1_stride2;
815 mask += mask_stride2;
816 dst += dst_stride2;
817
818 pred_mask_0 =
819 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
820 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
821 InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
822 pred_mask_0, pred_mask_1,
823 shift6, dst, dst_stride);
824 pred_0 += pred0_stride2;
825 pred_1 += pred1_stride2;
826 mask += mask_stride2;
827 dst += dst_stride2;
828
829 pred_mask_0 =
830 GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
831 pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
832 InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
833 pred_mask_0, pred_mask_1,
834 shift6, dst, dst_stride);
835 pred_0 += pred0_stride2;
836 pred_1 += pred1_stride2;
837 mask += mask_stride2;
838 dst += dst_stride2;
839 y -= 8;
840 } while (y != 0);
841 }
842
843 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)844 inline void InterIntraMaskBlend10bpp_SSE4_1(
845 const void* LIBGAV1_RESTRICT prediction_0,
846 const void* LIBGAV1_RESTRICT prediction_1,
847 const ptrdiff_t prediction_stride_1,
848 const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
849 const int width, const int height, void* LIBGAV1_RESTRICT dest,
850 const ptrdiff_t dest_stride) {
851 auto* dst = static_cast<uint16_t*>(dest);
852 const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
853 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
854 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
855 const ptrdiff_t pred_stride_0 = width;
856 const ptrdiff_t pred_stride_1 = prediction_stride_1;
857 if (width == 4) {
858 InterIntraMaskBlend10bpp4xH_SSE4_1<subsampling_x, subsampling_y>(
859 pred_0, pred_1, pred_stride_1, mask_ptr, mask_stride, height, dst,
860 dst_stride);
861 return;
862 }
863 const uint8_t* mask = mask_ptr;
864 const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
865 const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1);
866 const ptrdiff_t mask_stride_ss = mask_stride << subsampling_y;
867 int y = height;
868 do {
869 int x = 0;
870 do {
871 const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
872 mask + (x << subsampling_x), mask_stride);
873 const __m128i pred_val_0 = LoadUnaligned16(pred_0 + x);
874 const __m128i pred_val_1 = LoadUnaligned16(pred_1 + x);
875 // 64 - mask
876 const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
877 const __m128i mask_0 = _mm_unpacklo_epi16(pred_mask_1, pred_mask_0);
878 const __m128i mask_1 = _mm_unpackhi_epi16(pred_mask_1, pred_mask_0);
879 const __m128i pred_0 = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
880 const __m128i pred_1 = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
881
882 const __m128i compound_pred_0 = _mm_madd_epi16(pred_0, mask_0);
883 const __m128i compound_pred_1 = _mm_madd_epi16(pred_1, mask_1);
884 const __m128i shift_0 =
885 RightShiftWithRoundingConst_S32(compound_pred_0, 6, shift6);
886 const __m128i shift_1 =
887 RightShiftWithRoundingConst_S32(compound_pred_1, 6, shift6);
888 StoreUnaligned16(dst + x, _mm_packus_epi32(shift_0, shift_1));
889 x += 8;
890 } while (x < width);
891 dst += dst_stride;
892 pred_0 += pred_stride_0;
893 pred_1 += pred_stride_1;
894 mask += mask_stride_ss;
895 } while (--y != 0);
896 }
897
Init10bpp()898 void Init10bpp() {
899 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
900 assert(dsp != nullptr);
901
902 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlend444)
903 dsp->mask_blend[0][0] = MaskBlend10bpp_SSE4_1<0, 0>;
904 #endif
905 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlend422)
906 dsp->mask_blend[1][0] = MaskBlend10bpp_SSE4_1<1, 0>;
907 #endif
908 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlend420)
909 dsp->mask_blend[2][0] = MaskBlend10bpp_SSE4_1<1, 1>;
910 #endif
911 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra444)
912 dsp->mask_blend[0][1] = InterIntraMaskBlend10bpp_SSE4_1<0, 0>;
913 #endif
914 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra422)
915 dsp->mask_blend[1][1] = InterIntraMaskBlend10bpp_SSE4_1<1, 0>;
916 #endif
917 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra420)
918 dsp->mask_blend[2][1] = InterIntraMaskBlend10bpp_SSE4_1<1, 1>;
919 #endif
920 }
921
922 } // namespace
923 } // namespace high_bitdepth
924 #endif // LIBGAV1_MAX_BITDEPTH >= 10
925
MaskBlendInit_SSE4_1()926 void MaskBlendInit_SSE4_1() {
927 low_bitdepth::Init8bpp();
928 #if LIBGAV1_MAX_BITDEPTH >= 10
929 high_bitdepth::Init10bpp();
930 #endif // LIBGAV1_MAX_BITDEPTH >= 10
931 }
932
933 } // namespace dsp
934 } // namespace libgav1
935
936 #else // !LIBGAV1_TARGETING_SSE4_1
937
938 namespace libgav1 {
939 namespace dsp {
940
MaskBlendInit_SSE4_1()941 void MaskBlendInit_SSE4_1() {}
942
943 } // namespace dsp
944 } // namespace libgav1
945 #endif // LIBGAV1_TARGETING_SSE4_1
946