xref: /aosp_15_r20/external/libgav1/src/dsp/x86/mask_blend_sse4.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/mask_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_SSE4_1
19 
20 #include <smmintrin.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace {
34 
35 template <int subsampling_x, int subsampling_y>
GetMask8(const uint8_t * mask,const ptrdiff_t stride)36 inline __m128i GetMask8(const uint8_t* mask, const ptrdiff_t stride) {
37   if (subsampling_x == 1 && subsampling_y == 1) {
38     const __m128i one = _mm_set1_epi8(1);
39     const __m128i mask_val_0 = LoadUnaligned16(mask);
40     const __m128i mask_val_1 = LoadUnaligned16(mask + stride);
41     const __m128i add_0 = _mm_adds_epu8(mask_val_0, mask_val_1);
42     const __m128i mask_0 = _mm_maddubs_epi16(add_0, one);
43     return RightShiftWithRounding_U16(mask_0, 2);
44   }
45   if (subsampling_x == 1) {
46     const __m128i row_vals = LoadUnaligned16(mask);
47     const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals);
48     const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8));
49     __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
50     return RightShiftWithRounding_U16(subsampled_mask, 1);
51   }
52   assert(subsampling_y == 0 && subsampling_x == 0);
53   const __m128i mask_val = LoadLo8(mask);
54   return _mm_cvtepu8_epi16(mask_val);
55 }
56 
57 // Imitate behavior of ARM vtrn1q_u64.
Transpose1_U64(const __m128i a,const __m128i b)58 inline __m128i Transpose1_U64(const __m128i a, const __m128i b) {
59   return _mm_castps_si128(
60       _mm_movelh_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b)));
61 }
62 
63 // Imitate behavior of ARM vtrn2q_u64.
Transpose2_U64(const __m128i a,const __m128i b)64 inline __m128i Transpose2_U64(const __m128i a, const __m128i b) {
65   return _mm_castps_si128(
66       _mm_movehl_ps(_mm_castsi128_ps(a), _mm_castsi128_ps(b)));
67 }
68 
69 // Width can only be 4 when it is subsampled from a block of width 8, hence
70 // subsampling_x is always 1 when this function is called.
71 template <int subsampling_x, int subsampling_y>
GetMask4x2(const uint8_t * mask)72 inline __m128i GetMask4x2(const uint8_t* mask) {
73   if (subsampling_x == 1 && subsampling_y == 1) {
74     const __m128i mask_val_01 = LoadUnaligned16(mask);
75     // Stride is fixed because this is the smallest block size.
76     const __m128i mask_val_23 = LoadUnaligned16(mask + 16);
77     // Transpose rows to add row 0 to row 1, and row 2 to row 3.
78     const __m128i mask_val_02 = Transpose1_U64(mask_val_01, mask_val_23);
79     const __m128i mask_val_13 = Transpose2_U64(mask_val_23, mask_val_01);
80     const __m128i add_0 = _mm_adds_epu8(mask_val_02, mask_val_13);
81     const __m128i one = _mm_set1_epi8(1);
82     const __m128i mask_0 = _mm_maddubs_epi16(add_0, one);
83     return RightShiftWithRounding_U16(mask_0, 2);
84   }
85   return GetMask8<subsampling_x, 0>(mask, 0);
86 }
87 
88 template <int subsampling_x, int subsampling_y>
GetInterIntraMask4x2(const uint8_t * mask,ptrdiff_t mask_stride)89 inline __m128i GetInterIntraMask4x2(const uint8_t* mask,
90                                     ptrdiff_t mask_stride) {
91   if (subsampling_x == 1) {
92     return GetMask4x2<subsampling_x, subsampling_y>(mask);
93   }
94   // When using intra or difference weighted masks, the function doesn't use
95   // subsampling, so |mask_stride| may be 4 or 8.
96   assert(subsampling_y == 0 && subsampling_x == 0);
97   const __m128i mask_val_0 = Load4(mask);
98   const __m128i mask_val_1 = Load4(mask + mask_stride);
99   return _mm_cvtepu8_epi16(
100       _mm_or_si128(mask_val_0, _mm_slli_si128(mask_val_1, 4)));
101 }
102 
103 }  // namespace
104 
105 namespace low_bitdepth {
106 namespace {
107 
108 // This function returns a 16-bit packed mask to fit in _mm_madd_epi16.
109 // 16-bit is also the lowest packing for hadd, but without subsampling there is
110 // an unfortunate conversion required.
111 template <int subsampling_x, int subsampling_y>
GetMask8(const uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t stride)112 inline __m128i GetMask8(const uint8_t* LIBGAV1_RESTRICT mask,
113                         ptrdiff_t stride) {
114   if (subsampling_x == 1) {
115     const __m128i row_vals = LoadUnaligned16(mask);
116 
117     const __m128i mask_val_0 = _mm_cvtepu8_epi16(row_vals);
118     const __m128i mask_val_1 = _mm_cvtepu8_epi16(_mm_srli_si128(row_vals, 8));
119     __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
120 
121     if (subsampling_y == 1) {
122       const __m128i next_row_vals = LoadUnaligned16(mask + stride);
123       const __m128i next_mask_val_0 = _mm_cvtepu8_epi16(next_row_vals);
124       const __m128i next_mask_val_1 =
125           _mm_cvtepu8_epi16(_mm_srli_si128(next_row_vals, 8));
126       subsampled_mask = _mm_add_epi16(
127           subsampled_mask, _mm_hadd_epi16(next_mask_val_0, next_mask_val_1));
128     }
129     return RightShiftWithRounding_U16(subsampled_mask, 1 + subsampling_y);
130   }
131   assert(subsampling_y == 0 && subsampling_x == 0);
132   const __m128i mask_val = LoadLo8(mask);
133   return _mm_cvtepu8_epi16(mask_val);
134 }
135 
WriteMaskBlendLine4x2(const int16_t * LIBGAV1_RESTRICT const pred_0,const int16_t * LIBGAV1_RESTRICT const pred_1,const __m128i pred_mask_0,const __m128i pred_mask_1,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)136 inline void WriteMaskBlendLine4x2(const int16_t* LIBGAV1_RESTRICT const pred_0,
137                                   const int16_t* LIBGAV1_RESTRICT const pred_1,
138                                   const __m128i pred_mask_0,
139                                   const __m128i pred_mask_1,
140                                   uint8_t* LIBGAV1_RESTRICT dst,
141                                   const ptrdiff_t dst_stride) {
142   const __m128i pred_val_0 = LoadAligned16(pred_0);
143   const __m128i pred_val_1 = LoadAligned16(pred_1);
144   const __m128i mask_lo = _mm_unpacklo_epi16(pred_mask_0, pred_mask_1);
145   const __m128i mask_hi = _mm_unpackhi_epi16(pred_mask_0, pred_mask_1);
146   const __m128i pred_lo = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
147   const __m128i pred_hi = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
148 
149   // int res = (mask_value * prediction_0[x] +
150   //      (64 - mask_value) * prediction_1[x]) >> 6;
151   const __m128i compound_pred_lo = _mm_madd_epi16(pred_lo, mask_lo);
152   const __m128i compound_pred_hi = _mm_madd_epi16(pred_hi, mask_hi);
153   const __m128i compound_pred = _mm_packus_epi32(
154       _mm_srli_epi32(compound_pred_lo, 6), _mm_srli_epi32(compound_pred_hi, 6));
155 
156   // dst[x] = static_cast<Pixel>(
157   //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
158   //           (1 << kBitdepth8) - 1));
159   const __m128i result = RightShiftWithRounding_S16(compound_pred, 4);
160   const __m128i res = _mm_packus_epi16(result, result);
161   Store4(dst, res);
162   Store4(dst + dst_stride, _mm_srli_si128(res, 4));
163 }
164 
165 template <int subsampling_x, int subsampling_y>
MaskBlending4x4_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT mask,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)166 inline void MaskBlending4x4_SSE4_1(const int16_t* LIBGAV1_RESTRICT pred_0,
167                                    const int16_t* LIBGAV1_RESTRICT pred_1,
168                                    const uint8_t* LIBGAV1_RESTRICT mask,
169                                    uint8_t* LIBGAV1_RESTRICT dst,
170                                    const ptrdiff_t dst_stride) {
171   constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
172   const __m128i mask_inverter = _mm_set1_epi16(64);
173   __m128i pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
174   __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
175   WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
176                         dst_stride);
177   pred_0 += 4 << 1;
178   pred_1 += 4 << 1;
179   mask += mask_stride << (1 + subsampling_y);
180   dst += dst_stride << 1;
181 
182   pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
183   pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
184   WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
185                         dst_stride);
186 }
187 
188 template <int subsampling_x, int subsampling_y>
MaskBlending4xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)189 inline void MaskBlending4xH_SSE4_1(
190     const int16_t* LIBGAV1_RESTRICT pred_0,
191     const int16_t* LIBGAV1_RESTRICT pred_1,
192     const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const int height,
193     uint8_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
194   assert(subsampling_x == 1);
195   const uint8_t* mask = mask_ptr;
196   constexpr ptrdiff_t mask_stride = 4 << subsampling_x;
197   if (height == 4) {
198     MaskBlending4x4_SSE4_1<subsampling_x, subsampling_y>(pred_0, pred_1, mask,
199                                                          dst, dst_stride);
200     return;
201   }
202   const __m128i mask_inverter = _mm_set1_epi16(64);
203   int y = 0;
204   do {
205     __m128i pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
206     __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
207 
208     WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
209                           dst_stride);
210     pred_0 += 4 << 1;
211     pred_1 += 4 << 1;
212     mask += mask_stride << (1 + subsampling_y);
213     dst += dst_stride << 1;
214 
215     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
216     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
217     WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
218                           dst_stride);
219     pred_0 += 4 << 1;
220     pred_1 += 4 << 1;
221     mask += mask_stride << (1 + subsampling_y);
222     dst += dst_stride << 1;
223 
224     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
225     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
226     WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
227                           dst_stride);
228     pred_0 += 4 << 1;
229     pred_1 += 4 << 1;
230     mask += mask_stride << (1 + subsampling_y);
231     dst += dst_stride << 1;
232 
233     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
234     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
235     WriteMaskBlendLine4x2(pred_0, pred_1, pred_mask_0, pred_mask_1, dst,
236                           dst_stride);
237     pred_0 += 4 << 1;
238     pred_1 += 4 << 1;
239     mask += mask_stride << (1 + subsampling_y);
240     dst += dst_stride << 1;
241     y += 8;
242   } while (y < height);
243 }
244 
245 template <int subsampling_x, int subsampling_y>
MaskBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dst_stride)246 inline void MaskBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
247                              const void* LIBGAV1_RESTRICT prediction_1,
248                              const ptrdiff_t /*prediction_stride_1*/,
249                              const uint8_t* LIBGAV1_RESTRICT const mask_ptr,
250                              const ptrdiff_t mask_stride, const int width,
251                              const int height, void* LIBGAV1_RESTRICT dest,
252                              const ptrdiff_t dst_stride) {
253   auto* dst = static_cast<uint8_t*>(dest);
254   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
255   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
256   const ptrdiff_t pred_stride_0 = width;
257   const ptrdiff_t pred_stride_1 = width;
258   if (width == 4) {
259     MaskBlending4xH_SSE4_1<subsampling_x, subsampling_y>(
260         pred_0, pred_1, mask_ptr, height, dst, dst_stride);
261     return;
262   }
263   const uint8_t* mask = mask_ptr;
264   const __m128i mask_inverter = _mm_set1_epi16(64);
265   int y = 0;
266   do {
267     int x = 0;
268     do {
269       const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
270           mask + (x << subsampling_x), mask_stride);
271       // 64 - mask
272       const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
273       const __m128i mask_lo = _mm_unpacklo_epi16(pred_mask_0, pred_mask_1);
274       const __m128i mask_hi = _mm_unpackhi_epi16(pred_mask_0, pred_mask_1);
275 
276       const __m128i pred_val_0 = LoadAligned16(pred_0 + x);
277       const __m128i pred_val_1 = LoadAligned16(pred_1 + x);
278       const __m128i pred_lo = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
279       const __m128i pred_hi = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
280       // int res = (mask_value * prediction_0[x] +
281       //      (64 - mask_value) * prediction_1[x]) >> 6;
282       const __m128i compound_pred_lo = _mm_madd_epi16(pred_lo, mask_lo);
283       const __m128i compound_pred_hi = _mm_madd_epi16(pred_hi, mask_hi);
284 
285       const __m128i res = _mm_packus_epi32(_mm_srli_epi32(compound_pred_lo, 6),
286                                            _mm_srli_epi32(compound_pred_hi, 6));
287       // dst[x] = static_cast<Pixel>(
288       //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
289       //           (1 << kBitdepth8) - 1));
290       const __m128i result = RightShiftWithRounding_S16(res, 4);
291       StoreLo8(dst + x, _mm_packus_epi16(result, result));
292 
293       x += 8;
294     } while (x < width);
295     dst += dst_stride;
296     pred_0 += pred_stride_0;
297     pred_1 += pred_stride_1;
298     mask += mask_stride << subsampling_y;
299   } while (++y < height);
300 }
301 
InterIntraWriteMaskBlendLine8bpp4x2(const uint8_t * LIBGAV1_RESTRICT const pred_0,uint8_t * LIBGAV1_RESTRICT const pred_1,const ptrdiff_t pred_stride_1,const __m128i pred_mask_0,const __m128i pred_mask_1)302 inline void InterIntraWriteMaskBlendLine8bpp4x2(
303     const uint8_t* LIBGAV1_RESTRICT const pred_0,
304     uint8_t* LIBGAV1_RESTRICT const pred_1, const ptrdiff_t pred_stride_1,
305     const __m128i pred_mask_0, const __m128i pred_mask_1) {
306   const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1);
307 
308   const __m128i pred_val_0 = LoadLo8(pred_0);
309   __m128i pred_val_1 = Load4(pred_1);
310   pred_val_1 = _mm_or_si128(_mm_slli_si128(Load4(pred_1 + pred_stride_1), 4),
311                             pred_val_1);
312   const __m128i pred = _mm_unpacklo_epi8(pred_val_0, pred_val_1);
313   // int res = (mask_value * prediction_1[x] +
314   //      (64 - mask_value) * prediction_0[x]) >> 6;
315   const __m128i compound_pred = _mm_maddubs_epi16(pred, pred_mask);
316   const __m128i result = RightShiftWithRounding_U16(compound_pred, 6);
317   const __m128i res = _mm_packus_epi16(result, result);
318 
319   Store4(pred_1, res);
320   Store4(pred_1 + pred_stride_1, _mm_srli_si128(res, 4));
321 }
322 
323 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp4x4_SSE4_1(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride)324 inline void InterIntraMaskBlending8bpp4x4_SSE4_1(
325     const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
326     const ptrdiff_t pred_stride_1, const uint8_t* LIBGAV1_RESTRICT mask,
327     const ptrdiff_t mask_stride) {
328   const __m128i mask_inverter = _mm_set1_epi8(64);
329   const __m128i pred_mask_u16_first =
330       GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
331   mask += mask_stride << (1 + subsampling_y);
332   const __m128i pred_mask_u16_second =
333       GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
334   mask += mask_stride << (1 + subsampling_y);
335   __m128i pred_mask_1 =
336       _mm_packus_epi16(pred_mask_u16_first, pred_mask_u16_second);
337   __m128i pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
338   InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
339                                       pred_mask_0, pred_mask_1);
340   pred_0 += 4 << 1;
341   pred_1 += pred_stride_1 << 1;
342 
343   pred_mask_1 = _mm_srli_si128(pred_mask_1, 8);
344   pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
345   InterIntraWriteMaskBlendLine8bpp4x2(pred_0, pred_1, pred_stride_1,
346                                       pred_mask_0, pred_mask_1);
347 }
348 
349 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlending8bpp4xH_SSE4_1(const uint8_t * LIBGAV1_RESTRICT pred_0,uint8_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int height)350 inline void InterIntraMaskBlending8bpp4xH_SSE4_1(
351     const uint8_t* LIBGAV1_RESTRICT pred_0, uint8_t* LIBGAV1_RESTRICT pred_1,
352     const ptrdiff_t pred_stride_1,
353     const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
354     const int height) {
355   const uint8_t* mask = mask_ptr;
356   if (height == 4) {
357     InterIntraMaskBlending8bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
358         pred_0, pred_1, pred_stride_1, mask, mask_stride);
359     return;
360   }
361   int y = 0;
362   do {
363     InterIntraMaskBlending8bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
364         pred_0, pred_1, pred_stride_1, mask, mask_stride);
365     pred_0 += 4 << 2;
366     pred_1 += pred_stride_1 << 2;
367     mask += mask_stride << (2 + subsampling_y);
368 
369     InterIntraMaskBlending8bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
370         pred_0, pred_1, pred_stride_1, mask, mask_stride);
371     pred_0 += 4 << 2;
372     pred_1 += pred_stride_1 << 2;
373     mask += mask_stride << (2 + subsampling_y);
374     y += 8;
375   } while (y < height);
376 }
377 
378 // This version returns 8-bit packed values to fit in _mm_maddubs_epi16 because,
379 // when is_inter_intra is true, the prediction values are brought to 8-bit
380 // packing as well.
381 template <int subsampling_x, int subsampling_y>
GetInterIntraMask8bpp8(const uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t stride)382 inline __m128i GetInterIntraMask8bpp8(const uint8_t* LIBGAV1_RESTRICT mask,
383                                       ptrdiff_t stride) {
384   if (subsampling_x == 1) {
385     const __m128i ret = GetMask8<subsampling_x, subsampling_y>(mask, stride);
386     return _mm_packus_epi16(ret, ret);
387   }
388   assert(subsampling_y == 0 && subsampling_x == 0);
389   // Unfortunately there is no shift operation for 8-bit packing, or else we
390   // could return everything with 8-bit packing.
391   const __m128i mask_val = LoadLo8(mask);
392   return mask_val;
393 }
394 
395 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend8bpp_SSE4_1(const uint8_t * LIBGAV1_RESTRICT prediction_0,uint8_t * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height)396 void InterIntraMaskBlend8bpp_SSE4_1(
397     const uint8_t* LIBGAV1_RESTRICT prediction_0,
398     uint8_t* LIBGAV1_RESTRICT prediction_1, const ptrdiff_t prediction_stride_1,
399     const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
400     const int width, const int height) {
401   if (width == 4) {
402     InterIntraMaskBlending8bpp4xH_SSE4_1<subsampling_x, subsampling_y>(
403         prediction_0, prediction_1, prediction_stride_1, mask_ptr, mask_stride,
404         height);
405     return;
406   }
407   const uint8_t* mask = mask_ptr;
408   const __m128i mask_inverter = _mm_set1_epi8(64);
409   int y = 0;
410   do {
411     int x = 0;
412     do {
413       const __m128i pred_mask_1 =
414           GetInterIntraMask8bpp8<subsampling_x, subsampling_y>(
415               mask + (x << subsampling_x), mask_stride);
416       // 64 - mask
417       const __m128i pred_mask_0 = _mm_sub_epi8(mask_inverter, pred_mask_1);
418       const __m128i pred_mask = _mm_unpacklo_epi8(pred_mask_0, pred_mask_1);
419 
420       const __m128i pred_val_0 = LoadLo8(prediction_0 + x);
421       const __m128i pred_val_1 = LoadLo8(prediction_1 + x);
422       const __m128i pred = _mm_unpacklo_epi8(pred_val_0, pred_val_1);
423       // int res = (mask_value * prediction_1[x] +
424       //      (64 - mask_value) * prediction_0[x]) >> 6;
425       const __m128i compound_pred = _mm_maddubs_epi16(pred, pred_mask);
426       const __m128i result = RightShiftWithRounding_U16(compound_pred, 6);
427       const __m128i res = _mm_packus_epi16(result, result);
428 
429       StoreLo8(prediction_1 + x, res);
430 
431       x += 8;
432     } while (x < width);
433     prediction_0 += width;
434     prediction_1 += prediction_stride_1;
435     mask += mask_stride << subsampling_y;
436   } while (++y < height);
437 }
438 
Init8bpp()439 void Init8bpp() {
440   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
441   assert(dsp != nullptr);
442 #if DSP_ENABLED_8BPP_SSE4_1(MaskBlend444)
443   dsp->mask_blend[0][0] = MaskBlend_SSE4_1<0, 0>;
444 #endif
445 #if DSP_ENABLED_8BPP_SSE4_1(MaskBlend422)
446   dsp->mask_blend[1][0] = MaskBlend_SSE4_1<1, 0>;
447 #endif
448 #if DSP_ENABLED_8BPP_SSE4_1(MaskBlend420)
449   dsp->mask_blend[2][0] = MaskBlend_SSE4_1<1, 1>;
450 #endif
451   // The is_inter_intra index of mask_blend[][] is replaced by
452   // inter_intra_mask_blend_8bpp[] in 8-bit.
453 #if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp444)
454   dsp->inter_intra_mask_blend_8bpp[0] = InterIntraMaskBlend8bpp_SSE4_1<0, 0>;
455 #endif
456 #if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp422)
457   dsp->inter_intra_mask_blend_8bpp[1] = InterIntraMaskBlend8bpp_SSE4_1<1, 0>;
458 #endif
459 #if DSP_ENABLED_8BPP_SSE4_1(InterIntraMaskBlend8bpp420)
460   dsp->inter_intra_mask_blend_8bpp[2] = InterIntraMaskBlend8bpp_SSE4_1<1, 1>;
461 #endif
462 }
463 
464 }  // namespace
465 }  // namespace low_bitdepth
466 
467 #if LIBGAV1_MAX_BITDEPTH >= 10
468 namespace high_bitdepth {
469 namespace {
470 
471 constexpr int kMax10bppSample = (1 << 10) - 1;
472 constexpr int kMaskInverse = 64;
473 constexpr int kRoundBitsMaskBlend = 4;
474 
RightShiftWithRoundingConst_S32(const __m128i v_val_d,int bits,const __m128i shift)475 inline __m128i RightShiftWithRoundingConst_S32(const __m128i v_val_d, int bits,
476                                                const __m128i shift) {
477   const __m128i v_tmp_d = _mm_add_epi32(v_val_d, shift);
478   return _mm_srai_epi32(v_tmp_d, bits);
479 }
480 
481 template <int subsampling_x, int subsampling_y>
GetMask4x2(const uint8_t * mask)482 inline __m128i GetMask4x2(const uint8_t* mask) {
483   if (subsampling_x == 1 && subsampling_y == 1) {
484     const __m128i mask_row_01 = LoadUnaligned16(mask);
485     const __m128i mask_row_23 = LoadUnaligned16(mask + 16);
486     const __m128i mask_val_0 = _mm_cvtepu8_epi16(mask_row_01);
487     const __m128i mask_val_1 =
488         _mm_cvtepu8_epi16(_mm_srli_si128(mask_row_01, 8));
489     const __m128i mask_val_2 = _mm_cvtepu8_epi16(mask_row_23);
490     const __m128i mask_val_3 =
491         _mm_cvtepu8_epi16(_mm_srli_si128(mask_row_23, 8));
492     const __m128i subsampled_mask_02 = _mm_hadd_epi16(mask_val_0, mask_val_2);
493     const __m128i subsampled_mask_13 = _mm_hadd_epi16(mask_val_1, mask_val_3);
494     const __m128i subsampled_mask =
495         _mm_add_epi16(subsampled_mask_02, subsampled_mask_13);
496     return RightShiftWithRounding_U16(subsampled_mask, 2);
497   }
498   if (subsampling_x == 1) {
499     const __m128i mask_row_01 = LoadUnaligned16(mask);
500     const __m128i mask_val_0 = _mm_cvtepu8_epi16(mask_row_01);
501     const __m128i mask_val_1 =
502         _mm_cvtepu8_epi16(_mm_srli_si128(mask_row_01, 8));
503     const __m128i subsampled_mask = _mm_hadd_epi16(mask_val_0, mask_val_1);
504     return RightShiftWithRounding_U16(subsampled_mask, 1);
505   }
506   return _mm_cvtepu8_epi16(LoadLo8(mask));
507 }
508 
WriteMaskBlendLine10bpp4x2_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const __m128i & pred_mask_0,const __m128i & pred_mask_1,const __m128i & offset,const __m128i & max,const __m128i & shift4,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)509 inline void WriteMaskBlendLine10bpp4x2_SSE4_1(
510     const uint16_t* LIBGAV1_RESTRICT pred_0,
511     const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
512     const __m128i& pred_mask_0, const __m128i& pred_mask_1,
513     const __m128i& offset, const __m128i& max, const __m128i& shift4,
514     uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
515   const __m128i pred_val_0 = LoadUnaligned16(pred_0);
516   const __m128i pred_val_1 = LoadHi8(LoadLo8(pred_1), pred_1 + pred_stride_1);
517 
518   // int res = (mask_value * pred_0[x] + (64 - mask_value) * pred_1[x]) >> 6;
519   const __m128i compound_pred_lo_0 = _mm_mullo_epi16(pred_val_0, pred_mask_0);
520   const __m128i compound_pred_hi_0 = _mm_mulhi_epu16(pred_val_0, pred_mask_0);
521   const __m128i compound_pred_lo_1 = _mm_mullo_epi16(pred_val_1, pred_mask_1);
522   const __m128i compound_pred_hi_1 = _mm_mulhi_epu16(pred_val_1, pred_mask_1);
523   const __m128i pack0_lo =
524       _mm_unpacklo_epi16(compound_pred_lo_0, compound_pred_hi_0);
525   const __m128i pack0_hi =
526       _mm_unpackhi_epi16(compound_pred_lo_0, compound_pred_hi_0);
527   const __m128i pack1_lo =
528       _mm_unpacklo_epi16(compound_pred_lo_1, compound_pred_hi_1);
529   const __m128i pack1_hi =
530       _mm_unpackhi_epi16(compound_pred_lo_1, compound_pred_hi_1);
531   const __m128i compound_pred_lo = _mm_add_epi32(pack0_lo, pack1_lo);
532   const __m128i compound_pred_hi = _mm_add_epi32(pack0_hi, pack1_hi);
533   // res -= (bitdepth == 8) ? 0 : kCompoundOffset;
534   const __m128i sub_0 =
535       _mm_sub_epi32(_mm_srli_epi32(compound_pred_lo, 6), offset);
536   const __m128i sub_1 =
537       _mm_sub_epi32(_mm_srli_epi32(compound_pred_hi, 6), offset);
538 
539   // dst[x] = static_cast<Pixel>(
540   //     Clip3(RightShiftWithRounding(res, inter_post_round_bits), 0,
541   //           (1 << kBitdepth8) - 1));
542   const __m128i shift_0 =
543       RightShiftWithRoundingConst_S32(sub_0, kRoundBitsMaskBlend, shift4);
544   const __m128i shift_1 =
545       RightShiftWithRoundingConst_S32(sub_1, kRoundBitsMaskBlend, shift4);
546   const __m128i result = _mm_min_epi16(_mm_packus_epi32(shift_0, shift_1), max);
547   StoreLo8(dst, result);
548   StoreHi8(dst + dst_stride, result);
549 }
550 
551 template <int subsampling_x, int subsampling_y>
MaskBlend10bpp4x4_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)552 inline void MaskBlend10bpp4x4_SSE4_1(const uint16_t* LIBGAV1_RESTRICT pred_0,
553                                      const uint16_t* LIBGAV1_RESTRICT pred_1,
554                                      const ptrdiff_t pred_stride_1,
555                                      const uint8_t* LIBGAV1_RESTRICT mask,
556                                      const ptrdiff_t mask_stride,
557                                      uint16_t* LIBGAV1_RESTRICT dst,
558                                      const ptrdiff_t dst_stride) {
559   const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
560   const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1);
561   const __m128i offset = _mm_set1_epi32(kCompoundOffset);
562   const __m128i max = _mm_set1_epi16(kMax10bppSample);
563   __m128i pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
564   __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
565   WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, pred_mask_0,
566                                     pred_mask_1, offset, max, shift4, dst,
567                                     dst_stride);
568   pred_0 += 4 << 1;
569   pred_1 += pred_stride_1 << 1;
570   mask += mask_stride << (1 + subsampling_y);
571   dst += dst_stride << 1;
572 
573   pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
574   pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
575   WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1, pred_mask_0,
576                                     pred_mask_1, offset, max, shift4, dst,
577                                     dst_stride);
578 }
579 
580 template <int subsampling_x, int subsampling_y>
MaskBlend10bpp4xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int height,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)581 inline void MaskBlend10bpp4xH_SSE4_1(
582     const uint16_t* LIBGAV1_RESTRICT pred_0,
583     const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
584     const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
585     const int height, uint16_t* LIBGAV1_RESTRICT dst,
586     const ptrdiff_t dst_stride) {
587   const uint8_t* mask = mask_ptr;
588   if (height == 4) {
589     MaskBlend10bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
590         pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride);
591     return;
592   }
593   const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
594   const uint8_t pred0_stride2 = 4 << 1;
595   const ptrdiff_t pred1_stride2 = pred_stride_1 << 1;
596   const ptrdiff_t mask_stride2 = mask_stride << (1 + subsampling_y);
597   const ptrdiff_t dst_stride2 = dst_stride << 1;
598   const __m128i offset = _mm_set1_epi32(kCompoundOffset);
599   const __m128i max = _mm_set1_epi16(kMax10bppSample);
600   const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1);
601   int y = height;
602   do {
603     __m128i pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
604     __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
605 
606     WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
607                                       pred_mask_0, pred_mask_1, offset, max,
608                                       shift4, dst, dst_stride);
609     pred_0 += pred0_stride2;
610     pred_1 += pred1_stride2;
611     mask += mask_stride2;
612     dst += dst_stride2;
613 
614     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
615     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
616     WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
617                                       pred_mask_0, pred_mask_1, offset, max,
618                                       shift4, dst, dst_stride);
619     pred_0 += pred0_stride2;
620     pred_1 += pred1_stride2;
621     mask += mask_stride2;
622     dst += dst_stride2;
623 
624     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
625     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
626     WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
627                                       pred_mask_0, pred_mask_1, offset, max,
628                                       shift4, dst, dst_stride);
629     pred_0 += pred0_stride2;
630     pred_1 += pred1_stride2;
631     mask += mask_stride2;
632     dst += dst_stride2;
633 
634     pred_mask_0 = GetMask4x2<subsampling_x, subsampling_y>(mask);
635     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
636     WriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
637                                       pred_mask_0, pred_mask_1, offset, max,
638                                       shift4, dst, dst_stride);
639     pred_0 += pred0_stride2;
640     pred_1 += pred1_stride2;
641     mask += mask_stride2;
642     dst += dst_stride2;
643     y -= 8;
644   } while (y != 0);
645 }
646 
647 template <int subsampling_x, int subsampling_y>
MaskBlend10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)648 inline void MaskBlend10bpp_SSE4_1(
649     const void* LIBGAV1_RESTRICT prediction_0,
650     const void* LIBGAV1_RESTRICT prediction_1,
651     const ptrdiff_t prediction_stride_1,
652     const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
653     const int width, const int height, void* LIBGAV1_RESTRICT dest,
654     const ptrdiff_t dest_stride) {
655   auto* dst = static_cast<uint16_t*>(dest);
656   const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
657   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
658   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
659   const ptrdiff_t pred_stride_0 = width;
660   const ptrdiff_t pred_stride_1 = prediction_stride_1;
661   if (width == 4) {
662     MaskBlend10bpp4xH_SSE4_1<subsampling_x, subsampling_y>(
663         pred_0, pred_1, pred_stride_1, mask_ptr, mask_stride, height, dst,
664         dst_stride);
665     return;
666   }
667   const uint8_t* mask = mask_ptr;
668   const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
669   const ptrdiff_t mask_stride_ss = mask_stride << subsampling_y;
670   const __m128i offset = _mm_set1_epi32(kCompoundOffset);
671   const __m128i max = _mm_set1_epi16(kMax10bppSample);
672   const __m128i shift4 = _mm_set1_epi32((1 << kRoundBitsMaskBlend) >> 1);
673   int y = height;
674   do {
675     int x = 0;
676     do {
677       const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
678           mask + (x << subsampling_x), mask_stride);
679       const __m128i pred_val_0 = LoadUnaligned16(pred_0 + x);
680       const __m128i pred_val_1 = LoadUnaligned16(pred_1 + x);
681       // 64 - mask
682       const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
683 
684       const __m128i compound_pred_lo_0 =
685           _mm_mullo_epi16(pred_val_0, pred_mask_0);
686       const __m128i compound_pred_hi_0 =
687           _mm_mulhi_epu16(pred_val_0, pred_mask_0);
688       const __m128i compound_pred_lo_1 =
689           _mm_mullo_epi16(pred_val_1, pred_mask_1);
690       const __m128i compound_pred_hi_1 =
691           _mm_mulhi_epu16(pred_val_1, pred_mask_1);
692       const __m128i pack0_lo =
693           _mm_unpacklo_epi16(compound_pred_lo_0, compound_pred_hi_0);
694       const __m128i pack0_hi =
695           _mm_unpackhi_epi16(compound_pred_lo_0, compound_pred_hi_0);
696       const __m128i pack1_lo =
697           _mm_unpacklo_epi16(compound_pred_lo_1, compound_pred_hi_1);
698       const __m128i pack1_hi =
699           _mm_unpackhi_epi16(compound_pred_lo_1, compound_pred_hi_1);
700       const __m128i compound_pred_lo = _mm_add_epi32(pack0_lo, pack1_lo);
701       const __m128i compound_pred_hi = _mm_add_epi32(pack0_hi, pack1_hi);
702 
703       const __m128i sub_0 =
704           _mm_sub_epi32(_mm_srli_epi32(compound_pred_lo, 6), offset);
705       const __m128i sub_1 =
706           _mm_sub_epi32(_mm_srli_epi32(compound_pred_hi, 6), offset);
707       const __m128i shift_0 =
708           RightShiftWithRoundingConst_S32(sub_0, kRoundBitsMaskBlend, shift4);
709       const __m128i shift_1 =
710           RightShiftWithRoundingConst_S32(sub_1, kRoundBitsMaskBlend, shift4);
711       const __m128i result =
712           _mm_min_epi16(_mm_packus_epi32(shift_0, shift_1), max);
713       StoreUnaligned16(dst + x, result);
714       x += 8;
715     } while (x < width);
716     dst += dst_stride;
717     pred_0 += pred_stride_0;
718     pred_1 += pred_stride_1;
719     mask += mask_stride_ss;
720   } while (--y != 0);
721 }
InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t pred_stride_1,const __m128i & pred_mask_0,const __m128i & pred_mask_1,const __m128i & shift6,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)722 inline void InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(
723     const uint16_t* LIBGAV1_RESTRICT prediction_0,
724     const uint16_t* LIBGAV1_RESTRICT prediction_1,
725     const ptrdiff_t pred_stride_1, const __m128i& pred_mask_0,
726     const __m128i& pred_mask_1, const __m128i& shift6,
727     uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
728   const __m128i pred_val_0 = LoadUnaligned16(prediction_0);
729   const __m128i pred_val_1 =
730       LoadHi8(LoadLo8(prediction_1), prediction_1 + pred_stride_1);
731 
732   const __m128i mask_0 = _mm_unpacklo_epi16(pred_mask_1, pred_mask_0);
733   const __m128i mask_1 = _mm_unpackhi_epi16(pred_mask_1, pred_mask_0);
734   const __m128i pred_0 = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
735   const __m128i pred_1 = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
736 
737   const __m128i compound_pred_0 = _mm_madd_epi16(pred_0, mask_0);
738   const __m128i compound_pred_1 = _mm_madd_epi16(pred_1, mask_1);
739   const __m128i shift_0 =
740       RightShiftWithRoundingConst_S32(compound_pred_0, 6, shift6);
741   const __m128i shift_1 =
742       RightShiftWithRoundingConst_S32(compound_pred_1, 6, shift6);
743   const __m128i res = _mm_packus_epi32(shift_0, shift_1);
744   StoreLo8(dst, res);
745   StoreHi8(dst + dst_stride, res);
746 }
747 
748 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend10bpp4x4_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT mask,const ptrdiff_t mask_stride,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)749 inline void InterIntraMaskBlend10bpp4x4_SSE4_1(
750     const uint16_t* LIBGAV1_RESTRICT pred_0,
751     const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
752     const uint8_t* LIBGAV1_RESTRICT mask, const ptrdiff_t mask_stride,
753     uint16_t* LIBGAV1_RESTRICT dst, const ptrdiff_t dst_stride) {
754   const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
755   const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1);
756   __m128i pred_mask_0 =
757       GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
758   __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
759   InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
760                                               pred_mask_0, pred_mask_1, shift6,
761                                               dst, dst_stride);
762   pred_0 += 4 << 1;
763   pred_1 += pred_stride_1 << 1;
764   mask += mask_stride << (1 + subsampling_y);
765   dst += dst_stride << 1;
766 
767   pred_mask_0 =
768       GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
769   pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
770   InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
771                                               pred_mask_0, pred_mask_1, shift6,
772                                               dst, dst_stride);
773 }
774 
775 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend10bpp4xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const ptrdiff_t pred_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int height,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)776 inline void InterIntraMaskBlend10bpp4xH_SSE4_1(
777     const uint16_t* LIBGAV1_RESTRICT pred_0,
778     const uint16_t* LIBGAV1_RESTRICT pred_1, const ptrdiff_t pred_stride_1,
779     const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
780     const int height, uint16_t* LIBGAV1_RESTRICT dst,
781     const ptrdiff_t dst_stride) {
782   const uint8_t* mask = mask_ptr;
783   if (height == 4) {
784     InterIntraMaskBlend10bpp4x4_SSE4_1<subsampling_x, subsampling_y>(
785         pred_0, pred_1, pred_stride_1, mask, mask_stride, dst, dst_stride);
786     return;
787   }
788   const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
789   const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1);
790   const uint8_t pred0_stride2 = 4 << 1;
791   const ptrdiff_t pred1_stride2 = pred_stride_1 << 1;
792   const ptrdiff_t mask_stride2 = mask_stride << (1 + subsampling_y);
793   const ptrdiff_t dst_stride2 = dst_stride << 1;
794   int y = height;
795   do {
796     __m128i pred_mask_0 =
797         GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
798     __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
799     InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
800                                                 pred_mask_0, pred_mask_1,
801                                                 shift6, dst, dst_stride);
802     pred_0 += pred0_stride2;
803     pred_1 += pred1_stride2;
804     mask += mask_stride2;
805     dst += dst_stride2;
806 
807     pred_mask_0 =
808         GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
809     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
810     InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
811                                                 pred_mask_0, pred_mask_1,
812                                                 shift6, dst, dst_stride);
813     pred_0 += pred0_stride2;
814     pred_1 += pred1_stride2;
815     mask += mask_stride2;
816     dst += dst_stride2;
817 
818     pred_mask_0 =
819         GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
820     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
821     InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
822                                                 pred_mask_0, pred_mask_1,
823                                                 shift6, dst, dst_stride);
824     pred_0 += pred0_stride2;
825     pred_1 += pred1_stride2;
826     mask += mask_stride2;
827     dst += dst_stride2;
828 
829     pred_mask_0 =
830         GetInterIntraMask4x2<subsampling_x, subsampling_y>(mask, mask_stride);
831     pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
832     InterIntraWriteMaskBlendLine10bpp4x2_SSE4_1(pred_0, pred_1, pred_stride_1,
833                                                 pred_mask_0, pred_mask_1,
834                                                 shift6, dst, dst_stride);
835     pred_0 += pred0_stride2;
836     pred_1 += pred1_stride2;
837     mask += mask_stride2;
838     dst += dst_stride2;
839     y -= 8;
840   } while (y != 0);
841 }
842 
843 template <int subsampling_x, int subsampling_y>
InterIntraMaskBlend10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const ptrdiff_t prediction_stride_1,const uint8_t * LIBGAV1_RESTRICT const mask_ptr,const ptrdiff_t mask_stride,const int width,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)844 inline void InterIntraMaskBlend10bpp_SSE4_1(
845     const void* LIBGAV1_RESTRICT prediction_0,
846     const void* LIBGAV1_RESTRICT prediction_1,
847     const ptrdiff_t prediction_stride_1,
848     const uint8_t* LIBGAV1_RESTRICT const mask_ptr, const ptrdiff_t mask_stride,
849     const int width, const int height, void* LIBGAV1_RESTRICT dest,
850     const ptrdiff_t dest_stride) {
851   auto* dst = static_cast<uint16_t*>(dest);
852   const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
853   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
854   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
855   const ptrdiff_t pred_stride_0 = width;
856   const ptrdiff_t pred_stride_1 = prediction_stride_1;
857   if (width == 4) {
858     InterIntraMaskBlend10bpp4xH_SSE4_1<subsampling_x, subsampling_y>(
859         pred_0, pred_1, pred_stride_1, mask_ptr, mask_stride, height, dst,
860         dst_stride);
861     return;
862   }
863   const uint8_t* mask = mask_ptr;
864   const __m128i mask_inverter = _mm_set1_epi16(kMaskInverse);
865   const __m128i shift6 = _mm_set1_epi32((1 << 6) >> 1);
866   const ptrdiff_t mask_stride_ss = mask_stride << subsampling_y;
867   int y = height;
868   do {
869     int x = 0;
870     do {
871       const __m128i pred_mask_0 = GetMask8<subsampling_x, subsampling_y>(
872           mask + (x << subsampling_x), mask_stride);
873       const __m128i pred_val_0 = LoadUnaligned16(pred_0 + x);
874       const __m128i pred_val_1 = LoadUnaligned16(pred_1 + x);
875       // 64 - mask
876       const __m128i pred_mask_1 = _mm_sub_epi16(mask_inverter, pred_mask_0);
877       const __m128i mask_0 = _mm_unpacklo_epi16(pred_mask_1, pred_mask_0);
878       const __m128i mask_1 = _mm_unpackhi_epi16(pred_mask_1, pred_mask_0);
879       const __m128i pred_0 = _mm_unpacklo_epi16(pred_val_0, pred_val_1);
880       const __m128i pred_1 = _mm_unpackhi_epi16(pred_val_0, pred_val_1);
881 
882       const __m128i compound_pred_0 = _mm_madd_epi16(pred_0, mask_0);
883       const __m128i compound_pred_1 = _mm_madd_epi16(pred_1, mask_1);
884       const __m128i shift_0 =
885           RightShiftWithRoundingConst_S32(compound_pred_0, 6, shift6);
886       const __m128i shift_1 =
887           RightShiftWithRoundingConst_S32(compound_pred_1, 6, shift6);
888       StoreUnaligned16(dst + x, _mm_packus_epi32(shift_0, shift_1));
889       x += 8;
890     } while (x < width);
891     dst += dst_stride;
892     pred_0 += pred_stride_0;
893     pred_1 += pred_stride_1;
894     mask += mask_stride_ss;
895   } while (--y != 0);
896 }
897 
Init10bpp()898 void Init10bpp() {
899   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
900   assert(dsp != nullptr);
901 
902 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlend444)
903   dsp->mask_blend[0][0] = MaskBlend10bpp_SSE4_1<0, 0>;
904 #endif
905 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlend422)
906   dsp->mask_blend[1][0] = MaskBlend10bpp_SSE4_1<1, 0>;
907 #endif
908 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlend420)
909   dsp->mask_blend[2][0] = MaskBlend10bpp_SSE4_1<1, 1>;
910 #endif
911 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra444)
912   dsp->mask_blend[0][1] = InterIntraMaskBlend10bpp_SSE4_1<0, 0>;
913 #endif
914 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra422)
915   dsp->mask_blend[1][1] = InterIntraMaskBlend10bpp_SSE4_1<1, 0>;
916 #endif
917 #if DSP_ENABLED_10BPP_SSE4_1(MaskBlendInterIntra420)
918   dsp->mask_blend[2][1] = InterIntraMaskBlend10bpp_SSE4_1<1, 1>;
919 #endif
920 }
921 
922 }  // namespace
923 }  // namespace high_bitdepth
924 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
925 
MaskBlendInit_SSE4_1()926 void MaskBlendInit_SSE4_1() {
927   low_bitdepth::Init8bpp();
928 #if LIBGAV1_MAX_BITDEPTH >= 10
929   high_bitdepth::Init10bpp();
930 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
931 }
932 
933 }  // namespace dsp
934 }  // namespace libgav1
935 
936 #else   // !LIBGAV1_TARGETING_SSE4_1
937 
938 namespace libgav1 {
939 namespace dsp {
940 
MaskBlendInit_SSE4_1()941 void MaskBlendInit_SSE4_1() {}
942 
943 }  // namespace dsp
944 }  // namespace libgav1
945 #endif  // LIBGAV1_TARGETING_SSE4_1
946