1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/x86/weight_mask_sse4.h"
16
17 #include "src/utils/cpu.h"
18
19 #if LIBGAV1_TARGETING_SSE4_1
20
21 #include <smmintrin.h>
22
23 #include <cassert>
24 #include <cstddef>
25 #include <cstdint>
26
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/dsp/x86/common_sse4.h"
30 #include "src/utils/common.h"
31
32 namespace libgav1 {
33 namespace dsp {
34 namespace low_bitdepth {
35 namespace {
36
37 constexpr int kRoundingBits8bpp = 4;
38
39 template <bool mask_is_inverse, bool is_store_16>
WeightMask16_SSE4_1(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)40 inline void WeightMask16_SSE4_1(const int16_t* LIBGAV1_RESTRICT prediction_0,
41 const int16_t* LIBGAV1_RESTRICT prediction_1,
42 uint8_t* LIBGAV1_RESTRICT mask,
43 ptrdiff_t mask_stride) {
44 const __m128i pred_00 = LoadAligned16(prediction_0);
45 const __m128i pred_10 = LoadAligned16(prediction_1);
46 const __m128i difference_0 = RightShiftWithRounding_U16(
47 _mm_abs_epi16(_mm_sub_epi16(pred_00, pred_10)), kRoundingBits8bpp);
48 const __m128i scaled_difference_0 = _mm_srli_epi16(difference_0, 4);
49
50 const __m128i pred_01 = LoadAligned16(prediction_0 + 8);
51 const __m128i pred_11 = LoadAligned16(prediction_1 + 8);
52 const __m128i difference_1 = RightShiftWithRounding_U16(
53 _mm_abs_epi16(_mm_sub_epi16(pred_01, pred_11)), kRoundingBits8bpp);
54 const __m128i scaled_difference_1 = _mm_srli_epi16(difference_1, 4);
55
56 const __m128i difference_offset = _mm_set1_epi8(38);
57 const __m128i adjusted_difference =
58 _mm_adds_epu8(_mm_packus_epi16(scaled_difference_0, scaled_difference_1),
59 difference_offset);
60 const __m128i mask_ceiling = _mm_set1_epi8(64);
61 const __m128i mask_value = _mm_min_epi8(adjusted_difference, mask_ceiling);
62 if (mask_is_inverse) {
63 const __m128i inverted_mask_value = _mm_sub_epi8(mask_ceiling, mask_value);
64 if (is_store_16) {
65 StoreAligned16(mask, inverted_mask_value);
66 } else {
67 StoreLo8(mask, inverted_mask_value);
68 StoreHi8(mask + mask_stride, inverted_mask_value);
69 }
70 } else {
71 if (is_store_16) {
72 StoreAligned16(mask, mask_value);
73 } else {
74 StoreLo8(mask, mask_value);
75 StoreHi8(mask + mask_stride, mask_value);
76 }
77 }
78 }
79
80 #define WEIGHT8_PAIR_WITHOUT_STRIDE \
81 WeightMask16_SSE4_1<mask_is_inverse, false>(pred_0, pred_1, mask, mask_stride)
82
83 #define WEIGHT8_PAIR_AND_STRIDE \
84 WEIGHT8_PAIR_WITHOUT_STRIDE; \
85 pred_0 += 8 << 1; \
86 pred_1 += 8 << 1; \
87 mask += mask_stride << 1
88
89 template <bool mask_is_inverse>
WeightMask8x8_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)90 void WeightMask8x8_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
91 const void* LIBGAV1_RESTRICT prediction_1,
92 uint8_t* LIBGAV1_RESTRICT mask,
93 ptrdiff_t mask_stride) {
94 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
95 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
96
97 WEIGHT8_PAIR_AND_STRIDE;
98 WEIGHT8_PAIR_AND_STRIDE;
99 WEIGHT8_PAIR_AND_STRIDE;
100 WEIGHT8_PAIR_WITHOUT_STRIDE;
101 }
102
103 template <bool mask_is_inverse>
WeightMask8x16_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)104 void WeightMask8x16_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
105 const void* LIBGAV1_RESTRICT prediction_1,
106 uint8_t* LIBGAV1_RESTRICT mask,
107 ptrdiff_t mask_stride) {
108 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
109 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
110 int y3 = 3;
111 do {
112 WEIGHT8_PAIR_AND_STRIDE;
113 WEIGHT8_PAIR_AND_STRIDE;
114 } while (--y3 != 0);
115 WEIGHT8_PAIR_AND_STRIDE;
116 WEIGHT8_PAIR_WITHOUT_STRIDE;
117 }
118
119 template <bool mask_is_inverse>
WeightMask8x32_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)120 void WeightMask8x32_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
121 const void* LIBGAV1_RESTRICT prediction_1,
122 uint8_t* LIBGAV1_RESTRICT mask,
123 ptrdiff_t mask_stride) {
124 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
125 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
126 int y5 = 5;
127 do {
128 WEIGHT8_PAIR_AND_STRIDE;
129 WEIGHT8_PAIR_AND_STRIDE;
130 WEIGHT8_PAIR_AND_STRIDE;
131 } while (--y5 != 0);
132 WEIGHT8_PAIR_WITHOUT_STRIDE;
133 }
134
135 #define WEIGHT16_WITHOUT_STRIDE \
136 WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask, mask_stride)
137
138 #define WEIGHT16_AND_STRIDE \
139 WEIGHT16_WITHOUT_STRIDE; \
140 pred_0 += 16; \
141 pred_1 += 16; \
142 mask += mask_stride
143
144 template <bool mask_is_inverse>
WeightMask16x8_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)145 void WeightMask16x8_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
146 const void* LIBGAV1_RESTRICT prediction_1,
147 uint8_t* LIBGAV1_RESTRICT mask,
148 ptrdiff_t mask_stride) {
149 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
150 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
151 int y = 7;
152 do {
153 WEIGHT16_AND_STRIDE;
154 } while (--y != 0);
155 WEIGHT16_WITHOUT_STRIDE;
156 }
157
158 template <bool mask_is_inverse>
WeightMask16x16_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)159 void WeightMask16x16_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
160 const void* LIBGAV1_RESTRICT prediction_1,
161 uint8_t* LIBGAV1_RESTRICT mask,
162 ptrdiff_t mask_stride) {
163 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
164 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
165 int y3 = 5;
166 do {
167 WEIGHT16_AND_STRIDE;
168 WEIGHT16_AND_STRIDE;
169 WEIGHT16_AND_STRIDE;
170 } while (--y3 != 0);
171 WEIGHT16_WITHOUT_STRIDE;
172 }
173
174 template <bool mask_is_inverse>
WeightMask16x32_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)175 void WeightMask16x32_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
176 const void* LIBGAV1_RESTRICT prediction_1,
177 uint8_t* LIBGAV1_RESTRICT mask,
178 ptrdiff_t mask_stride) {
179 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
180 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
181 int y5 = 6;
182 do {
183 WEIGHT16_AND_STRIDE;
184 WEIGHT16_AND_STRIDE;
185 WEIGHT16_AND_STRIDE;
186 WEIGHT16_AND_STRIDE;
187 WEIGHT16_AND_STRIDE;
188 } while (--y5 != 0);
189 WEIGHT16_AND_STRIDE;
190 WEIGHT16_WITHOUT_STRIDE;
191 }
192
193 template <bool mask_is_inverse>
WeightMask16x64_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)194 void WeightMask16x64_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
195 const void* LIBGAV1_RESTRICT prediction_1,
196 uint8_t* LIBGAV1_RESTRICT mask,
197 ptrdiff_t mask_stride) {
198 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
199 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
200 int y3 = 21;
201 do {
202 WEIGHT16_AND_STRIDE;
203 WEIGHT16_AND_STRIDE;
204 WEIGHT16_AND_STRIDE;
205 } while (--y3 != 0);
206 WEIGHT16_WITHOUT_STRIDE;
207 }
208
209 #define WEIGHT32_WITHOUT_STRIDE \
210 WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask, \
211 mask_stride); \
212 WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \
213 mask + 16, mask_stride)
214
215 #define WEIGHT32_AND_STRIDE \
216 WEIGHT32_WITHOUT_STRIDE; \
217 pred_0 += 32; \
218 pred_1 += 32; \
219 mask += mask_stride
220
221 template <bool mask_is_inverse>
WeightMask32x8_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)222 void WeightMask32x8_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
223 const void* LIBGAV1_RESTRICT prediction_1,
224 uint8_t* LIBGAV1_RESTRICT mask,
225 ptrdiff_t mask_stride) {
226 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
227 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
228 WEIGHT32_AND_STRIDE;
229 WEIGHT32_AND_STRIDE;
230 WEIGHT32_AND_STRIDE;
231 WEIGHT32_AND_STRIDE;
232 WEIGHT32_AND_STRIDE;
233 WEIGHT32_AND_STRIDE;
234 WEIGHT32_AND_STRIDE;
235 WEIGHT32_WITHOUT_STRIDE;
236 }
237
238 template <bool mask_is_inverse>
WeightMask32x16_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)239 void WeightMask32x16_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
240 const void* LIBGAV1_RESTRICT prediction_1,
241 uint8_t* LIBGAV1_RESTRICT mask,
242 ptrdiff_t mask_stride) {
243 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
244 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
245 int y3 = 5;
246 do {
247 WEIGHT32_AND_STRIDE;
248 WEIGHT32_AND_STRIDE;
249 WEIGHT32_AND_STRIDE;
250 } while (--y3 != 0);
251 WEIGHT32_WITHOUT_STRIDE;
252 }
253
254 template <bool mask_is_inverse>
WeightMask32x32_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)255 void WeightMask32x32_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
256 const void* LIBGAV1_RESTRICT prediction_1,
257 uint8_t* LIBGAV1_RESTRICT mask,
258 ptrdiff_t mask_stride) {
259 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
260 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
261 int y5 = 6;
262 do {
263 WEIGHT32_AND_STRIDE;
264 WEIGHT32_AND_STRIDE;
265 WEIGHT32_AND_STRIDE;
266 WEIGHT32_AND_STRIDE;
267 WEIGHT32_AND_STRIDE;
268 } while (--y5 != 0);
269 WEIGHT32_AND_STRIDE;
270 WEIGHT32_WITHOUT_STRIDE;
271 }
272
273 template <bool mask_is_inverse>
WeightMask32x64_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)274 void WeightMask32x64_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
275 const void* LIBGAV1_RESTRICT prediction_1,
276 uint8_t* LIBGAV1_RESTRICT mask,
277 ptrdiff_t mask_stride) {
278 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
279 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
280 int y3 = 21;
281 do {
282 WEIGHT32_AND_STRIDE;
283 WEIGHT32_AND_STRIDE;
284 WEIGHT32_AND_STRIDE;
285 } while (--y3 != 0);
286 WEIGHT32_WITHOUT_STRIDE;
287 }
288
289 #define WEIGHT64_WITHOUT_STRIDE \
290 WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask, \
291 mask_stride); \
292 WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \
293 mask + 16, mask_stride); \
294 WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0 + 32, pred_1 + 32, \
295 mask + 32, mask_stride); \
296 WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0 + 48, pred_1 + 48, \
297 mask + 48, mask_stride)
298
299 #define WEIGHT64_AND_STRIDE \
300 WEIGHT64_WITHOUT_STRIDE; \
301 pred_0 += 64; \
302 pred_1 += 64; \
303 mask += mask_stride
304
305 template <bool mask_is_inverse>
WeightMask64x16_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)306 void WeightMask64x16_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
307 const void* LIBGAV1_RESTRICT prediction_1,
308 uint8_t* LIBGAV1_RESTRICT mask,
309 ptrdiff_t mask_stride) {
310 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
311 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
312 int y3 = 0;
313 do {
314 WEIGHT64_AND_STRIDE;
315 WEIGHT64_AND_STRIDE;
316 WEIGHT64_AND_STRIDE;
317 } while (++y3 < 5);
318 WEIGHT64_WITHOUT_STRIDE;
319 }
320
321 template <bool mask_is_inverse>
WeightMask64x32_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)322 void WeightMask64x32_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
323 const void* LIBGAV1_RESTRICT prediction_1,
324 uint8_t* LIBGAV1_RESTRICT mask,
325 ptrdiff_t mask_stride) {
326 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
327 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
328 int y5 = 0;
329 do {
330 WEIGHT64_AND_STRIDE;
331 WEIGHT64_AND_STRIDE;
332 WEIGHT64_AND_STRIDE;
333 WEIGHT64_AND_STRIDE;
334 WEIGHT64_AND_STRIDE;
335 } while (++y5 < 6);
336 WEIGHT64_AND_STRIDE;
337 WEIGHT64_WITHOUT_STRIDE;
338 }
339
340 template <bool mask_is_inverse>
WeightMask64x64_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)341 void WeightMask64x64_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
342 const void* LIBGAV1_RESTRICT prediction_1,
343 uint8_t* LIBGAV1_RESTRICT mask,
344 ptrdiff_t mask_stride) {
345 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
346 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
347 int y3 = 0;
348 do {
349 WEIGHT64_AND_STRIDE;
350 WEIGHT64_AND_STRIDE;
351 WEIGHT64_AND_STRIDE;
352 } while (++y3 < 21);
353 WEIGHT64_WITHOUT_STRIDE;
354 }
355
356 template <bool mask_is_inverse>
WeightMask64x128_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)357 void WeightMask64x128_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
358 const void* LIBGAV1_RESTRICT prediction_1,
359 uint8_t* LIBGAV1_RESTRICT mask,
360 ptrdiff_t mask_stride) {
361 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
362 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
363 int y3 = 0;
364 do {
365 WEIGHT64_AND_STRIDE;
366 WEIGHT64_AND_STRIDE;
367 WEIGHT64_AND_STRIDE;
368 } while (++y3 < 42);
369 WEIGHT64_AND_STRIDE;
370 WEIGHT64_WITHOUT_STRIDE;
371 }
372
373 template <bool mask_is_inverse>
WeightMask128x64_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)374 void WeightMask128x64_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
375 const void* LIBGAV1_RESTRICT prediction_1,
376 uint8_t* LIBGAV1_RESTRICT mask,
377 ptrdiff_t mask_stride) {
378 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
379 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
380 int y3 = 0;
381 const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
382 do {
383 WEIGHT64_WITHOUT_STRIDE;
384 pred_0 += 64;
385 pred_1 += 64;
386 mask += 64;
387 WEIGHT64_WITHOUT_STRIDE;
388 pred_0 += 64;
389 pred_1 += 64;
390 mask += adjusted_mask_stride;
391
392 WEIGHT64_WITHOUT_STRIDE;
393 pred_0 += 64;
394 pred_1 += 64;
395 mask += 64;
396 WEIGHT64_WITHOUT_STRIDE;
397 pred_0 += 64;
398 pred_1 += 64;
399 mask += adjusted_mask_stride;
400
401 WEIGHT64_WITHOUT_STRIDE;
402 pred_0 += 64;
403 pred_1 += 64;
404 mask += 64;
405 WEIGHT64_WITHOUT_STRIDE;
406 pred_0 += 64;
407 pred_1 += 64;
408 mask += adjusted_mask_stride;
409 } while (++y3 < 21);
410 WEIGHT64_WITHOUT_STRIDE;
411 pred_0 += 64;
412 pred_1 += 64;
413 mask += 64;
414 WEIGHT64_WITHOUT_STRIDE;
415 }
416
417 template <bool mask_is_inverse>
WeightMask128x128_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)418 void WeightMask128x128_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
419 const void* LIBGAV1_RESTRICT prediction_1,
420 uint8_t* LIBGAV1_RESTRICT mask,
421 ptrdiff_t mask_stride) {
422 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
423 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
424 int y3 = 0;
425 const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
426 do {
427 WEIGHT64_WITHOUT_STRIDE;
428 pred_0 += 64;
429 pred_1 += 64;
430 mask += 64;
431 WEIGHT64_WITHOUT_STRIDE;
432 pred_0 += 64;
433 pred_1 += 64;
434 mask += adjusted_mask_stride;
435
436 WEIGHT64_WITHOUT_STRIDE;
437 pred_0 += 64;
438 pred_1 += 64;
439 mask += 64;
440 WEIGHT64_WITHOUT_STRIDE;
441 pred_0 += 64;
442 pred_1 += 64;
443 mask += adjusted_mask_stride;
444
445 WEIGHT64_WITHOUT_STRIDE;
446 pred_0 += 64;
447 pred_1 += 64;
448 mask += 64;
449 WEIGHT64_WITHOUT_STRIDE;
450 pred_0 += 64;
451 pred_1 += 64;
452 mask += adjusted_mask_stride;
453 } while (++y3 < 42);
454 WEIGHT64_WITHOUT_STRIDE;
455 pred_0 += 64;
456 pred_1 += 64;
457 mask += 64;
458 WEIGHT64_WITHOUT_STRIDE;
459 pred_0 += 64;
460 pred_1 += 64;
461 mask += adjusted_mask_stride;
462
463 WEIGHT64_WITHOUT_STRIDE;
464 pred_0 += 64;
465 pred_1 += 64;
466 mask += 64;
467 WEIGHT64_WITHOUT_STRIDE;
468 }
469
470 #define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \
471 dsp->weight_mask[w_index][h_index][0] = \
472 WeightMask##width##x##height##_SSE4_1<0>; \
473 dsp->weight_mask[w_index][h_index][1] = \
474 WeightMask##width##x##height##_SSE4_1<1>
Init8bpp()475 void Init8bpp() {
476 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
477 assert(dsp != nullptr);
478 INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0);
479 INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1);
480 INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2);
481 INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0);
482 INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1);
483 INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2);
484 INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3);
485 INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0);
486 INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1);
487 INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2);
488 INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3);
489 INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1);
490 INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2);
491 INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3);
492 INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4);
493 INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3);
494 INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4);
495 }
496
497 } // namespace
498 } // namespace low_bitdepth
499
500 #if LIBGAV1_MAX_BITDEPTH >= 10
501 namespace high_bitdepth {
502 namespace {
503
504 constexpr int kRoundingBits10bpp = 6;
505 constexpr int kScaledDiffShift = 4;
506
507 template <bool mask_is_inverse, bool is_store_16>
WeightMask16_10bpp_SSE4_1(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)508 inline void WeightMask16_10bpp_SSE4_1(
509 const uint16_t* LIBGAV1_RESTRICT prediction_0,
510 const uint16_t* LIBGAV1_RESTRICT prediction_1,
511 uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
512 const __m128i diff_offset = _mm_set1_epi8(38);
513 const __m128i mask_ceiling = _mm_set1_epi8(64);
514 const __m128i zero = _mm_setzero_si128();
515
516 // Range of prediction: [3988, 61532].
517 const __m128i pred_00 = LoadAligned16(prediction_0);
518 const __m128i pred_10 = LoadAligned16(prediction_1);
519 const __m128i pred_lo_00 = _mm_cvtepu16_epi32(pred_00);
520 const __m128i pred_lo_10 = _mm_cvtepu16_epi32(pred_10);
521 const __m128i diff_lo_0 = RightShiftWithRounding_U32(
522 _mm_abs_epi32(_mm_sub_epi32(pred_lo_00, pred_lo_10)), kRoundingBits10bpp);
523
524 const __m128i pred_hi_00 = _mm_unpackhi_epi16(pred_00, zero);
525 const __m128i pred_hi_10 = _mm_unpackhi_epi16(pred_10, zero);
526 const __m128i diff_hi_0 = RightShiftWithRounding_U32(
527 _mm_abs_epi32(_mm_sub_epi32(pred_hi_00, pred_hi_10)), kRoundingBits10bpp);
528
529 const __m128i diff_0 = _mm_packus_epi32(diff_lo_0, diff_hi_0);
530 const __m128i scaled_diff_0 = _mm_srli_epi16(diff_0, kScaledDiffShift);
531
532 const __m128i pred_01 = LoadAligned16(prediction_0 + 8);
533 const __m128i pred_11 = LoadAligned16(prediction_1 + 8);
534 const __m128i pred_lo_01 = _mm_cvtepu16_epi32(pred_01);
535 const __m128i pred_lo_11 = _mm_cvtepu16_epi32(pred_11);
536 const __m128i diff_lo_1 = RightShiftWithRounding_U32(
537 _mm_abs_epi32(_mm_sub_epi32(pred_lo_01, pred_lo_11)), kRoundingBits10bpp);
538
539 const __m128i pred_hi_01 = _mm_unpackhi_epi16(pred_01, zero);
540 const __m128i pred_hi_11 = _mm_unpackhi_epi16(pred_11, zero);
541 const __m128i diff_hi_1 = RightShiftWithRounding_U32(
542 _mm_abs_epi32(_mm_sub_epi32(pred_hi_01, pred_hi_11)), kRoundingBits10bpp);
543
544 const __m128i diff_1 = _mm_packus_epi32(diff_lo_1, diff_hi_1);
545 const __m128i scaled_diff_1 = _mm_srli_epi16(diff_1, kScaledDiffShift);
546
547 const __m128i adjusted_diff = _mm_adds_epu8(
548 _mm_packus_epi16(scaled_diff_0, scaled_diff_1), diff_offset);
549 const __m128i mask_value = _mm_min_epi8(adjusted_diff, mask_ceiling);
550
551 if (mask_is_inverse) {
552 const __m128i inverted_mask_value = _mm_sub_epi8(mask_ceiling, mask_value);
553 if (is_store_16) {
554 StoreAligned16(mask, inverted_mask_value);
555 } else {
556 StoreLo8(mask, inverted_mask_value);
557 StoreHi8(mask + mask_stride, inverted_mask_value);
558 }
559 } else {
560 if (is_store_16) {
561 StoreAligned16(mask, mask_value);
562 } else {
563 StoreLo8(mask, mask_value);
564 StoreHi8(mask + mask_stride, mask_value);
565 }
566 }
567 }
568
569 #define WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP \
570 WeightMask16_10bpp_SSE4_1<mask_is_inverse, false>(pred_0, pred_1, mask, \
571 mask_stride)
572
573 #define WEIGHT8_PAIR_AND_STRIDE_10BPP \
574 WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP; \
575 pred_0 += 8 << 1; \
576 pred_1 += 8 << 1; \
577 mask += mask_stride << 1
578
579 template <bool mask_is_inverse>
WeightMask8x8_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)580 void WeightMask8x8_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
581 const void* LIBGAV1_RESTRICT prediction_1,
582 uint8_t* LIBGAV1_RESTRICT mask,
583 ptrdiff_t mask_stride) {
584 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
585 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
586
587 WEIGHT8_PAIR_AND_STRIDE_10BPP;
588 WEIGHT8_PAIR_AND_STRIDE_10BPP;
589 WEIGHT8_PAIR_AND_STRIDE_10BPP;
590 WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP;
591 }
592
593 template <bool mask_is_inverse>
WeightMask8x16_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)594 void WeightMask8x16_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
595 const void* LIBGAV1_RESTRICT prediction_1,
596 uint8_t* LIBGAV1_RESTRICT mask,
597 ptrdiff_t mask_stride) {
598 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
599 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
600 int y3 = 3;
601 do {
602 WEIGHT8_PAIR_AND_STRIDE_10BPP;
603 WEIGHT8_PAIR_AND_STRIDE_10BPP;
604 } while (--y3 != 0);
605 WEIGHT8_PAIR_AND_STRIDE_10BPP;
606 WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP;
607 }
608
609 template <bool mask_is_inverse>
WeightMask8x32_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)610 void WeightMask8x32_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
611 const void* LIBGAV1_RESTRICT prediction_1,
612 uint8_t* LIBGAV1_RESTRICT mask,
613 ptrdiff_t mask_stride) {
614 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
615 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
616 int y5 = 5;
617 do {
618 WEIGHT8_PAIR_AND_STRIDE_10BPP;
619 WEIGHT8_PAIR_AND_STRIDE_10BPP;
620 WEIGHT8_PAIR_AND_STRIDE_10BPP;
621 } while (--y5 != 0);
622 WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP;
623 }
624
625 #define WEIGHT16_WITHOUT_STRIDE_10BPP \
626 WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask, \
627 mask_stride)
628
629 #define WEIGHT16_AND_STRIDE_10BPP \
630 WEIGHT16_WITHOUT_STRIDE_10BPP; \
631 pred_0 += 16; \
632 pred_1 += 16; \
633 mask += mask_stride
634
635 template <bool mask_is_inverse>
WeightMask16x8_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)636 void WeightMask16x8_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
637 const void* LIBGAV1_RESTRICT prediction_1,
638 uint8_t* LIBGAV1_RESTRICT mask,
639 ptrdiff_t mask_stride) {
640 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
641 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
642 int y = 7;
643 do {
644 WEIGHT16_AND_STRIDE_10BPP;
645 } while (--y != 0);
646 WEIGHT16_WITHOUT_STRIDE_10BPP;
647 }
648
649 template <bool mask_is_inverse>
WeightMask16x16_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)650 void WeightMask16x16_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
651 const void* LIBGAV1_RESTRICT prediction_1,
652 uint8_t* LIBGAV1_RESTRICT mask,
653 ptrdiff_t mask_stride) {
654 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
655 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
656 int y3 = 5;
657 do {
658 WEIGHT16_AND_STRIDE_10BPP;
659 WEIGHT16_AND_STRIDE_10BPP;
660 WEIGHT16_AND_STRIDE_10BPP;
661 } while (--y3 != 0);
662 WEIGHT16_WITHOUT_STRIDE_10BPP;
663 }
664
665 template <bool mask_is_inverse>
WeightMask16x32_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)666 void WeightMask16x32_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
667 const void* LIBGAV1_RESTRICT prediction_1,
668 uint8_t* LIBGAV1_RESTRICT mask,
669 ptrdiff_t mask_stride) {
670 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
671 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
672 int y5 = 6;
673 do {
674 WEIGHT16_AND_STRIDE_10BPP;
675 WEIGHT16_AND_STRIDE_10BPP;
676 WEIGHT16_AND_STRIDE_10BPP;
677 WEIGHT16_AND_STRIDE_10BPP;
678 WEIGHT16_AND_STRIDE_10BPP;
679 } while (--y5 != 0);
680 WEIGHT16_AND_STRIDE_10BPP;
681 WEIGHT16_WITHOUT_STRIDE_10BPP;
682 }
683
684 template <bool mask_is_inverse>
WeightMask16x64_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)685 void WeightMask16x64_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
686 const void* LIBGAV1_RESTRICT prediction_1,
687 uint8_t* LIBGAV1_RESTRICT mask,
688 ptrdiff_t mask_stride) {
689 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
690 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
691 int y3 = 21;
692 do {
693 WEIGHT16_AND_STRIDE_10BPP;
694 WEIGHT16_AND_STRIDE_10BPP;
695 WEIGHT16_AND_STRIDE_10BPP;
696 } while (--y3 != 0);
697 WEIGHT16_WITHOUT_STRIDE_10BPP;
698 }
699
700 #define WEIGHT32_WITHOUT_STRIDE_10BPP \
701 WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask, \
702 mask_stride); \
703 WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \
704 mask + 16, mask_stride)
705
706 #define WEIGHT32_AND_STRIDE_10BPP \
707 WEIGHT32_WITHOUT_STRIDE_10BPP; \
708 pred_0 += 32; \
709 pred_1 += 32; \
710 mask += mask_stride
711
712 template <bool mask_is_inverse>
WeightMask32x8_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)713 void WeightMask32x8_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
714 const void* LIBGAV1_RESTRICT prediction_1,
715 uint8_t* LIBGAV1_RESTRICT mask,
716 ptrdiff_t mask_stride) {
717 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
718 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
719 WEIGHT32_AND_STRIDE_10BPP;
720 WEIGHT32_AND_STRIDE_10BPP;
721 WEIGHT32_AND_STRIDE_10BPP;
722 WEIGHT32_AND_STRIDE_10BPP;
723 WEIGHT32_AND_STRIDE_10BPP;
724 WEIGHT32_AND_STRIDE_10BPP;
725 WEIGHT32_AND_STRIDE_10BPP;
726 WEIGHT32_WITHOUT_STRIDE_10BPP;
727 }
728
729 template <bool mask_is_inverse>
WeightMask32x16_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)730 void WeightMask32x16_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
731 const void* LIBGAV1_RESTRICT prediction_1,
732 uint8_t* LIBGAV1_RESTRICT mask,
733 ptrdiff_t mask_stride) {
734 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
735 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
736 int y3 = 5;
737 do {
738 WEIGHT32_AND_STRIDE_10BPP;
739 WEIGHT32_AND_STRIDE_10BPP;
740 WEIGHT32_AND_STRIDE_10BPP;
741 } while (--y3 != 0);
742 WEIGHT32_WITHOUT_STRIDE_10BPP;
743 }
744
745 template <bool mask_is_inverse>
WeightMask32x32_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)746 void WeightMask32x32_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
747 const void* LIBGAV1_RESTRICT prediction_1,
748 uint8_t* LIBGAV1_RESTRICT mask,
749 ptrdiff_t mask_stride) {
750 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
751 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
752 int y5 = 6;
753 do {
754 WEIGHT32_AND_STRIDE_10BPP;
755 WEIGHT32_AND_STRIDE_10BPP;
756 WEIGHT32_AND_STRIDE_10BPP;
757 WEIGHT32_AND_STRIDE_10BPP;
758 WEIGHT32_AND_STRIDE_10BPP;
759 } while (--y5 != 0);
760 WEIGHT32_AND_STRIDE_10BPP;
761 WEIGHT32_WITHOUT_STRIDE_10BPP;
762 }
763
764 template <bool mask_is_inverse>
WeightMask32x64_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)765 void WeightMask32x64_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
766 const void* LIBGAV1_RESTRICT prediction_1,
767 uint8_t* LIBGAV1_RESTRICT mask,
768 ptrdiff_t mask_stride) {
769 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
770 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
771 int y3 = 21;
772 do {
773 WEIGHT32_AND_STRIDE_10BPP;
774 WEIGHT32_AND_STRIDE_10BPP;
775 WEIGHT32_AND_STRIDE_10BPP;
776 } while (--y3 != 0);
777 WEIGHT32_WITHOUT_STRIDE_10BPP;
778 }
779
780 #define WEIGHT64_WITHOUT_STRIDE_10BPP \
781 WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask, \
782 mask_stride); \
783 WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \
784 mask + 16, mask_stride); \
785 WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0 + 32, pred_1 + 32, \
786 mask + 32, mask_stride); \
787 WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0 + 48, pred_1 + 48, \
788 mask + 48, mask_stride)
789
790 #define WEIGHT64_AND_STRIDE_10BPP \
791 WEIGHT64_WITHOUT_STRIDE_10BPP; \
792 pred_0 += 64; \
793 pred_1 += 64; \
794 mask += mask_stride
795
796 template <bool mask_is_inverse>
WeightMask64x16_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)797 void WeightMask64x16_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
798 const void* LIBGAV1_RESTRICT prediction_1,
799 uint8_t* LIBGAV1_RESTRICT mask,
800 ptrdiff_t mask_stride) {
801 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
802 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
803 int y3 = 5;
804 do {
805 WEIGHT64_AND_STRIDE_10BPP;
806 WEIGHT64_AND_STRIDE_10BPP;
807 WEIGHT64_AND_STRIDE_10BPP;
808 } while (--y3 != 0);
809 WEIGHT64_WITHOUT_STRIDE_10BPP;
810 }
811
812 template <bool mask_is_inverse>
WeightMask64x32_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)813 void WeightMask64x32_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
814 const void* LIBGAV1_RESTRICT prediction_1,
815 uint8_t* LIBGAV1_RESTRICT mask,
816 ptrdiff_t mask_stride) {
817 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
818 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
819 int y5 = 6;
820 do {
821 WEIGHT64_AND_STRIDE_10BPP;
822 WEIGHT64_AND_STRIDE_10BPP;
823 WEIGHT64_AND_STRIDE_10BPP;
824 WEIGHT64_AND_STRIDE_10BPP;
825 WEIGHT64_AND_STRIDE_10BPP;
826 } while (--y5 != 0);
827 WEIGHT64_AND_STRIDE_10BPP;
828 WEIGHT64_WITHOUT_STRIDE_10BPP;
829 }
830
831 template <bool mask_is_inverse>
WeightMask64x64_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)832 void WeightMask64x64_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
833 const void* LIBGAV1_RESTRICT prediction_1,
834 uint8_t* LIBGAV1_RESTRICT mask,
835 ptrdiff_t mask_stride) {
836 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
837 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
838 int y3 = 21;
839 do {
840 WEIGHT64_AND_STRIDE_10BPP;
841 WEIGHT64_AND_STRIDE_10BPP;
842 WEIGHT64_AND_STRIDE_10BPP;
843 } while (--y3 != 0);
844 WEIGHT64_WITHOUT_STRIDE_10BPP;
845 }
846
847 template <bool mask_is_inverse>
WeightMask64x128_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)848 void WeightMask64x128_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
849 const void* LIBGAV1_RESTRICT prediction_1,
850 uint8_t* LIBGAV1_RESTRICT mask,
851 ptrdiff_t mask_stride) {
852 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
853 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
854 int y3 = 42;
855 do {
856 WEIGHT64_AND_STRIDE_10BPP;
857 WEIGHT64_AND_STRIDE_10BPP;
858 WEIGHT64_AND_STRIDE_10BPP;
859 } while (--y3 != 0);
860 WEIGHT64_AND_STRIDE_10BPP;
861 WEIGHT64_WITHOUT_STRIDE_10BPP;
862 }
863
864 template <bool mask_is_inverse>
WeightMask128x64_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)865 void WeightMask128x64_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
866 const void* LIBGAV1_RESTRICT prediction_1,
867 uint8_t* LIBGAV1_RESTRICT mask,
868 ptrdiff_t mask_stride) {
869 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
870 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
871 int y3 = 21;
872 const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
873 do {
874 WEIGHT64_WITHOUT_STRIDE_10BPP;
875 pred_0 += 64;
876 pred_1 += 64;
877 mask += 64;
878 WEIGHT64_WITHOUT_STRIDE_10BPP;
879 pred_0 += 64;
880 pred_1 += 64;
881 mask += adjusted_mask_stride;
882
883 WEIGHT64_WITHOUT_STRIDE_10BPP;
884 pred_0 += 64;
885 pred_1 += 64;
886 mask += 64;
887 WEIGHT64_WITHOUT_STRIDE_10BPP;
888 pred_0 += 64;
889 pred_1 += 64;
890 mask += adjusted_mask_stride;
891
892 WEIGHT64_WITHOUT_STRIDE_10BPP;
893 pred_0 += 64;
894 pred_1 += 64;
895 mask += 64;
896 WEIGHT64_WITHOUT_STRIDE_10BPP;
897 pred_0 += 64;
898 pred_1 += 64;
899 mask += adjusted_mask_stride;
900 } while (--y3 != 0);
901 WEIGHT64_WITHOUT_STRIDE_10BPP;
902 pred_0 += 64;
903 pred_1 += 64;
904 mask += 64;
905 WEIGHT64_WITHOUT_STRIDE_10BPP;
906 }
907
908 template <bool mask_is_inverse>
WeightMask128x128_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)909 void WeightMask128x128_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
910 const void* LIBGAV1_RESTRICT prediction_1,
911 uint8_t* LIBGAV1_RESTRICT mask,
912 ptrdiff_t mask_stride) {
913 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
914 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
915 int y3 = 42;
916 const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
917 do {
918 WEIGHT64_WITHOUT_STRIDE_10BPP;
919 pred_0 += 64;
920 pred_1 += 64;
921 mask += 64;
922 WEIGHT64_WITHOUT_STRIDE_10BPP;
923 pred_0 += 64;
924 pred_1 += 64;
925 mask += adjusted_mask_stride;
926
927 WEIGHT64_WITHOUT_STRIDE_10BPP;
928 pred_0 += 64;
929 pred_1 += 64;
930 mask += 64;
931 WEIGHT64_WITHOUT_STRIDE_10BPP;
932 pred_0 += 64;
933 pred_1 += 64;
934 mask += adjusted_mask_stride;
935
936 WEIGHT64_WITHOUT_STRIDE_10BPP;
937 pred_0 += 64;
938 pred_1 += 64;
939 mask += 64;
940 WEIGHT64_WITHOUT_STRIDE_10BPP;
941 pred_0 += 64;
942 pred_1 += 64;
943 mask += adjusted_mask_stride;
944 } while (--y3 != 0);
945 WEIGHT64_WITHOUT_STRIDE_10BPP;
946 pred_0 += 64;
947 pred_1 += 64;
948 mask += 64;
949 WEIGHT64_WITHOUT_STRIDE_10BPP;
950 pred_0 += 64;
951 pred_1 += 64;
952 mask += adjusted_mask_stride;
953
954 WEIGHT64_WITHOUT_STRIDE_10BPP;
955 pred_0 += 64;
956 pred_1 += 64;
957 mask += 64;
958 WEIGHT64_WITHOUT_STRIDE_10BPP;
959 }
960
961 #define INIT_WEIGHT_MASK_10BPP(width, height, w_index, h_index) \
962 dsp->weight_mask[w_index][h_index][0] = \
963 WeightMask##width##x##height##_10bpp_SSE4_1<0>; \
964 dsp->weight_mask[w_index][h_index][1] = \
965 WeightMask##width##x##height##_10bpp_SSE4_1<1>
Init10bpp()966 void Init10bpp() {
967 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
968 assert(dsp != nullptr);
969 INIT_WEIGHT_MASK_10BPP(8, 8, 0, 0);
970 INIT_WEIGHT_MASK_10BPP(8, 16, 0, 1);
971 INIT_WEIGHT_MASK_10BPP(8, 32, 0, 2);
972 INIT_WEIGHT_MASK_10BPP(16, 8, 1, 0);
973 INIT_WEIGHT_MASK_10BPP(16, 16, 1, 1);
974 INIT_WEIGHT_MASK_10BPP(16, 32, 1, 2);
975 INIT_WEIGHT_MASK_10BPP(16, 64, 1, 3);
976 INIT_WEIGHT_MASK_10BPP(32, 8, 2, 0);
977 INIT_WEIGHT_MASK_10BPP(32, 16, 2, 1);
978 INIT_WEIGHT_MASK_10BPP(32, 32, 2, 2);
979 INIT_WEIGHT_MASK_10BPP(32, 64, 2, 3);
980 INIT_WEIGHT_MASK_10BPP(64, 16, 3, 1);
981 INIT_WEIGHT_MASK_10BPP(64, 32, 3, 2);
982 INIT_WEIGHT_MASK_10BPP(64, 64, 3, 3);
983 INIT_WEIGHT_MASK_10BPP(64, 128, 3, 4);
984 INIT_WEIGHT_MASK_10BPP(128, 64, 4, 3);
985 INIT_WEIGHT_MASK_10BPP(128, 128, 4, 4);
986 }
987
988 } // namespace
989 } // namespace high_bitdepth
990 #endif // LIBGAV1_MAX_BITDEPTH >= 10
991
WeightMaskInit_SSE4_1()992 void WeightMaskInit_SSE4_1() {
993 low_bitdepth::Init8bpp();
994 #if LIBGAV1_MAX_BITDEPTH >= 10
995 high_bitdepth::Init10bpp();
996 #endif
997 }
998
999 } // namespace dsp
1000 } // namespace libgav1
1001
1002 #else // !LIBGAV1_TARGETING_SSE4_1
1003
1004 namespace libgav1 {
1005 namespace dsp {
1006
WeightMaskInit_SSE4_1()1007 void WeightMaskInit_SSE4_1() {}
1008
1009 } // namespace dsp
1010 } // namespace libgav1
1011 #endif // LIBGAV1_TARGETING_SSE4_1
1012