xref: /aosp_15_r20/external/libgav1/src/dsp/x86/obmc_sse4.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/obmc.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_SSE4_1
19 
20 #include <xmmintrin.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30 #include "src/utils/constants.h"
31 
32 namespace libgav1 {
33 namespace dsp {
34 namespace low_bitdepth {
35 namespace {
36 
37 #include "src/dsp/obmc.inc"
38 
OverlapBlendFromLeft2xH_SSE4_1(uint8_t * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT const obmc_prediction)39 inline void OverlapBlendFromLeft2xH_SSE4_1(
40     uint8_t* LIBGAV1_RESTRICT const prediction,
41     const ptrdiff_t prediction_stride, const int height,
42     const uint8_t* LIBGAV1_RESTRICT const obmc_prediction) {
43   constexpr int obmc_prediction_stride = 2;
44   uint8_t* pred = prediction;
45   const uint8_t* obmc_pred = obmc_prediction;
46   const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040);
47   const __m128i mask_val = _mm_shufflelo_epi16(Load4(kObmcMask), 0);
48   // 64 - mask
49   const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
50   const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
51   int y = height;
52   do {
53     const __m128i pred_val = Load2x2(pred, pred + prediction_stride);
54     const __m128i obmc_pred_val = Load4(obmc_pred);
55 
56     const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
57     const __m128i result =
58         RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6);
59     const __m128i packed_result = _mm_packus_epi16(result, result);
60     Store2(pred, packed_result);
61     pred += prediction_stride;
62     const int16_t second_row_result = _mm_extract_epi16(packed_result, 1);
63     memcpy(pred, &second_row_result, sizeof(second_row_result));
64     pred += prediction_stride;
65     obmc_pred += obmc_prediction_stride << 1;
66     y -= 2;
67   } while (y != 0);
68 }
69 
OverlapBlendFromLeft4xH_SSE4_1(uint8_t * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT const obmc_prediction)70 inline void OverlapBlendFromLeft4xH_SSE4_1(
71     uint8_t* LIBGAV1_RESTRICT const prediction,
72     const ptrdiff_t prediction_stride, const int height,
73     const uint8_t* LIBGAV1_RESTRICT const obmc_prediction) {
74   constexpr int obmc_prediction_stride = 4;
75   uint8_t* pred = prediction;
76   const uint8_t* obmc_pred = obmc_prediction;
77   const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040);
78   const __m128i mask_val = Load4(kObmcMask + 2);
79   // 64 - mask
80   const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
81   // Duplicate first half of vector.
82   const __m128i masks =
83       _mm_shuffle_epi32(_mm_unpacklo_epi8(mask_val, obmc_mask_val), 0x44);
84   int y = height;
85   do {
86     const __m128i pred_val0 = Load4(pred);
87     pred += prediction_stride;
88 
89     // Place the second row of each source in the second four bytes.
90     const __m128i pred_val =
91         _mm_alignr_epi8(Load4(pred), _mm_slli_si128(pred_val0, 12), 12);
92     const __m128i obmc_pred_val = LoadLo8(obmc_pred);
93     const __m128i terms = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
94     const __m128i result =
95         RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6);
96     const __m128i packed_result = _mm_packus_epi16(result, result);
97     Store4(pred - prediction_stride, packed_result);
98     const int second_row_result = _mm_extract_epi32(packed_result, 1);
99     memcpy(pred, &second_row_result, sizeof(second_row_result));
100     pred += prediction_stride;
101     obmc_pred += obmc_prediction_stride << 1;
102     y -= 2;
103   } while (y != 0);
104 }
105 
OverlapBlendFromLeft8xH_SSE4_1(uint8_t * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT const obmc_prediction)106 inline void OverlapBlendFromLeft8xH_SSE4_1(
107     uint8_t* LIBGAV1_RESTRICT const prediction,
108     const ptrdiff_t prediction_stride, const int height,
109     const uint8_t* LIBGAV1_RESTRICT const obmc_prediction) {
110   constexpr int obmc_prediction_stride = 8;
111   uint8_t* pred = prediction;
112   const uint8_t* obmc_pred = obmc_prediction;
113   const __m128i mask_inverter = _mm_set1_epi8(64);
114   const __m128i mask_val = LoadLo8(kObmcMask + 6);
115   // 64 - mask
116   const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
117   const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
118   int y = height;
119   do {
120     const __m128i pred_val = LoadHi8(LoadLo8(pred), pred + prediction_stride);
121     const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred);
122 
123     const __m128i terms_lo = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
124     const __m128i result_lo =
125         RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_lo, masks), 6);
126 
127     const __m128i terms_hi = _mm_unpackhi_epi8(pred_val, obmc_pred_val);
128     const __m128i result_hi =
129         RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_hi, masks), 6);
130 
131     const __m128i result = _mm_packus_epi16(result_lo, result_hi);
132     StoreLo8(pred, result);
133     pred += prediction_stride;
134     StoreHi8(pred, result);
135     pred += prediction_stride;
136     obmc_pred += obmc_prediction_stride << 1;
137     y -= 2;
138   } while (y != 0);
139 }
140 
OverlapBlendFromLeft_SSE4_1(void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int width,const int height,const void * LIBGAV1_RESTRICT const obmc_prediction,const ptrdiff_t obmc_prediction_stride)141 void OverlapBlendFromLeft_SSE4_1(
142     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
143     const int width, const int height,
144     const void* LIBGAV1_RESTRICT const obmc_prediction,
145     const ptrdiff_t obmc_prediction_stride) {
146   auto* pred = static_cast<uint8_t*>(prediction);
147   const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
148   assert(width >= 2);
149   assert(height >= 4);
150 
151   if (width == 2) {
152     OverlapBlendFromLeft2xH_SSE4_1(pred, prediction_stride, height, obmc_pred);
153     return;
154   }
155   if (width == 4) {
156     OverlapBlendFromLeft4xH_SSE4_1(pred, prediction_stride, height, obmc_pred);
157     return;
158   }
159   if (width == 8) {
160     OverlapBlendFromLeft8xH_SSE4_1(pred, prediction_stride, height, obmc_pred);
161     return;
162   }
163   const __m128i mask_inverter = _mm_set1_epi8(64);
164   const uint8_t* mask = kObmcMask + width - 2;
165   int x = 0;
166   do {
167     pred = static_cast<uint8_t*>(prediction) + x;
168     obmc_pred = static_cast<const uint8_t*>(obmc_prediction) + x;
169     const __m128i mask_val = LoadUnaligned16(mask + x);
170     // 64 - mask
171     const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
172     const __m128i masks_lo = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
173     const __m128i masks_hi = _mm_unpackhi_epi8(mask_val, obmc_mask_val);
174 
175     int y = 0;
176     do {
177       const __m128i pred_val = LoadUnaligned16(pred);
178       const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred);
179       const __m128i terms_lo = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
180       const __m128i result_lo =
181           RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_lo, masks_lo), 6);
182       const __m128i terms_hi = _mm_unpackhi_epi8(pred_val, obmc_pred_val);
183       const __m128i result_hi =
184           RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_hi, masks_hi), 6);
185       StoreUnaligned16(pred, _mm_packus_epi16(result_lo, result_hi));
186 
187       pred += prediction_stride;
188       obmc_pred += obmc_prediction_stride;
189     } while (++y < height);
190     x += 16;
191   } while (x < width);
192 }
193 
OverlapBlendFromTop4xH_SSE4_1(uint8_t * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT const obmc_prediction)194 inline void OverlapBlendFromTop4xH_SSE4_1(
195     uint8_t* LIBGAV1_RESTRICT const prediction,
196     const ptrdiff_t prediction_stride, const int height,
197     const uint8_t* LIBGAV1_RESTRICT const obmc_prediction) {
198   constexpr int obmc_prediction_stride = 4;
199   uint8_t* pred = prediction;
200   const uint8_t* obmc_pred = obmc_prediction;
201   const __m128i mask_inverter = _mm_set1_epi16(64);
202   const __m128i mask_shuffler = _mm_set_epi32(0x01010101, 0x01010101, 0, 0);
203   const __m128i mask_preinverter = _mm_set1_epi16(-256 | 1);
204 
205   const uint8_t* mask = kObmcMask + height - 2;
206   const int compute_height = height - (height >> 2);
207   int y = 0;
208   do {
209     // First mask in the first half, second mask in the second half.
210     const __m128i mask_val = _mm_shuffle_epi8(
211         _mm_cvtsi32_si128(*reinterpret_cast<const uint16_t*>(mask + y)),
212         mask_shuffler);
213     const __m128i masks =
214         _mm_sub_epi8(mask_inverter, _mm_sign_epi8(mask_val, mask_preinverter));
215     const __m128i pred_val0 = Load4(pred);
216 
217     const __m128i obmc_pred_val = LoadLo8(obmc_pred);
218     pred += prediction_stride;
219     const __m128i pred_val =
220         _mm_alignr_epi8(Load4(pred), _mm_slli_si128(pred_val0, 12), 12);
221     const __m128i terms = _mm_unpacklo_epi8(obmc_pred_val, pred_val);
222     const __m128i result =
223         RightShiftWithRounding_U16(_mm_maddubs_epi16(terms, masks), 6);
224 
225     const __m128i packed_result = _mm_packus_epi16(result, result);
226     Store4(pred - prediction_stride, packed_result);
227     Store4(pred, _mm_srli_si128(packed_result, 4));
228     pred += prediction_stride;
229     obmc_pred += obmc_prediction_stride << 1;
230     y += 2;
231   } while (y < compute_height);
232 }
233 
OverlapBlendFromTop8xH_SSE4_1(uint8_t * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int height,const uint8_t * LIBGAV1_RESTRICT const obmc_prediction)234 inline void OverlapBlendFromTop8xH_SSE4_1(
235     uint8_t* LIBGAV1_RESTRICT const prediction,
236     const ptrdiff_t prediction_stride, const int height,
237     const uint8_t* LIBGAV1_RESTRICT const obmc_prediction) {
238   constexpr int obmc_prediction_stride = 8;
239   uint8_t* pred = prediction;
240   const uint8_t* obmc_pred = obmc_prediction;
241   const uint8_t* mask = kObmcMask + height - 2;
242   const __m128i mask_inverter = _mm_set1_epi8(64);
243   const int compute_height = height - (height >> 2);
244   int y = compute_height;
245   do {
246     const __m128i mask_val0 = _mm_set1_epi8(mask[compute_height - y]);
247     // 64 - mask
248     const __m128i obmc_mask_val0 = _mm_sub_epi8(mask_inverter, mask_val0);
249     const __m128i masks0 = _mm_unpacklo_epi8(mask_val0, obmc_mask_val0);
250 
251     const __m128i pred_val = LoadHi8(LoadLo8(pred), pred + prediction_stride);
252     const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred);
253 
254     const __m128i terms_lo = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
255     const __m128i result_lo =
256         RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_lo, masks0), 6);
257 
258     --y;
259     const __m128i mask_val1 = _mm_set1_epi8(mask[compute_height - y]);
260     // 64 - mask
261     const __m128i obmc_mask_val1 = _mm_sub_epi8(mask_inverter, mask_val1);
262     const __m128i masks1 = _mm_unpacklo_epi8(mask_val1, obmc_mask_val1);
263 
264     const __m128i terms_hi = _mm_unpackhi_epi8(pred_val, obmc_pred_val);
265     const __m128i result_hi =
266         RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_hi, masks1), 6);
267 
268     const __m128i result = _mm_packus_epi16(result_lo, result_hi);
269     StoreLo8(pred, result);
270     pred += prediction_stride;
271     StoreHi8(pred, result);
272     pred += prediction_stride;
273     obmc_pred += obmc_prediction_stride << 1;
274   } while (--y > 0);
275 }
276 
OverlapBlendFromTop_SSE4_1(void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int width,const int height,const void * LIBGAV1_RESTRICT const obmc_prediction,const ptrdiff_t obmc_prediction_stride)277 void OverlapBlendFromTop_SSE4_1(
278     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
279     const int width, const int height,
280     const void* LIBGAV1_RESTRICT const obmc_prediction,
281     const ptrdiff_t obmc_prediction_stride) {
282   auto* pred = static_cast<uint8_t*>(prediction);
283   const auto* obmc_pred = static_cast<const uint8_t*>(obmc_prediction);
284   assert(width >= 4);
285   assert(height >= 2);
286 
287   if (width == 4) {
288     OverlapBlendFromTop4xH_SSE4_1(pred, prediction_stride, height, obmc_pred);
289     return;
290   }
291   if (width == 8) {
292     OverlapBlendFromTop8xH_SSE4_1(pred, prediction_stride, height, obmc_pred);
293     return;
294   }
295 
296   // Stop when mask value becomes 64.
297   const int compute_height = height - (height >> 2);
298   const __m128i mask_inverter = _mm_set1_epi8(64);
299   int y = 0;
300   const uint8_t* mask = kObmcMask + height - 2;
301   do {
302     const __m128i mask_val = _mm_set1_epi8(mask[y]);
303     // 64 - mask
304     const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
305     const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
306     int x = 0;
307     do {
308       const __m128i pred_val = LoadUnaligned16(pred + x);
309       const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred + x);
310       const __m128i terms_lo = _mm_unpacklo_epi8(pred_val, obmc_pred_val);
311       const __m128i result_lo =
312           RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_lo, masks), 6);
313       const __m128i terms_hi = _mm_unpackhi_epi8(pred_val, obmc_pred_val);
314       const __m128i result_hi =
315           RightShiftWithRounding_U16(_mm_maddubs_epi16(terms_hi, masks), 6);
316       StoreUnaligned16(pred + x, _mm_packus_epi16(result_lo, result_hi));
317       x += 16;
318     } while (x < width);
319     pred += prediction_stride;
320     obmc_pred += obmc_prediction_stride;
321   } while (++y < compute_height);
322 }
323 
Init8bpp()324 void Init8bpp() {
325   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
326   assert(dsp != nullptr);
327 #if DSP_ENABLED_8BPP_SSE4_1(ObmcVertical)
328   dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop_SSE4_1;
329 #endif
330 #if DSP_ENABLED_8BPP_SSE4_1(ObmcHorizontal)
331   dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft_SSE4_1;
332 #endif
333 }
334 
335 }  // namespace
336 }  // namespace low_bitdepth
337 
338 #if LIBGAV1_MAX_BITDEPTH >= 10
339 namespace high_bitdepth {
340 namespace {
341 
342 #include "src/dsp/obmc.inc"
343 
344 constexpr int kRoundBitsObmcBlend = 6;
345 
OverlapBlendFromLeft2xH_SSE4_1(uint16_t * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride,const int height,const uint16_t * LIBGAV1_RESTRICT const obmc_prediction)346 inline void OverlapBlendFromLeft2xH_SSE4_1(
347     uint16_t* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride,
348     const int height, const uint16_t* LIBGAV1_RESTRICT const obmc_prediction) {
349   constexpr int obmc_pred_stride = 2;
350   uint16_t* pred = prediction;
351   const uint16_t* obmc_pred = obmc_prediction;
352   const ptrdiff_t pred_stride2 = pred_stride << 1;
353   const ptrdiff_t obmc_pred_stride2 = obmc_pred_stride << 1;
354   const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040);
355   const __m128i mask_val = _mm_shufflelo_epi16(Load2(kObmcMask), 0x00);
356   // 64 - mask.
357   const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
358   const __m128i masks =
359       _mm_cvtepi8_epi16(_mm_unpacklo_epi8(mask_val, obmc_mask_val));
360   int y = height;
361   do {
362     const __m128i pred_val = Load4x2(pred, pred + pred_stride);
363     const __m128i obmc_pred_val = LoadLo8(obmc_pred);
364     const __m128i terms = _mm_unpacklo_epi16(pred_val, obmc_pred_val);
365     const __m128i result = RightShiftWithRounding_U32(
366         _mm_madd_epi16(terms, masks), kRoundBitsObmcBlend);
367     const __m128i packed_result = _mm_packus_epi32(result, result);
368     Store4(pred, packed_result);
369     Store4(pred + pred_stride, _mm_srli_si128(packed_result, 4));
370     pred += pred_stride2;
371     obmc_pred += obmc_pred_stride2;
372     y -= 2;
373   } while (y != 0);
374 }
375 
OverlapBlendFromLeft4xH_SSE4_1(uint16_t * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride,const int height,const uint16_t * LIBGAV1_RESTRICT const obmc_prediction)376 inline void OverlapBlendFromLeft4xH_SSE4_1(
377     uint16_t* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride,
378     const int height, const uint16_t* LIBGAV1_RESTRICT const obmc_prediction) {
379   constexpr int obmc_pred_stride = 4;
380   uint16_t* pred = prediction;
381   const uint16_t* obmc_pred = obmc_prediction;
382   const ptrdiff_t pred_stride2 = pred_stride << 1;
383   const ptrdiff_t obmc_pred_stride2 = obmc_pred_stride << 1;
384   const __m128i mask_inverter = _mm_cvtsi32_si128(0x40404040);
385   const __m128i mask_val = Load4(kObmcMask + 2);
386   // 64 - mask.
387   const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
388   const __m128i masks =
389       _mm_cvtepi8_epi16(_mm_unpacklo_epi8(mask_val, obmc_mask_val));
390   int y = height;
391   do {
392     const __m128i pred_val = LoadHi8(LoadLo8(pred), pred + pred_stride);
393     const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred);
394     const __m128i terms_lo = _mm_unpacklo_epi16(pred_val, obmc_pred_val);
395     const __m128i terms_hi = _mm_unpackhi_epi16(pred_val, obmc_pred_val);
396     const __m128i result_lo = RightShiftWithRounding_U32(
397         _mm_madd_epi16(terms_lo, masks), kRoundBitsObmcBlend);
398     const __m128i result_hi = RightShiftWithRounding_U32(
399         _mm_madd_epi16(terms_hi, masks), kRoundBitsObmcBlend);
400     const __m128i packed_result = _mm_packus_epi32(result_lo, result_hi);
401     StoreLo8(pred, packed_result);
402     StoreHi8(pred + pred_stride, packed_result);
403     pred += pred_stride2;
404     obmc_pred += obmc_pred_stride2;
405     y -= 2;
406   } while (y != 0);
407 }
408 
OverlapBlendFromLeft10bpp_SSE4_1(void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int width,const int height,const void * LIBGAV1_RESTRICT const obmc_prediction,const ptrdiff_t obmc_prediction_stride)409 void OverlapBlendFromLeft10bpp_SSE4_1(
410     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
411     const int width, const int height,
412     const void* LIBGAV1_RESTRICT const obmc_prediction,
413     const ptrdiff_t obmc_prediction_stride) {
414   auto* pred = static_cast<uint16_t*>(prediction);
415   const auto* obmc_pred = static_cast<const uint16_t*>(obmc_prediction);
416   const ptrdiff_t pred_stride = prediction_stride / sizeof(pred[0]);
417   const ptrdiff_t obmc_pred_stride =
418       obmc_prediction_stride / sizeof(obmc_pred[0]);
419   assert(width >= 2);
420   assert(height >= 4);
421 
422   if (width == 2) {
423     OverlapBlendFromLeft2xH_SSE4_1(pred, pred_stride, height, obmc_pred);
424     return;
425   }
426   if (width == 4) {
427     OverlapBlendFromLeft4xH_SSE4_1(pred, pred_stride, height, obmc_pred);
428     return;
429   }
430   const __m128i mask_inverter = _mm_set1_epi8(64);
431   const uint8_t* mask = kObmcMask + width - 2;
432   int x = 0;
433   do {
434     pred = static_cast<uint16_t*>(prediction) + x;
435     obmc_pred = static_cast<const uint16_t*>(obmc_prediction) + x;
436     const __m128i mask_val = LoadLo8(mask + x);
437     // 64 - mask
438     const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
439     const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
440     const __m128i masks_lo = _mm_cvtepi8_epi16(masks);
441     const __m128i masks_hi = _mm_cvtepi8_epi16(_mm_srli_si128(masks, 8));
442     int y = height;
443     do {
444       const __m128i pred_val = LoadUnaligned16(pred);
445       const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred);
446       const __m128i terms_lo = _mm_unpacklo_epi16(pred_val, obmc_pred_val);
447       const __m128i terms_hi = _mm_unpackhi_epi16(pred_val, obmc_pred_val);
448       const __m128i result_lo = RightShiftWithRounding_U32(
449           _mm_madd_epi16(terms_lo, masks_lo), kRoundBitsObmcBlend);
450       const __m128i result_hi = RightShiftWithRounding_U32(
451           _mm_madd_epi16(terms_hi, masks_hi), kRoundBitsObmcBlend);
452       StoreUnaligned16(pred, _mm_packus_epi32(result_lo, result_hi));
453 
454       pred += pred_stride;
455       obmc_pred += obmc_pred_stride;
456     } while (--y != 0);
457     x += 8;
458   } while (x < width);
459 }
460 
OverlapBlendFromTop4xH_SSE4_1(uint16_t * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride,const int height,const uint16_t * LIBGAV1_RESTRICT const obmc_prediction)461 inline void OverlapBlendFromTop4xH_SSE4_1(
462     uint16_t* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride,
463     const int height, const uint16_t* LIBGAV1_RESTRICT const obmc_prediction) {
464   constexpr int obmc_pred_stride = 4;
465   uint16_t* pred = prediction;
466   const uint16_t* obmc_pred = obmc_prediction;
467   const __m128i mask_inverter = _mm_set1_epi16(64);
468   const __m128i mask_shuffler = _mm_set_epi32(0x01010101, 0x01010101, 0, 0);
469   const __m128i mask_preinverter = _mm_set1_epi16(-256 | 1);
470   const uint8_t* mask = kObmcMask + height - 2;
471   const int compute_height = height - (height >> 2);
472   const ptrdiff_t pred_stride2 = pred_stride << 1;
473   const ptrdiff_t obmc_pred_stride2 = obmc_pred_stride << 1;
474   int y = 0;
475   do {
476     // First mask in the first half, second mask in the second half.
477     const __m128i mask_val = _mm_shuffle_epi8(Load4(mask + y), mask_shuffler);
478     const __m128i masks =
479         _mm_sub_epi8(mask_inverter, _mm_sign_epi8(mask_val, mask_preinverter));
480     const __m128i masks_lo = _mm_cvtepi8_epi16(masks);
481     const __m128i masks_hi = _mm_cvtepi8_epi16(_mm_srli_si128(masks, 8));
482 
483     const __m128i pred_val = LoadHi8(LoadLo8(pred), pred + pred_stride);
484     const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred);
485     const __m128i terms_lo = _mm_unpacklo_epi16(obmc_pred_val, pred_val);
486     const __m128i terms_hi = _mm_unpackhi_epi16(obmc_pred_val, pred_val);
487     const __m128i result_lo = RightShiftWithRounding_U32(
488         _mm_madd_epi16(terms_lo, masks_lo), kRoundBitsObmcBlend);
489     const __m128i result_hi = RightShiftWithRounding_U32(
490         _mm_madd_epi16(terms_hi, masks_hi), kRoundBitsObmcBlend);
491     const __m128i packed_result = _mm_packus_epi32(result_lo, result_hi);
492 
493     StoreLo8(pred, packed_result);
494     StoreHi8(pred + pred_stride, packed_result);
495     pred += pred_stride2;
496     obmc_pred += obmc_pred_stride2;
497     y += 2;
498   } while (y < compute_height);
499 }
500 
OverlapBlendFromTop10bpp_SSE4_1(void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t prediction_stride,const int width,const int height,const void * LIBGAV1_RESTRICT const obmc_prediction,const ptrdiff_t obmc_prediction_stride)501 void OverlapBlendFromTop10bpp_SSE4_1(
502     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t prediction_stride,
503     const int width, const int height,
504     const void* LIBGAV1_RESTRICT const obmc_prediction,
505     const ptrdiff_t obmc_prediction_stride) {
506   auto* pred = static_cast<uint16_t*>(prediction);
507   const auto* obmc_pred = static_cast<const uint16_t*>(obmc_prediction);
508   const ptrdiff_t pred_stride = prediction_stride / sizeof(pred[0]);
509   const ptrdiff_t obmc_pred_stride =
510       obmc_prediction_stride / sizeof(obmc_pred[0]);
511   assert(width >= 4);
512   assert(height >= 2);
513 
514   if (width == 4) {
515     OverlapBlendFromTop4xH_SSE4_1(pred, pred_stride, height, obmc_pred);
516     return;
517   }
518 
519   const __m128i mask_inverter = _mm_set1_epi8(64);
520   const int compute_height = height - (height >> 2);
521   const uint8_t* mask = kObmcMask + height - 2;
522   pred = static_cast<uint16_t*>(prediction);
523   obmc_pred = static_cast<const uint16_t*>(obmc_prediction);
524   int y = 0;
525   do {
526     const __m128i mask_val = _mm_set1_epi8(mask[y]);
527     // 64 - mask
528     const __m128i obmc_mask_val = _mm_sub_epi8(mask_inverter, mask_val);
529     const __m128i masks = _mm_unpacklo_epi8(mask_val, obmc_mask_val);
530     const __m128i masks_lo = _mm_cvtepi8_epi16(masks);
531     const __m128i masks_hi = _mm_cvtepi8_epi16(_mm_srli_si128(masks, 8));
532     int x = 0;
533     do {
534       const __m128i pred_val = LoadUnaligned16(pred + x);
535       const __m128i obmc_pred_val = LoadUnaligned16(obmc_pred + x);
536       const __m128i terms_lo = _mm_unpacklo_epi16(pred_val, obmc_pred_val);
537       const __m128i terms_hi = _mm_unpackhi_epi16(pred_val, obmc_pred_val);
538       const __m128i result_lo = RightShiftWithRounding_U32(
539           _mm_madd_epi16(terms_lo, masks_lo), kRoundBitsObmcBlend);
540       const __m128i result_hi = RightShiftWithRounding_U32(
541           _mm_madd_epi16(terms_hi, masks_hi), kRoundBitsObmcBlend);
542       StoreUnaligned16(pred + x, _mm_packus_epi32(result_lo, result_hi));
543       x += 8;
544     } while (x < width);
545     pred += pred_stride;
546     obmc_pred += obmc_pred_stride;
547   } while (++y < compute_height);
548 }
549 
Init10bpp()550 void Init10bpp() {
551   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
552   assert(dsp != nullptr);
553 #if DSP_ENABLED_10BPP_SSE4_1(ObmcVertical)
554   dsp->obmc_blend[kObmcDirectionVertical] = OverlapBlendFromTop10bpp_SSE4_1;
555 #endif
556 #if DSP_ENABLED_10BPP_SSE4_1(ObmcHorizontal)
557   dsp->obmc_blend[kObmcDirectionHorizontal] = OverlapBlendFromLeft10bpp_SSE4_1;
558 #endif
559 }
560 
561 }  // namespace
562 }  // namespace high_bitdepth
563 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
564 
ObmcInit_SSE4_1()565 void ObmcInit_SSE4_1() {
566   low_bitdepth::Init8bpp();
567 #if LIBGAV1_MAX_BITDEPTH >= 10
568   high_bitdepth::Init10bpp();
569 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
570 }
571 
572 }  // namespace dsp
573 }  // namespace libgav1
574 
575 #else   // !LIBGAV1_TARGETING_SSE4_1
576 
577 namespace libgav1 {
578 namespace dsp {
579 
ObmcInit_SSE4_1()580 void ObmcInit_SSE4_1() {}
581 
582 }  // namespace dsp
583 }  // namespace libgav1
584 #endif  // LIBGAV1_TARGETING_SSE4_1
585