xref: /aosp_15_r20/external/libgav1/src/dsp/arm/weight_mask_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/arm/weight_mask_neon.h"
16 
17 #include "src/dsp/weight_mask.h"
18 #include "src/utils/cpu.h"
19 
20 #if LIBGAV1_ENABLE_NEON
21 
22 #include <arm_neon.h>
23 
24 #include <cassert>
25 #include <cstddef>
26 #include <cstdint>
27 
28 #include "src/dsp/arm/common_neon.h"
29 #include "src/dsp/constants.h"
30 #include "src/dsp/dsp.h"
31 #include "src/utils/common.h"
32 
33 namespace libgav1 {
34 namespace dsp {
35 namespace {
36 
LoadPred(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1)37 inline int16x8x2_t LoadPred(const int16_t* LIBGAV1_RESTRICT prediction_0,
38                             const int16_t* LIBGAV1_RESTRICT prediction_1) {
39   const int16x8x2_t pred = {vld1q_s16(prediction_0), vld1q_s16(prediction_1)};
40   return pred;
41 }
42 
43 #if LIBGAV1_MAX_BITDEPTH >= 10
LoadPred(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1)44 inline uint16x8x2_t LoadPred(const uint16_t* LIBGAV1_RESTRICT prediction_0,
45                              const uint16_t* LIBGAV1_RESTRICT prediction_1) {
46   const uint16x8x2_t pred = {vld1q_u16(prediction_0), vld1q_u16(prediction_1)};
47   return pred;
48 }
49 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
50 
51 template <int bitdepth>
AbsolutePredDifference(const int16x8x2_t pred)52 inline uint16x8_t AbsolutePredDifference(const int16x8x2_t pred) {
53   static_assert(bitdepth == 8, "");
54   constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4);
55   return vrshrq_n_u16(
56       vreinterpretq_u16_s16(vabdq_s16(pred.val[0], pred.val[1])),
57       rounding_bits);
58 }
59 
60 template <int bitdepth>
AbsolutePredDifference(const uint16x8x2_t pred)61 inline uint16x8_t AbsolutePredDifference(const uint16x8x2_t pred) {
62   constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4);
63   return vrshrq_n_u16(vabdq_u16(pred.val[0], pred.val[1]), rounding_bits);
64 }
65 
66 template <bool mask_is_inverse, int bitdepth>
WeightMask8_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask)67 inline void WeightMask8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
68                              const void* LIBGAV1_RESTRICT prediction_1,
69                              uint8_t* LIBGAV1_RESTRICT mask) {
70   using PredType =
71       typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
72   using PredTypeVecx2 =
73       typename std::conditional<bitdepth == 8, int16x8x2_t, uint16x8x2_t>::type;
74   const PredTypeVecx2 pred =
75       LoadPred(static_cast<const PredType*>(prediction_0),
76                static_cast<const PredType*>(prediction_1));
77   const uint16x8_t difference = AbsolutePredDifference<bitdepth>(pred);
78   const uint8x8_t difference_offset = vdup_n_u8(38);
79   const uint8x8_t mask_ceiling = vdup_n_u8(64);
80   const uint8x8_t adjusted_difference =
81       vqadd_u8(vqshrn_n_u16(difference, 4), difference_offset);
82   const uint8x8_t mask_value = vmin_u8(adjusted_difference, mask_ceiling);
83   if (mask_is_inverse) {
84     const uint8x8_t inverted_mask_value = vsub_u8(mask_ceiling, mask_value);
85     vst1_u8(mask, inverted_mask_value);
86   } else {
87     vst1_u8(mask, mask_value);
88   }
89 }
90 
91 #define WEIGHT8_WITHOUT_STRIDE \
92   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask)
93 
94 #define WEIGHT8_AND_STRIDE \
95   WEIGHT8_WITHOUT_STRIDE;  \
96   pred_0 += 8;             \
97   pred_1 += 8;             \
98   mask += mask_stride
99 
100 // |pred_0| and |pred_1| are cast as int16_t* for the sake of pointer math. They
101 // are uint16_t* for 10bpp and 12bpp, and this is handled in WeightMask8_NEON.
102 template <bool mask_is_inverse, int bitdepth>
WeightMask8x8_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)103 void WeightMask8x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
104                         const void* LIBGAV1_RESTRICT prediction_1,
105                         uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
106   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
107   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
108   int y = 0;
109   do {
110     WEIGHT8_AND_STRIDE;
111   } while (++y < 7);
112   WEIGHT8_WITHOUT_STRIDE;
113 }
114 
115 template <bool mask_is_inverse, int bitdepth>
WeightMask8x16_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)116 void WeightMask8x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
117                          const void* LIBGAV1_RESTRICT prediction_1,
118                          uint8_t* LIBGAV1_RESTRICT mask,
119                          ptrdiff_t mask_stride) {
120   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
121   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
122   int y3 = 0;
123   do {
124     WEIGHT8_AND_STRIDE;
125     WEIGHT8_AND_STRIDE;
126     WEIGHT8_AND_STRIDE;
127   } while (++y3 < 5);
128   WEIGHT8_WITHOUT_STRIDE;
129 }
130 
131 template <bool mask_is_inverse, int bitdepth>
WeightMask8x32_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)132 void WeightMask8x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
133                          const void* LIBGAV1_RESTRICT prediction_1,
134                          uint8_t* LIBGAV1_RESTRICT mask,
135                          ptrdiff_t mask_stride) {
136   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
137   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
138   int y5 = 0;
139   do {
140     WEIGHT8_AND_STRIDE;
141     WEIGHT8_AND_STRIDE;
142     WEIGHT8_AND_STRIDE;
143     WEIGHT8_AND_STRIDE;
144     WEIGHT8_AND_STRIDE;
145   } while (++y5 < 6);
146   WEIGHT8_AND_STRIDE;
147   WEIGHT8_WITHOUT_STRIDE;
148 }
149 
150 #define WEIGHT16_WITHOUT_STRIDE                                      \
151   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask); \
152   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8, mask + 8)
153 
154 #define WEIGHT16_AND_STRIDE \
155   WEIGHT16_WITHOUT_STRIDE;  \
156   pred_0 += 16;             \
157   pred_1 += 16;             \
158   mask += mask_stride
159 
160 template <bool mask_is_inverse, int bitdepth>
WeightMask16x8_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)161 void WeightMask16x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
162                          const void* LIBGAV1_RESTRICT prediction_1,
163                          uint8_t* LIBGAV1_RESTRICT mask,
164                          ptrdiff_t mask_stride) {
165   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
166   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
167   int y = 0;
168   do {
169     WEIGHT16_AND_STRIDE;
170   } while (++y < 7);
171   WEIGHT16_WITHOUT_STRIDE;
172 }
173 
174 template <bool mask_is_inverse, int bitdepth>
WeightMask16x16_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)175 void WeightMask16x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
176                           const void* LIBGAV1_RESTRICT prediction_1,
177                           uint8_t* LIBGAV1_RESTRICT mask,
178                           ptrdiff_t mask_stride) {
179   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
180   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
181   int y3 = 0;
182   do {
183     WEIGHT16_AND_STRIDE;
184     WEIGHT16_AND_STRIDE;
185     WEIGHT16_AND_STRIDE;
186   } while (++y3 < 5);
187   WEIGHT16_WITHOUT_STRIDE;
188 }
189 
190 template <bool mask_is_inverse, int bitdepth>
WeightMask16x32_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)191 void WeightMask16x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
192                           const void* LIBGAV1_RESTRICT prediction_1,
193                           uint8_t* LIBGAV1_RESTRICT mask,
194                           ptrdiff_t mask_stride) {
195   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
196   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
197   int y5 = 0;
198   do {
199     WEIGHT16_AND_STRIDE;
200     WEIGHT16_AND_STRIDE;
201     WEIGHT16_AND_STRIDE;
202     WEIGHT16_AND_STRIDE;
203     WEIGHT16_AND_STRIDE;
204   } while (++y5 < 6);
205   WEIGHT16_AND_STRIDE;
206   WEIGHT16_WITHOUT_STRIDE;
207 }
208 
209 template <bool mask_is_inverse, int bitdepth>
WeightMask16x64_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)210 void WeightMask16x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
211                           const void* LIBGAV1_RESTRICT prediction_1,
212                           uint8_t* LIBGAV1_RESTRICT mask,
213                           ptrdiff_t mask_stride) {
214   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
215   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
216   int y3 = 0;
217   do {
218     WEIGHT16_AND_STRIDE;
219     WEIGHT16_AND_STRIDE;
220     WEIGHT16_AND_STRIDE;
221   } while (++y3 < 21);
222   WEIGHT16_WITHOUT_STRIDE;
223 }
224 
225 #define WEIGHT32_WITHOUT_STRIDE                                         \
226   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask);    \
227   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8,   \
228                                               mask + 8);                \
229   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 16, pred_1 + 16, \
230                                               mask + 16);               \
231   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 24, pred_1 + 24, \
232                                               mask + 24)
233 
234 #define WEIGHT32_AND_STRIDE \
235   WEIGHT32_WITHOUT_STRIDE;  \
236   pred_0 += 32;             \
237   pred_1 += 32;             \
238   mask += mask_stride
239 
240 template <bool mask_is_inverse, int bitdepth>
WeightMask32x8_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)241 void WeightMask32x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
242                          const void* LIBGAV1_RESTRICT prediction_1,
243                          uint8_t* LIBGAV1_RESTRICT mask,
244                          ptrdiff_t mask_stride) {
245   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
246   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
247   WEIGHT32_AND_STRIDE;
248   WEIGHT32_AND_STRIDE;
249   WEIGHT32_AND_STRIDE;
250   WEIGHT32_AND_STRIDE;
251   WEIGHT32_AND_STRIDE;
252   WEIGHT32_AND_STRIDE;
253   WEIGHT32_AND_STRIDE;
254   WEIGHT32_WITHOUT_STRIDE;
255 }
256 
257 template <bool mask_is_inverse, int bitdepth>
WeightMask32x16_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)258 void WeightMask32x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
259                           const void* LIBGAV1_RESTRICT prediction_1,
260                           uint8_t* LIBGAV1_RESTRICT mask,
261                           ptrdiff_t mask_stride) {
262   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
263   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
264   int y3 = 0;
265   do {
266     WEIGHT32_AND_STRIDE;
267     WEIGHT32_AND_STRIDE;
268     WEIGHT32_AND_STRIDE;
269   } while (++y3 < 5);
270   WEIGHT32_WITHOUT_STRIDE;
271 }
272 
273 template <bool mask_is_inverse, int bitdepth>
WeightMask32x32_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)274 void WeightMask32x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
275                           const void* LIBGAV1_RESTRICT prediction_1,
276                           uint8_t* LIBGAV1_RESTRICT mask,
277                           ptrdiff_t mask_stride) {
278   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
279   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
280   int y5 = 0;
281   do {
282     WEIGHT32_AND_STRIDE;
283     WEIGHT32_AND_STRIDE;
284     WEIGHT32_AND_STRIDE;
285     WEIGHT32_AND_STRIDE;
286     WEIGHT32_AND_STRIDE;
287   } while (++y5 < 6);
288   WEIGHT32_AND_STRIDE;
289   WEIGHT32_WITHOUT_STRIDE;
290 }
291 
292 template <bool mask_is_inverse, int bitdepth>
WeightMask32x64_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)293 void WeightMask32x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
294                           const void* LIBGAV1_RESTRICT prediction_1,
295                           uint8_t* LIBGAV1_RESTRICT mask,
296                           ptrdiff_t mask_stride) {
297   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
298   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
299   int y3 = 0;
300   do {
301     WEIGHT32_AND_STRIDE;
302     WEIGHT32_AND_STRIDE;
303     WEIGHT32_AND_STRIDE;
304   } while (++y3 < 21);
305   WEIGHT32_WITHOUT_STRIDE;
306 }
307 
308 #define WEIGHT64_WITHOUT_STRIDE                                         \
309   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask);    \
310   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8,   \
311                                               mask + 8);                \
312   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 16, pred_1 + 16, \
313                                               mask + 16);               \
314   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 24, pred_1 + 24, \
315                                               mask + 24);               \
316   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 32, pred_1 + 32, \
317                                               mask + 32);               \
318   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 40, pred_1 + 40, \
319                                               mask + 40);               \
320   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 48, pred_1 + 48, \
321                                               mask + 48);               \
322   WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 56, pred_1 + 56, \
323                                               mask + 56)
324 
325 #define WEIGHT64_AND_STRIDE \
326   WEIGHT64_WITHOUT_STRIDE;  \
327   pred_0 += 64;             \
328   pred_1 += 64;             \
329   mask += mask_stride
330 
331 template <bool mask_is_inverse, int bitdepth>
WeightMask64x16_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)332 void WeightMask64x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
333                           const void* LIBGAV1_RESTRICT prediction_1,
334                           uint8_t* LIBGAV1_RESTRICT mask,
335                           ptrdiff_t mask_stride) {
336   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
337   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
338   int y3 = 0;
339   do {
340     WEIGHT64_AND_STRIDE;
341     WEIGHT64_AND_STRIDE;
342     WEIGHT64_AND_STRIDE;
343   } while (++y3 < 5);
344   WEIGHT64_WITHOUT_STRIDE;
345 }
346 
347 template <bool mask_is_inverse, int bitdepth>
WeightMask64x32_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)348 void WeightMask64x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
349                           const void* LIBGAV1_RESTRICT prediction_1,
350                           uint8_t* LIBGAV1_RESTRICT mask,
351                           ptrdiff_t mask_stride) {
352   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
353   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
354   int y5 = 0;
355   do {
356     WEIGHT64_AND_STRIDE;
357     WEIGHT64_AND_STRIDE;
358     WEIGHT64_AND_STRIDE;
359     WEIGHT64_AND_STRIDE;
360     WEIGHT64_AND_STRIDE;
361   } while (++y5 < 6);
362   WEIGHT64_AND_STRIDE;
363   WEIGHT64_WITHOUT_STRIDE;
364 }
365 
366 template <bool mask_is_inverse, int bitdepth>
WeightMask64x64_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)367 void WeightMask64x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
368                           const void* LIBGAV1_RESTRICT prediction_1,
369                           uint8_t* LIBGAV1_RESTRICT mask,
370                           ptrdiff_t mask_stride) {
371   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
372   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
373   int y3 = 0;
374   do {
375     WEIGHT64_AND_STRIDE;
376     WEIGHT64_AND_STRIDE;
377     WEIGHT64_AND_STRIDE;
378   } while (++y3 < 21);
379   WEIGHT64_WITHOUT_STRIDE;
380 }
381 
382 template <bool mask_is_inverse, int bitdepth>
WeightMask64x128_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)383 void WeightMask64x128_NEON(const void* LIBGAV1_RESTRICT prediction_0,
384                            const void* LIBGAV1_RESTRICT prediction_1,
385                            uint8_t* LIBGAV1_RESTRICT mask,
386                            ptrdiff_t mask_stride) {
387   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
388   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
389   int y3 = 0;
390   do {
391     WEIGHT64_AND_STRIDE;
392     WEIGHT64_AND_STRIDE;
393     WEIGHT64_AND_STRIDE;
394   } while (++y3 < 42);
395   WEIGHT64_AND_STRIDE;
396   WEIGHT64_WITHOUT_STRIDE;
397 }
398 
399 template <bool mask_is_inverse, int bitdepth>
WeightMask128x64_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)400 void WeightMask128x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
401                            const void* LIBGAV1_RESTRICT prediction_1,
402                            uint8_t* LIBGAV1_RESTRICT mask,
403                            ptrdiff_t mask_stride) {
404   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
405   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
406   int y3 = 0;
407   const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
408   do {
409     WEIGHT64_WITHOUT_STRIDE;
410     pred_0 += 64;
411     pred_1 += 64;
412     mask += 64;
413     WEIGHT64_WITHOUT_STRIDE;
414     pred_0 += 64;
415     pred_1 += 64;
416     mask += adjusted_mask_stride;
417 
418     WEIGHT64_WITHOUT_STRIDE;
419     pred_0 += 64;
420     pred_1 += 64;
421     mask += 64;
422     WEIGHT64_WITHOUT_STRIDE;
423     pred_0 += 64;
424     pred_1 += 64;
425     mask += adjusted_mask_stride;
426 
427     WEIGHT64_WITHOUT_STRIDE;
428     pred_0 += 64;
429     pred_1 += 64;
430     mask += 64;
431     WEIGHT64_WITHOUT_STRIDE;
432     pred_0 += 64;
433     pred_1 += 64;
434     mask += adjusted_mask_stride;
435   } while (++y3 < 21);
436   WEIGHT64_WITHOUT_STRIDE;
437   pred_0 += 64;
438   pred_1 += 64;
439   mask += 64;
440   WEIGHT64_WITHOUT_STRIDE;
441 }
442 
443 template <bool mask_is_inverse, int bitdepth>
WeightMask128x128_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)444 void WeightMask128x128_NEON(const void* LIBGAV1_RESTRICT prediction_0,
445                             const void* LIBGAV1_RESTRICT prediction_1,
446                             uint8_t* LIBGAV1_RESTRICT mask,
447                             ptrdiff_t mask_stride) {
448   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
449   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
450   int y3 = 0;
451   const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
452   do {
453     WEIGHT64_WITHOUT_STRIDE;
454     pred_0 += 64;
455     pred_1 += 64;
456     mask += 64;
457     WEIGHT64_WITHOUT_STRIDE;
458     pred_0 += 64;
459     pred_1 += 64;
460     mask += adjusted_mask_stride;
461 
462     WEIGHT64_WITHOUT_STRIDE;
463     pred_0 += 64;
464     pred_1 += 64;
465     mask += 64;
466     WEIGHT64_WITHOUT_STRIDE;
467     pred_0 += 64;
468     pred_1 += 64;
469     mask += adjusted_mask_stride;
470 
471     WEIGHT64_WITHOUT_STRIDE;
472     pred_0 += 64;
473     pred_1 += 64;
474     mask += 64;
475     WEIGHT64_WITHOUT_STRIDE;
476     pred_0 += 64;
477     pred_1 += 64;
478     mask += adjusted_mask_stride;
479   } while (++y3 < 42);
480   WEIGHT64_WITHOUT_STRIDE;
481   pred_0 += 64;
482   pred_1 += 64;
483   mask += 64;
484   WEIGHT64_WITHOUT_STRIDE;
485   pred_0 += 64;
486   pred_1 += 64;
487   mask += adjusted_mask_stride;
488 
489   WEIGHT64_WITHOUT_STRIDE;
490   pred_0 += 64;
491   pred_1 += 64;
492   mask += 64;
493   WEIGHT64_WITHOUT_STRIDE;
494 }
495 #undef WEIGHT8_WITHOUT_STRIDE
496 #undef WEIGHT8_AND_STRIDE
497 #undef WEIGHT16_WITHOUT_STRIDE
498 #undef WEIGHT16_AND_STRIDE
499 #undef WEIGHT32_WITHOUT_STRIDE
500 #undef WEIGHT32_AND_STRIDE
501 #undef WEIGHT64_WITHOUT_STRIDE
502 #undef WEIGHT64_AND_STRIDE
503 
504 #define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \
505   dsp->weight_mask[w_index][h_index][0] =                      \
506       WeightMask##width##x##height##_NEON<0, 8>;               \
507   dsp->weight_mask[w_index][h_index][1] =                      \
508       WeightMask##width##x##height##_NEON<1, 8>
Init8bpp()509 void Init8bpp() {
510   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
511   assert(dsp != nullptr);
512   INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0);
513   INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1);
514   INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2);
515   INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0);
516   INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1);
517   INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2);
518   INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3);
519   INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0);
520   INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1);
521   INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2);
522   INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3);
523   INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1);
524   INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2);
525   INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3);
526   INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4);
527   INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3);
528   INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4);
529 }
530 #undef INIT_WEIGHT_MASK_8BPP
531 
532 }  // namespace
533 
534 #if LIBGAV1_MAX_BITDEPTH >= 10
535 namespace high_bitdepth {
536 namespace {
537 
538 #define INIT_WEIGHT_MASK_10BPP(width, height, w_index, h_index) \
539   dsp->weight_mask[w_index][h_index][0] =                       \
540       WeightMask##width##x##height##_NEON<0, 10>;               \
541   dsp->weight_mask[w_index][h_index][1] =                       \
542       WeightMask##width##x##height##_NEON<1, 10>
Init10bpp()543 void Init10bpp() {
544   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
545   assert(dsp != nullptr);
546   INIT_WEIGHT_MASK_10BPP(8, 8, 0, 0);
547   INIT_WEIGHT_MASK_10BPP(8, 16, 0, 1);
548   INIT_WEIGHT_MASK_10BPP(8, 32, 0, 2);
549   INIT_WEIGHT_MASK_10BPP(16, 8, 1, 0);
550   INIT_WEIGHT_MASK_10BPP(16, 16, 1, 1);
551   INIT_WEIGHT_MASK_10BPP(16, 32, 1, 2);
552   INIT_WEIGHT_MASK_10BPP(16, 64, 1, 3);
553   INIT_WEIGHT_MASK_10BPP(32, 8, 2, 0);
554   INIT_WEIGHT_MASK_10BPP(32, 16, 2, 1);
555   INIT_WEIGHT_MASK_10BPP(32, 32, 2, 2);
556   INIT_WEIGHT_MASK_10BPP(32, 64, 2, 3);
557   INIT_WEIGHT_MASK_10BPP(64, 16, 3, 1);
558   INIT_WEIGHT_MASK_10BPP(64, 32, 3, 2);
559   INIT_WEIGHT_MASK_10BPP(64, 64, 3, 3);
560   INIT_WEIGHT_MASK_10BPP(64, 128, 3, 4);
561   INIT_WEIGHT_MASK_10BPP(128, 64, 4, 3);
562   INIT_WEIGHT_MASK_10BPP(128, 128, 4, 4);
563 }
564 #undef INIT_WEIGHT_MASK_10BPP
565 
566 }  // namespace
567 }  // namespace high_bitdepth
568 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
WeightMaskInit_NEON()569 void WeightMaskInit_NEON() {
570   Init8bpp();
571 #if LIBGAV1_MAX_BITDEPTH >= 10
572   high_bitdepth::Init10bpp();
573 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
574 }
575 
576 }  // namespace dsp
577 }  // namespace libgav1
578 
579 #else   // !LIBGAV1_ENABLE_NEON
580 
581 namespace libgav1 {
582 namespace dsp {
583 
WeightMaskInit_NEON()584 void WeightMaskInit_NEON() {}
585 
586 }  // namespace dsp
587 }  // namespace libgav1
588 #endif  // LIBGAV1_ENABLE_NEON
589