xref: /aosp_15_r20/external/libgav1/src/dsp/x86/weight_mask_sse4.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/x86/weight_mask_sse4.h"
16 
17 #include "src/utils/cpu.h"
18 
19 #if LIBGAV1_TARGETING_SSE4_1
20 
21 #include <smmintrin.h>
22 
23 #include <cassert>
24 #include <cstddef>
25 #include <cstdint>
26 
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/dsp/x86/common_sse4.h"
30 #include "src/utils/common.h"
31 
32 namespace libgav1 {
33 namespace dsp {
34 namespace low_bitdepth {
35 namespace {
36 
37 constexpr int kRoundingBits8bpp = 4;
38 
39 template <bool mask_is_inverse, bool is_store_16>
WeightMask16_SSE4_1(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)40 inline void WeightMask16_SSE4_1(const int16_t* LIBGAV1_RESTRICT prediction_0,
41                                 const int16_t* LIBGAV1_RESTRICT prediction_1,
42                                 uint8_t* LIBGAV1_RESTRICT mask,
43                                 ptrdiff_t mask_stride) {
44   const __m128i pred_00 = LoadAligned16(prediction_0);
45   const __m128i pred_10 = LoadAligned16(prediction_1);
46   const __m128i difference_0 = RightShiftWithRounding_U16(
47       _mm_abs_epi16(_mm_sub_epi16(pred_00, pred_10)), kRoundingBits8bpp);
48   const __m128i scaled_difference_0 = _mm_srli_epi16(difference_0, 4);
49 
50   const __m128i pred_01 = LoadAligned16(prediction_0 + 8);
51   const __m128i pred_11 = LoadAligned16(prediction_1 + 8);
52   const __m128i difference_1 = RightShiftWithRounding_U16(
53       _mm_abs_epi16(_mm_sub_epi16(pred_01, pred_11)), kRoundingBits8bpp);
54   const __m128i scaled_difference_1 = _mm_srli_epi16(difference_1, 4);
55 
56   const __m128i difference_offset = _mm_set1_epi8(38);
57   const __m128i adjusted_difference =
58       _mm_adds_epu8(_mm_packus_epi16(scaled_difference_0, scaled_difference_1),
59                     difference_offset);
60   const __m128i mask_ceiling = _mm_set1_epi8(64);
61   const __m128i mask_value = _mm_min_epi8(adjusted_difference, mask_ceiling);
62   if (mask_is_inverse) {
63     const __m128i inverted_mask_value = _mm_sub_epi8(mask_ceiling, mask_value);
64     if (is_store_16) {
65       StoreAligned16(mask, inverted_mask_value);
66     } else {
67       StoreLo8(mask, inverted_mask_value);
68       StoreHi8(mask + mask_stride, inverted_mask_value);
69     }
70   } else {
71     if (is_store_16) {
72       StoreAligned16(mask, mask_value);
73     } else {
74       StoreLo8(mask, mask_value);
75       StoreHi8(mask + mask_stride, mask_value);
76     }
77   }
78 }
79 
80 #define WEIGHT8_PAIR_WITHOUT_STRIDE \
81   WeightMask16_SSE4_1<mask_is_inverse, false>(pred_0, pred_1, mask, mask_stride)
82 
83 #define WEIGHT8_PAIR_AND_STRIDE \
84   WEIGHT8_PAIR_WITHOUT_STRIDE;  \
85   pred_0 += 8 << 1;             \
86   pred_1 += 8 << 1;             \
87   mask += mask_stride << 1
88 
89 template <bool mask_is_inverse>
WeightMask8x8_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)90 void WeightMask8x8_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
91                           const void* LIBGAV1_RESTRICT prediction_1,
92                           uint8_t* LIBGAV1_RESTRICT mask,
93                           ptrdiff_t mask_stride) {
94   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
95   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
96 
97   WEIGHT8_PAIR_AND_STRIDE;
98   WEIGHT8_PAIR_AND_STRIDE;
99   WEIGHT8_PAIR_AND_STRIDE;
100   WEIGHT8_PAIR_WITHOUT_STRIDE;
101 }
102 
103 template <bool mask_is_inverse>
WeightMask8x16_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)104 void WeightMask8x16_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
105                            const void* LIBGAV1_RESTRICT prediction_1,
106                            uint8_t* LIBGAV1_RESTRICT mask,
107                            ptrdiff_t mask_stride) {
108   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
109   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
110   int y3 = 3;
111   do {
112     WEIGHT8_PAIR_AND_STRIDE;
113     WEIGHT8_PAIR_AND_STRIDE;
114   } while (--y3 != 0);
115   WEIGHT8_PAIR_AND_STRIDE;
116   WEIGHT8_PAIR_WITHOUT_STRIDE;
117 }
118 
119 template <bool mask_is_inverse>
WeightMask8x32_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)120 void WeightMask8x32_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
121                            const void* LIBGAV1_RESTRICT prediction_1,
122                            uint8_t* LIBGAV1_RESTRICT mask,
123                            ptrdiff_t mask_stride) {
124   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
125   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
126   int y5 = 5;
127   do {
128     WEIGHT8_PAIR_AND_STRIDE;
129     WEIGHT8_PAIR_AND_STRIDE;
130     WEIGHT8_PAIR_AND_STRIDE;
131   } while (--y5 != 0);
132   WEIGHT8_PAIR_WITHOUT_STRIDE;
133 }
134 
135 #define WEIGHT16_WITHOUT_STRIDE \
136   WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask, mask_stride)
137 
138 #define WEIGHT16_AND_STRIDE \
139   WEIGHT16_WITHOUT_STRIDE;  \
140   pred_0 += 16;             \
141   pred_1 += 16;             \
142   mask += mask_stride
143 
144 template <bool mask_is_inverse>
WeightMask16x8_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)145 void WeightMask16x8_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
146                            const void* LIBGAV1_RESTRICT prediction_1,
147                            uint8_t* LIBGAV1_RESTRICT mask,
148                            ptrdiff_t mask_stride) {
149   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
150   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
151   int y = 7;
152   do {
153     WEIGHT16_AND_STRIDE;
154   } while (--y != 0);
155   WEIGHT16_WITHOUT_STRIDE;
156 }
157 
158 template <bool mask_is_inverse>
WeightMask16x16_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)159 void WeightMask16x16_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
160                             const void* LIBGAV1_RESTRICT prediction_1,
161                             uint8_t* LIBGAV1_RESTRICT mask,
162                             ptrdiff_t mask_stride) {
163   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
164   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
165   int y3 = 5;
166   do {
167     WEIGHT16_AND_STRIDE;
168     WEIGHT16_AND_STRIDE;
169     WEIGHT16_AND_STRIDE;
170   } while (--y3 != 0);
171   WEIGHT16_WITHOUT_STRIDE;
172 }
173 
174 template <bool mask_is_inverse>
WeightMask16x32_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)175 void WeightMask16x32_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
176                             const void* LIBGAV1_RESTRICT prediction_1,
177                             uint8_t* LIBGAV1_RESTRICT mask,
178                             ptrdiff_t mask_stride) {
179   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
180   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
181   int y5 = 6;
182   do {
183     WEIGHT16_AND_STRIDE;
184     WEIGHT16_AND_STRIDE;
185     WEIGHT16_AND_STRIDE;
186     WEIGHT16_AND_STRIDE;
187     WEIGHT16_AND_STRIDE;
188   } while (--y5 != 0);
189   WEIGHT16_AND_STRIDE;
190   WEIGHT16_WITHOUT_STRIDE;
191 }
192 
193 template <bool mask_is_inverse>
WeightMask16x64_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)194 void WeightMask16x64_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
195                             const void* LIBGAV1_RESTRICT prediction_1,
196                             uint8_t* LIBGAV1_RESTRICT mask,
197                             ptrdiff_t mask_stride) {
198   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
199   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
200   int y3 = 21;
201   do {
202     WEIGHT16_AND_STRIDE;
203     WEIGHT16_AND_STRIDE;
204     WEIGHT16_AND_STRIDE;
205   } while (--y3 != 0);
206   WEIGHT16_WITHOUT_STRIDE;
207 }
208 
209 #define WEIGHT32_WITHOUT_STRIDE                                        \
210   WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask,     \
211                                              mask_stride);             \
212   WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \
213                                              mask + 16, mask_stride)
214 
215 #define WEIGHT32_AND_STRIDE \
216   WEIGHT32_WITHOUT_STRIDE;  \
217   pred_0 += 32;             \
218   pred_1 += 32;             \
219   mask += mask_stride
220 
221 template <bool mask_is_inverse>
WeightMask32x8_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)222 void WeightMask32x8_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
223                            const void* LIBGAV1_RESTRICT prediction_1,
224                            uint8_t* LIBGAV1_RESTRICT mask,
225                            ptrdiff_t mask_stride) {
226   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
227   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
228   WEIGHT32_AND_STRIDE;
229   WEIGHT32_AND_STRIDE;
230   WEIGHT32_AND_STRIDE;
231   WEIGHT32_AND_STRIDE;
232   WEIGHT32_AND_STRIDE;
233   WEIGHT32_AND_STRIDE;
234   WEIGHT32_AND_STRIDE;
235   WEIGHT32_WITHOUT_STRIDE;
236 }
237 
238 template <bool mask_is_inverse>
WeightMask32x16_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)239 void WeightMask32x16_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
240                             const void* LIBGAV1_RESTRICT prediction_1,
241                             uint8_t* LIBGAV1_RESTRICT mask,
242                             ptrdiff_t mask_stride) {
243   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
244   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
245   int y3 = 5;
246   do {
247     WEIGHT32_AND_STRIDE;
248     WEIGHT32_AND_STRIDE;
249     WEIGHT32_AND_STRIDE;
250   } while (--y3 != 0);
251   WEIGHT32_WITHOUT_STRIDE;
252 }
253 
254 template <bool mask_is_inverse>
WeightMask32x32_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)255 void WeightMask32x32_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
256                             const void* LIBGAV1_RESTRICT prediction_1,
257                             uint8_t* LIBGAV1_RESTRICT mask,
258                             ptrdiff_t mask_stride) {
259   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
260   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
261   int y5 = 6;
262   do {
263     WEIGHT32_AND_STRIDE;
264     WEIGHT32_AND_STRIDE;
265     WEIGHT32_AND_STRIDE;
266     WEIGHT32_AND_STRIDE;
267     WEIGHT32_AND_STRIDE;
268   } while (--y5 != 0);
269   WEIGHT32_AND_STRIDE;
270   WEIGHT32_WITHOUT_STRIDE;
271 }
272 
273 template <bool mask_is_inverse>
WeightMask32x64_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)274 void WeightMask32x64_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
275                             const void* LIBGAV1_RESTRICT prediction_1,
276                             uint8_t* LIBGAV1_RESTRICT mask,
277                             ptrdiff_t mask_stride) {
278   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
279   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
280   int y3 = 21;
281   do {
282     WEIGHT32_AND_STRIDE;
283     WEIGHT32_AND_STRIDE;
284     WEIGHT32_AND_STRIDE;
285   } while (--y3 != 0);
286   WEIGHT32_WITHOUT_STRIDE;
287 }
288 
289 #define WEIGHT64_WITHOUT_STRIDE                                        \
290   WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask,     \
291                                              mask_stride);             \
292   WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \
293                                              mask + 16, mask_stride);  \
294   WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0 + 32, pred_1 + 32, \
295                                              mask + 32, mask_stride);  \
296   WeightMask16_SSE4_1<mask_is_inverse, true>(pred_0 + 48, pred_1 + 48, \
297                                              mask + 48, mask_stride)
298 
299 #define WEIGHT64_AND_STRIDE \
300   WEIGHT64_WITHOUT_STRIDE;  \
301   pred_0 += 64;             \
302   pred_1 += 64;             \
303   mask += mask_stride
304 
305 template <bool mask_is_inverse>
WeightMask64x16_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)306 void WeightMask64x16_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
307                             const void* LIBGAV1_RESTRICT prediction_1,
308                             uint8_t* LIBGAV1_RESTRICT mask,
309                             ptrdiff_t mask_stride) {
310   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
311   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
312   int y3 = 0;
313   do {
314     WEIGHT64_AND_STRIDE;
315     WEIGHT64_AND_STRIDE;
316     WEIGHT64_AND_STRIDE;
317   } while (++y3 < 5);
318   WEIGHT64_WITHOUT_STRIDE;
319 }
320 
321 template <bool mask_is_inverse>
WeightMask64x32_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)322 void WeightMask64x32_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
323                             const void* LIBGAV1_RESTRICT prediction_1,
324                             uint8_t* LIBGAV1_RESTRICT mask,
325                             ptrdiff_t mask_stride) {
326   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
327   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
328   int y5 = 0;
329   do {
330     WEIGHT64_AND_STRIDE;
331     WEIGHT64_AND_STRIDE;
332     WEIGHT64_AND_STRIDE;
333     WEIGHT64_AND_STRIDE;
334     WEIGHT64_AND_STRIDE;
335   } while (++y5 < 6);
336   WEIGHT64_AND_STRIDE;
337   WEIGHT64_WITHOUT_STRIDE;
338 }
339 
340 template <bool mask_is_inverse>
WeightMask64x64_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)341 void WeightMask64x64_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
342                             const void* LIBGAV1_RESTRICT prediction_1,
343                             uint8_t* LIBGAV1_RESTRICT mask,
344                             ptrdiff_t mask_stride) {
345   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
346   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
347   int y3 = 0;
348   do {
349     WEIGHT64_AND_STRIDE;
350     WEIGHT64_AND_STRIDE;
351     WEIGHT64_AND_STRIDE;
352   } while (++y3 < 21);
353   WEIGHT64_WITHOUT_STRIDE;
354 }
355 
356 template <bool mask_is_inverse>
WeightMask64x128_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)357 void WeightMask64x128_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
358                              const void* LIBGAV1_RESTRICT prediction_1,
359                              uint8_t* LIBGAV1_RESTRICT mask,
360                              ptrdiff_t mask_stride) {
361   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
362   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
363   int y3 = 0;
364   do {
365     WEIGHT64_AND_STRIDE;
366     WEIGHT64_AND_STRIDE;
367     WEIGHT64_AND_STRIDE;
368   } while (++y3 < 42);
369   WEIGHT64_AND_STRIDE;
370   WEIGHT64_WITHOUT_STRIDE;
371 }
372 
373 template <bool mask_is_inverse>
WeightMask128x64_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)374 void WeightMask128x64_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
375                              const void* LIBGAV1_RESTRICT prediction_1,
376                              uint8_t* LIBGAV1_RESTRICT mask,
377                              ptrdiff_t mask_stride) {
378   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
379   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
380   int y3 = 0;
381   const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
382   do {
383     WEIGHT64_WITHOUT_STRIDE;
384     pred_0 += 64;
385     pred_1 += 64;
386     mask += 64;
387     WEIGHT64_WITHOUT_STRIDE;
388     pred_0 += 64;
389     pred_1 += 64;
390     mask += adjusted_mask_stride;
391 
392     WEIGHT64_WITHOUT_STRIDE;
393     pred_0 += 64;
394     pred_1 += 64;
395     mask += 64;
396     WEIGHT64_WITHOUT_STRIDE;
397     pred_0 += 64;
398     pred_1 += 64;
399     mask += adjusted_mask_stride;
400 
401     WEIGHT64_WITHOUT_STRIDE;
402     pred_0 += 64;
403     pred_1 += 64;
404     mask += 64;
405     WEIGHT64_WITHOUT_STRIDE;
406     pred_0 += 64;
407     pred_1 += 64;
408     mask += adjusted_mask_stride;
409   } while (++y3 < 21);
410   WEIGHT64_WITHOUT_STRIDE;
411   pred_0 += 64;
412   pred_1 += 64;
413   mask += 64;
414   WEIGHT64_WITHOUT_STRIDE;
415 }
416 
417 template <bool mask_is_inverse>
WeightMask128x128_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)418 void WeightMask128x128_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
419                               const void* LIBGAV1_RESTRICT prediction_1,
420                               uint8_t* LIBGAV1_RESTRICT mask,
421                               ptrdiff_t mask_stride) {
422   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
423   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
424   int y3 = 0;
425   const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
426   do {
427     WEIGHT64_WITHOUT_STRIDE;
428     pred_0 += 64;
429     pred_1 += 64;
430     mask += 64;
431     WEIGHT64_WITHOUT_STRIDE;
432     pred_0 += 64;
433     pred_1 += 64;
434     mask += adjusted_mask_stride;
435 
436     WEIGHT64_WITHOUT_STRIDE;
437     pred_0 += 64;
438     pred_1 += 64;
439     mask += 64;
440     WEIGHT64_WITHOUT_STRIDE;
441     pred_0 += 64;
442     pred_1 += 64;
443     mask += adjusted_mask_stride;
444 
445     WEIGHT64_WITHOUT_STRIDE;
446     pred_0 += 64;
447     pred_1 += 64;
448     mask += 64;
449     WEIGHT64_WITHOUT_STRIDE;
450     pred_0 += 64;
451     pred_1 += 64;
452     mask += adjusted_mask_stride;
453   } while (++y3 < 42);
454   WEIGHT64_WITHOUT_STRIDE;
455   pred_0 += 64;
456   pred_1 += 64;
457   mask += 64;
458   WEIGHT64_WITHOUT_STRIDE;
459   pred_0 += 64;
460   pred_1 += 64;
461   mask += adjusted_mask_stride;
462 
463   WEIGHT64_WITHOUT_STRIDE;
464   pred_0 += 64;
465   pred_1 += 64;
466   mask += 64;
467   WEIGHT64_WITHOUT_STRIDE;
468 }
469 
470 #define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \
471   dsp->weight_mask[w_index][h_index][0] =                      \
472       WeightMask##width##x##height##_SSE4_1<0>;                \
473   dsp->weight_mask[w_index][h_index][1] =                      \
474       WeightMask##width##x##height##_SSE4_1<1>
Init8bpp()475 void Init8bpp() {
476   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
477   assert(dsp != nullptr);
478   INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0);
479   INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1);
480   INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2);
481   INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0);
482   INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1);
483   INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2);
484   INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3);
485   INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0);
486   INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1);
487   INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2);
488   INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3);
489   INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1);
490   INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2);
491   INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3);
492   INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4);
493   INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3);
494   INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4);
495 }
496 
497 }  // namespace
498 }  // namespace low_bitdepth
499 
500 #if LIBGAV1_MAX_BITDEPTH >= 10
501 namespace high_bitdepth {
502 namespace {
503 
504 constexpr int kRoundingBits10bpp = 6;
505 constexpr int kScaledDiffShift = 4;
506 
507 template <bool mask_is_inverse, bool is_store_16>
WeightMask16_10bpp_SSE4_1(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)508 inline void WeightMask16_10bpp_SSE4_1(
509     const uint16_t* LIBGAV1_RESTRICT prediction_0,
510     const uint16_t* LIBGAV1_RESTRICT prediction_1,
511     uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
512   const __m128i diff_offset = _mm_set1_epi8(38);
513   const __m128i mask_ceiling = _mm_set1_epi8(64);
514   const __m128i zero = _mm_setzero_si128();
515 
516   // Range of prediction: [3988, 61532].
517   const __m128i pred_00 = LoadAligned16(prediction_0);
518   const __m128i pred_10 = LoadAligned16(prediction_1);
519   const __m128i pred_lo_00 = _mm_cvtepu16_epi32(pred_00);
520   const __m128i pred_lo_10 = _mm_cvtepu16_epi32(pred_10);
521   const __m128i diff_lo_0 = RightShiftWithRounding_U32(
522       _mm_abs_epi32(_mm_sub_epi32(pred_lo_00, pred_lo_10)), kRoundingBits10bpp);
523 
524   const __m128i pred_hi_00 = _mm_unpackhi_epi16(pred_00, zero);
525   const __m128i pred_hi_10 = _mm_unpackhi_epi16(pred_10, zero);
526   const __m128i diff_hi_0 = RightShiftWithRounding_U32(
527       _mm_abs_epi32(_mm_sub_epi32(pred_hi_00, pred_hi_10)), kRoundingBits10bpp);
528 
529   const __m128i diff_0 = _mm_packus_epi32(diff_lo_0, diff_hi_0);
530   const __m128i scaled_diff_0 = _mm_srli_epi16(diff_0, kScaledDiffShift);
531 
532   const __m128i pred_01 = LoadAligned16(prediction_0 + 8);
533   const __m128i pred_11 = LoadAligned16(prediction_1 + 8);
534   const __m128i pred_lo_01 = _mm_cvtepu16_epi32(pred_01);
535   const __m128i pred_lo_11 = _mm_cvtepu16_epi32(pred_11);
536   const __m128i diff_lo_1 = RightShiftWithRounding_U32(
537       _mm_abs_epi32(_mm_sub_epi32(pred_lo_01, pred_lo_11)), kRoundingBits10bpp);
538 
539   const __m128i pred_hi_01 = _mm_unpackhi_epi16(pred_01, zero);
540   const __m128i pred_hi_11 = _mm_unpackhi_epi16(pred_11, zero);
541   const __m128i diff_hi_1 = RightShiftWithRounding_U32(
542       _mm_abs_epi32(_mm_sub_epi32(pred_hi_01, pred_hi_11)), kRoundingBits10bpp);
543 
544   const __m128i diff_1 = _mm_packus_epi32(diff_lo_1, diff_hi_1);
545   const __m128i scaled_diff_1 = _mm_srli_epi16(diff_1, kScaledDiffShift);
546 
547   const __m128i adjusted_diff = _mm_adds_epu8(
548       _mm_packus_epi16(scaled_diff_0, scaled_diff_1), diff_offset);
549   const __m128i mask_value = _mm_min_epi8(adjusted_diff, mask_ceiling);
550 
551   if (mask_is_inverse) {
552     const __m128i inverted_mask_value = _mm_sub_epi8(mask_ceiling, mask_value);
553     if (is_store_16) {
554       StoreAligned16(mask, inverted_mask_value);
555     } else {
556       StoreLo8(mask, inverted_mask_value);
557       StoreHi8(mask + mask_stride, inverted_mask_value);
558     }
559   } else {
560     if (is_store_16) {
561       StoreAligned16(mask, mask_value);
562     } else {
563       StoreLo8(mask, mask_value);
564       StoreHi8(mask + mask_stride, mask_value);
565     }
566   }
567 }
568 
569 #define WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP                                 \
570   WeightMask16_10bpp_SSE4_1<mask_is_inverse, false>(pred_0, pred_1, mask, \
571                                                     mask_stride)
572 
573 #define WEIGHT8_PAIR_AND_STRIDE_10BPP \
574   WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP;  \
575   pred_0 += 8 << 1;                   \
576   pred_1 += 8 << 1;                   \
577   mask += mask_stride << 1
578 
579 template <bool mask_is_inverse>
WeightMask8x8_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)580 void WeightMask8x8_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
581                                 const void* LIBGAV1_RESTRICT prediction_1,
582                                 uint8_t* LIBGAV1_RESTRICT mask,
583                                 ptrdiff_t mask_stride) {
584   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
585   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
586 
587   WEIGHT8_PAIR_AND_STRIDE_10BPP;
588   WEIGHT8_PAIR_AND_STRIDE_10BPP;
589   WEIGHT8_PAIR_AND_STRIDE_10BPP;
590   WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP;
591 }
592 
593 template <bool mask_is_inverse>
WeightMask8x16_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)594 void WeightMask8x16_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
595                                  const void* LIBGAV1_RESTRICT prediction_1,
596                                  uint8_t* LIBGAV1_RESTRICT mask,
597                                  ptrdiff_t mask_stride) {
598   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
599   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
600   int y3 = 3;
601   do {
602     WEIGHT8_PAIR_AND_STRIDE_10BPP;
603     WEIGHT8_PAIR_AND_STRIDE_10BPP;
604   } while (--y3 != 0);
605   WEIGHT8_PAIR_AND_STRIDE_10BPP;
606   WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP;
607 }
608 
609 template <bool mask_is_inverse>
WeightMask8x32_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)610 void WeightMask8x32_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
611                                  const void* LIBGAV1_RESTRICT prediction_1,
612                                  uint8_t* LIBGAV1_RESTRICT mask,
613                                  ptrdiff_t mask_stride) {
614   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
615   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
616   int y5 = 5;
617   do {
618     WEIGHT8_PAIR_AND_STRIDE_10BPP;
619     WEIGHT8_PAIR_AND_STRIDE_10BPP;
620     WEIGHT8_PAIR_AND_STRIDE_10BPP;
621   } while (--y5 != 0);
622   WEIGHT8_PAIR_WITHOUT_STRIDE_10BPP;
623 }
624 
625 #define WEIGHT16_WITHOUT_STRIDE_10BPP                                    \
626   WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask, \
627                                                    mask_stride)
628 
629 #define WEIGHT16_AND_STRIDE_10BPP \
630   WEIGHT16_WITHOUT_STRIDE_10BPP;  \
631   pred_0 += 16;                   \
632   pred_1 += 16;                   \
633   mask += mask_stride
634 
635 template <bool mask_is_inverse>
WeightMask16x8_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)636 void WeightMask16x8_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
637                                  const void* LIBGAV1_RESTRICT prediction_1,
638                                  uint8_t* LIBGAV1_RESTRICT mask,
639                                  ptrdiff_t mask_stride) {
640   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
641   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
642   int y = 7;
643   do {
644     WEIGHT16_AND_STRIDE_10BPP;
645   } while (--y != 0);
646   WEIGHT16_WITHOUT_STRIDE_10BPP;
647 }
648 
649 template <bool mask_is_inverse>
WeightMask16x16_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)650 void WeightMask16x16_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
651                                   const void* LIBGAV1_RESTRICT prediction_1,
652                                   uint8_t* LIBGAV1_RESTRICT mask,
653                                   ptrdiff_t mask_stride) {
654   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
655   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
656   int y3 = 5;
657   do {
658     WEIGHT16_AND_STRIDE_10BPP;
659     WEIGHT16_AND_STRIDE_10BPP;
660     WEIGHT16_AND_STRIDE_10BPP;
661   } while (--y3 != 0);
662   WEIGHT16_WITHOUT_STRIDE_10BPP;
663 }
664 
665 template <bool mask_is_inverse>
WeightMask16x32_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)666 void WeightMask16x32_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
667                                   const void* LIBGAV1_RESTRICT prediction_1,
668                                   uint8_t* LIBGAV1_RESTRICT mask,
669                                   ptrdiff_t mask_stride) {
670   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
671   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
672   int y5 = 6;
673   do {
674     WEIGHT16_AND_STRIDE_10BPP;
675     WEIGHT16_AND_STRIDE_10BPP;
676     WEIGHT16_AND_STRIDE_10BPP;
677     WEIGHT16_AND_STRIDE_10BPP;
678     WEIGHT16_AND_STRIDE_10BPP;
679   } while (--y5 != 0);
680   WEIGHT16_AND_STRIDE_10BPP;
681   WEIGHT16_WITHOUT_STRIDE_10BPP;
682 }
683 
684 template <bool mask_is_inverse>
WeightMask16x64_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)685 void WeightMask16x64_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
686                                   const void* LIBGAV1_RESTRICT prediction_1,
687                                   uint8_t* LIBGAV1_RESTRICT mask,
688                                   ptrdiff_t mask_stride) {
689   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
690   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
691   int y3 = 21;
692   do {
693     WEIGHT16_AND_STRIDE_10BPP;
694     WEIGHT16_AND_STRIDE_10BPP;
695     WEIGHT16_AND_STRIDE_10BPP;
696   } while (--y3 != 0);
697   WEIGHT16_WITHOUT_STRIDE_10BPP;
698 }
699 
700 #define WEIGHT32_WITHOUT_STRIDE_10BPP                                        \
701   WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask,     \
702                                                    mask_stride);             \
703   WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \
704                                                    mask + 16, mask_stride)
705 
706 #define WEIGHT32_AND_STRIDE_10BPP \
707   WEIGHT32_WITHOUT_STRIDE_10BPP;  \
708   pred_0 += 32;                   \
709   pred_1 += 32;                   \
710   mask += mask_stride
711 
712 template <bool mask_is_inverse>
WeightMask32x8_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)713 void WeightMask32x8_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
714                                  const void* LIBGAV1_RESTRICT prediction_1,
715                                  uint8_t* LIBGAV1_RESTRICT mask,
716                                  ptrdiff_t mask_stride) {
717   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
718   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
719   WEIGHT32_AND_STRIDE_10BPP;
720   WEIGHT32_AND_STRIDE_10BPP;
721   WEIGHT32_AND_STRIDE_10BPP;
722   WEIGHT32_AND_STRIDE_10BPP;
723   WEIGHT32_AND_STRIDE_10BPP;
724   WEIGHT32_AND_STRIDE_10BPP;
725   WEIGHT32_AND_STRIDE_10BPP;
726   WEIGHT32_WITHOUT_STRIDE_10BPP;
727 }
728 
729 template <bool mask_is_inverse>
WeightMask32x16_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)730 void WeightMask32x16_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
731                                   const void* LIBGAV1_RESTRICT prediction_1,
732                                   uint8_t* LIBGAV1_RESTRICT mask,
733                                   ptrdiff_t mask_stride) {
734   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
735   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
736   int y3 = 5;
737   do {
738     WEIGHT32_AND_STRIDE_10BPP;
739     WEIGHT32_AND_STRIDE_10BPP;
740     WEIGHT32_AND_STRIDE_10BPP;
741   } while (--y3 != 0);
742   WEIGHT32_WITHOUT_STRIDE_10BPP;
743 }
744 
745 template <bool mask_is_inverse>
WeightMask32x32_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)746 void WeightMask32x32_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
747                                   const void* LIBGAV1_RESTRICT prediction_1,
748                                   uint8_t* LIBGAV1_RESTRICT mask,
749                                   ptrdiff_t mask_stride) {
750   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
751   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
752   int y5 = 6;
753   do {
754     WEIGHT32_AND_STRIDE_10BPP;
755     WEIGHT32_AND_STRIDE_10BPP;
756     WEIGHT32_AND_STRIDE_10BPP;
757     WEIGHT32_AND_STRIDE_10BPP;
758     WEIGHT32_AND_STRIDE_10BPP;
759   } while (--y5 != 0);
760   WEIGHT32_AND_STRIDE_10BPP;
761   WEIGHT32_WITHOUT_STRIDE_10BPP;
762 }
763 
764 template <bool mask_is_inverse>
WeightMask32x64_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)765 void WeightMask32x64_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
766                                   const void* LIBGAV1_RESTRICT prediction_1,
767                                   uint8_t* LIBGAV1_RESTRICT mask,
768                                   ptrdiff_t mask_stride) {
769   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
770   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
771   int y3 = 21;
772   do {
773     WEIGHT32_AND_STRIDE_10BPP;
774     WEIGHT32_AND_STRIDE_10BPP;
775     WEIGHT32_AND_STRIDE_10BPP;
776   } while (--y3 != 0);
777   WEIGHT32_WITHOUT_STRIDE_10BPP;
778 }
779 
780 #define WEIGHT64_WITHOUT_STRIDE_10BPP                                        \
781   WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0, pred_1, mask,     \
782                                                    mask_stride);             \
783   WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0 + 16, pred_1 + 16, \
784                                                    mask + 16, mask_stride);  \
785   WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0 + 32, pred_1 + 32, \
786                                                    mask + 32, mask_stride);  \
787   WeightMask16_10bpp_SSE4_1<mask_is_inverse, true>(pred_0 + 48, pred_1 + 48, \
788                                                    mask + 48, mask_stride)
789 
790 #define WEIGHT64_AND_STRIDE_10BPP \
791   WEIGHT64_WITHOUT_STRIDE_10BPP;  \
792   pred_0 += 64;                   \
793   pred_1 += 64;                   \
794   mask += mask_stride
795 
796 template <bool mask_is_inverse>
WeightMask64x16_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)797 void WeightMask64x16_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
798                                   const void* LIBGAV1_RESTRICT prediction_1,
799                                   uint8_t* LIBGAV1_RESTRICT mask,
800                                   ptrdiff_t mask_stride) {
801   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
802   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
803   int y3 = 5;
804   do {
805     WEIGHT64_AND_STRIDE_10BPP;
806     WEIGHT64_AND_STRIDE_10BPP;
807     WEIGHT64_AND_STRIDE_10BPP;
808   } while (--y3 != 0);
809   WEIGHT64_WITHOUT_STRIDE_10BPP;
810 }
811 
812 template <bool mask_is_inverse>
WeightMask64x32_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)813 void WeightMask64x32_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
814                                   const void* LIBGAV1_RESTRICT prediction_1,
815                                   uint8_t* LIBGAV1_RESTRICT mask,
816                                   ptrdiff_t mask_stride) {
817   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
818   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
819   int y5 = 6;
820   do {
821     WEIGHT64_AND_STRIDE_10BPP;
822     WEIGHT64_AND_STRIDE_10BPP;
823     WEIGHT64_AND_STRIDE_10BPP;
824     WEIGHT64_AND_STRIDE_10BPP;
825     WEIGHT64_AND_STRIDE_10BPP;
826   } while (--y5 != 0);
827   WEIGHT64_AND_STRIDE_10BPP;
828   WEIGHT64_WITHOUT_STRIDE_10BPP;
829 }
830 
831 template <bool mask_is_inverse>
WeightMask64x64_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)832 void WeightMask64x64_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
833                                   const void* LIBGAV1_RESTRICT prediction_1,
834                                   uint8_t* LIBGAV1_RESTRICT mask,
835                                   ptrdiff_t mask_stride) {
836   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
837   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
838   int y3 = 21;
839   do {
840     WEIGHT64_AND_STRIDE_10BPP;
841     WEIGHT64_AND_STRIDE_10BPP;
842     WEIGHT64_AND_STRIDE_10BPP;
843   } while (--y3 != 0);
844   WEIGHT64_WITHOUT_STRIDE_10BPP;
845 }
846 
847 template <bool mask_is_inverse>
WeightMask64x128_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)848 void WeightMask64x128_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
849                                    const void* LIBGAV1_RESTRICT prediction_1,
850                                    uint8_t* LIBGAV1_RESTRICT mask,
851                                    ptrdiff_t mask_stride) {
852   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
853   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
854   int y3 = 42;
855   do {
856     WEIGHT64_AND_STRIDE_10BPP;
857     WEIGHT64_AND_STRIDE_10BPP;
858     WEIGHT64_AND_STRIDE_10BPP;
859   } while (--y3 != 0);
860   WEIGHT64_AND_STRIDE_10BPP;
861   WEIGHT64_WITHOUT_STRIDE_10BPP;
862 }
863 
864 template <bool mask_is_inverse>
WeightMask128x64_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)865 void WeightMask128x64_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
866                                    const void* LIBGAV1_RESTRICT prediction_1,
867                                    uint8_t* LIBGAV1_RESTRICT mask,
868                                    ptrdiff_t mask_stride) {
869   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
870   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
871   int y3 = 21;
872   const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
873   do {
874     WEIGHT64_WITHOUT_STRIDE_10BPP;
875     pred_0 += 64;
876     pred_1 += 64;
877     mask += 64;
878     WEIGHT64_WITHOUT_STRIDE_10BPP;
879     pred_0 += 64;
880     pred_1 += 64;
881     mask += adjusted_mask_stride;
882 
883     WEIGHT64_WITHOUT_STRIDE_10BPP;
884     pred_0 += 64;
885     pred_1 += 64;
886     mask += 64;
887     WEIGHT64_WITHOUT_STRIDE_10BPP;
888     pred_0 += 64;
889     pred_1 += 64;
890     mask += adjusted_mask_stride;
891 
892     WEIGHT64_WITHOUT_STRIDE_10BPP;
893     pred_0 += 64;
894     pred_1 += 64;
895     mask += 64;
896     WEIGHT64_WITHOUT_STRIDE_10BPP;
897     pred_0 += 64;
898     pred_1 += 64;
899     mask += adjusted_mask_stride;
900   } while (--y3 != 0);
901   WEIGHT64_WITHOUT_STRIDE_10BPP;
902   pred_0 += 64;
903   pred_1 += 64;
904   mask += 64;
905   WEIGHT64_WITHOUT_STRIDE_10BPP;
906 }
907 
908 template <bool mask_is_inverse>
WeightMask128x128_10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)909 void WeightMask128x128_10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
910                                     const void* LIBGAV1_RESTRICT prediction_1,
911                                     uint8_t* LIBGAV1_RESTRICT mask,
912                                     ptrdiff_t mask_stride) {
913   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
914   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
915   int y3 = 42;
916   const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
917   do {
918     WEIGHT64_WITHOUT_STRIDE_10BPP;
919     pred_0 += 64;
920     pred_1 += 64;
921     mask += 64;
922     WEIGHT64_WITHOUT_STRIDE_10BPP;
923     pred_0 += 64;
924     pred_1 += 64;
925     mask += adjusted_mask_stride;
926 
927     WEIGHT64_WITHOUT_STRIDE_10BPP;
928     pred_0 += 64;
929     pred_1 += 64;
930     mask += 64;
931     WEIGHT64_WITHOUT_STRIDE_10BPP;
932     pred_0 += 64;
933     pred_1 += 64;
934     mask += adjusted_mask_stride;
935 
936     WEIGHT64_WITHOUT_STRIDE_10BPP;
937     pred_0 += 64;
938     pred_1 += 64;
939     mask += 64;
940     WEIGHT64_WITHOUT_STRIDE_10BPP;
941     pred_0 += 64;
942     pred_1 += 64;
943     mask += adjusted_mask_stride;
944   } while (--y3 != 0);
945   WEIGHT64_WITHOUT_STRIDE_10BPP;
946   pred_0 += 64;
947   pred_1 += 64;
948   mask += 64;
949   WEIGHT64_WITHOUT_STRIDE_10BPP;
950   pred_0 += 64;
951   pred_1 += 64;
952   mask += adjusted_mask_stride;
953 
954   WEIGHT64_WITHOUT_STRIDE_10BPP;
955   pred_0 += 64;
956   pred_1 += 64;
957   mask += 64;
958   WEIGHT64_WITHOUT_STRIDE_10BPP;
959 }
960 
961 #define INIT_WEIGHT_MASK_10BPP(width, height, w_index, h_index) \
962   dsp->weight_mask[w_index][h_index][0] =                       \
963       WeightMask##width##x##height##_10bpp_SSE4_1<0>;           \
964   dsp->weight_mask[w_index][h_index][1] =                       \
965       WeightMask##width##x##height##_10bpp_SSE4_1<1>
Init10bpp()966 void Init10bpp() {
967   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
968   assert(dsp != nullptr);
969   INIT_WEIGHT_MASK_10BPP(8, 8, 0, 0);
970   INIT_WEIGHT_MASK_10BPP(8, 16, 0, 1);
971   INIT_WEIGHT_MASK_10BPP(8, 32, 0, 2);
972   INIT_WEIGHT_MASK_10BPP(16, 8, 1, 0);
973   INIT_WEIGHT_MASK_10BPP(16, 16, 1, 1);
974   INIT_WEIGHT_MASK_10BPP(16, 32, 1, 2);
975   INIT_WEIGHT_MASK_10BPP(16, 64, 1, 3);
976   INIT_WEIGHT_MASK_10BPP(32, 8, 2, 0);
977   INIT_WEIGHT_MASK_10BPP(32, 16, 2, 1);
978   INIT_WEIGHT_MASK_10BPP(32, 32, 2, 2);
979   INIT_WEIGHT_MASK_10BPP(32, 64, 2, 3);
980   INIT_WEIGHT_MASK_10BPP(64, 16, 3, 1);
981   INIT_WEIGHT_MASK_10BPP(64, 32, 3, 2);
982   INIT_WEIGHT_MASK_10BPP(64, 64, 3, 3);
983   INIT_WEIGHT_MASK_10BPP(64, 128, 3, 4);
984   INIT_WEIGHT_MASK_10BPP(128, 64, 4, 3);
985   INIT_WEIGHT_MASK_10BPP(128, 128, 4, 4);
986 }
987 
988 }  // namespace
989 }  // namespace high_bitdepth
990 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
991 
WeightMaskInit_SSE4_1()992 void WeightMaskInit_SSE4_1() {
993   low_bitdepth::Init8bpp();
994 #if LIBGAV1_MAX_BITDEPTH >= 10
995   high_bitdepth::Init10bpp();
996 #endif
997 }
998 
999 }  // namespace dsp
1000 }  // namespace libgav1
1001 
1002 #else   // !LIBGAV1_TARGETING_SSE4_1
1003 
1004 namespace libgav1 {
1005 namespace dsp {
1006 
WeightMaskInit_SSE4_1()1007 void WeightMaskInit_SSE4_1() {}
1008 
1009 }  // namespace dsp
1010 }  // namespace libgav1
1011 #endif  // LIBGAV1_TARGETING_SSE4_1
1012