1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/arm/weight_mask_neon.h"
16
17 #include "src/dsp/weight_mask.h"
18 #include "src/utils/cpu.h"
19
20 #if LIBGAV1_ENABLE_NEON
21
22 #include <arm_neon.h>
23
24 #include <cassert>
25 #include <cstddef>
26 #include <cstdint>
27
28 #include "src/dsp/arm/common_neon.h"
29 #include "src/dsp/constants.h"
30 #include "src/dsp/dsp.h"
31 #include "src/utils/common.h"
32
33 namespace libgav1 {
34 namespace dsp {
35 namespace {
36
LoadPred(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1)37 inline int16x8x2_t LoadPred(const int16_t* LIBGAV1_RESTRICT prediction_0,
38 const int16_t* LIBGAV1_RESTRICT prediction_1) {
39 const int16x8x2_t pred = {vld1q_s16(prediction_0), vld1q_s16(prediction_1)};
40 return pred;
41 }
42
43 #if LIBGAV1_MAX_BITDEPTH >= 10
LoadPred(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1)44 inline uint16x8x2_t LoadPred(const uint16_t* LIBGAV1_RESTRICT prediction_0,
45 const uint16_t* LIBGAV1_RESTRICT prediction_1) {
46 const uint16x8x2_t pred = {vld1q_u16(prediction_0), vld1q_u16(prediction_1)};
47 return pred;
48 }
49 #endif // LIBGAV1_MAX_BITDEPTH >= 10
50
51 template <int bitdepth>
AbsolutePredDifference(const int16x8x2_t pred)52 inline uint16x8_t AbsolutePredDifference(const int16x8x2_t pred) {
53 static_assert(bitdepth == 8, "");
54 constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4);
55 return vrshrq_n_u16(
56 vreinterpretq_u16_s16(vabdq_s16(pred.val[0], pred.val[1])),
57 rounding_bits);
58 }
59
60 template <int bitdepth>
AbsolutePredDifference(const uint16x8x2_t pred)61 inline uint16x8_t AbsolutePredDifference(const uint16x8x2_t pred) {
62 constexpr int rounding_bits = bitdepth - 8 + ((bitdepth == 12) ? 2 : 4);
63 return vrshrq_n_u16(vabdq_u16(pred.val[0], pred.val[1]), rounding_bits);
64 }
65
66 template <bool mask_is_inverse, int bitdepth>
WeightMask8_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask)67 inline void WeightMask8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
68 const void* LIBGAV1_RESTRICT prediction_1,
69 uint8_t* LIBGAV1_RESTRICT mask) {
70 using PredType =
71 typename std::conditional<bitdepth == 8, int16_t, uint16_t>::type;
72 using PredTypeVecx2 =
73 typename std::conditional<bitdepth == 8, int16x8x2_t, uint16x8x2_t>::type;
74 const PredTypeVecx2 pred =
75 LoadPred(static_cast<const PredType*>(prediction_0),
76 static_cast<const PredType*>(prediction_1));
77 const uint16x8_t difference = AbsolutePredDifference<bitdepth>(pred);
78 const uint8x8_t difference_offset = vdup_n_u8(38);
79 const uint8x8_t mask_ceiling = vdup_n_u8(64);
80 const uint8x8_t adjusted_difference =
81 vqadd_u8(vqshrn_n_u16(difference, 4), difference_offset);
82 const uint8x8_t mask_value = vmin_u8(adjusted_difference, mask_ceiling);
83 if (mask_is_inverse) {
84 const uint8x8_t inverted_mask_value = vsub_u8(mask_ceiling, mask_value);
85 vst1_u8(mask, inverted_mask_value);
86 } else {
87 vst1_u8(mask, mask_value);
88 }
89 }
90
91 #define WEIGHT8_WITHOUT_STRIDE \
92 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask)
93
94 #define WEIGHT8_AND_STRIDE \
95 WEIGHT8_WITHOUT_STRIDE; \
96 pred_0 += 8; \
97 pred_1 += 8; \
98 mask += mask_stride
99
100 // |pred_0| and |pred_1| are cast as int16_t* for the sake of pointer math. They
101 // are uint16_t* for 10bpp and 12bpp, and this is handled in WeightMask8_NEON.
102 template <bool mask_is_inverse, int bitdepth>
WeightMask8x8_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)103 void WeightMask8x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
104 const void* LIBGAV1_RESTRICT prediction_1,
105 uint8_t* LIBGAV1_RESTRICT mask, ptrdiff_t mask_stride) {
106 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
107 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
108 int y = 0;
109 do {
110 WEIGHT8_AND_STRIDE;
111 } while (++y < 7);
112 WEIGHT8_WITHOUT_STRIDE;
113 }
114
115 template <bool mask_is_inverse, int bitdepth>
WeightMask8x16_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)116 void WeightMask8x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
117 const void* LIBGAV1_RESTRICT prediction_1,
118 uint8_t* LIBGAV1_RESTRICT mask,
119 ptrdiff_t mask_stride) {
120 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
121 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
122 int y3 = 0;
123 do {
124 WEIGHT8_AND_STRIDE;
125 WEIGHT8_AND_STRIDE;
126 WEIGHT8_AND_STRIDE;
127 } while (++y3 < 5);
128 WEIGHT8_WITHOUT_STRIDE;
129 }
130
131 template <bool mask_is_inverse, int bitdepth>
WeightMask8x32_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)132 void WeightMask8x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
133 const void* LIBGAV1_RESTRICT prediction_1,
134 uint8_t* LIBGAV1_RESTRICT mask,
135 ptrdiff_t mask_stride) {
136 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
137 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
138 int y5 = 0;
139 do {
140 WEIGHT8_AND_STRIDE;
141 WEIGHT8_AND_STRIDE;
142 WEIGHT8_AND_STRIDE;
143 WEIGHT8_AND_STRIDE;
144 WEIGHT8_AND_STRIDE;
145 } while (++y5 < 6);
146 WEIGHT8_AND_STRIDE;
147 WEIGHT8_WITHOUT_STRIDE;
148 }
149
150 #define WEIGHT16_WITHOUT_STRIDE \
151 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask); \
152 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8, mask + 8)
153
154 #define WEIGHT16_AND_STRIDE \
155 WEIGHT16_WITHOUT_STRIDE; \
156 pred_0 += 16; \
157 pred_1 += 16; \
158 mask += mask_stride
159
160 template <bool mask_is_inverse, int bitdepth>
WeightMask16x8_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)161 void WeightMask16x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
162 const void* LIBGAV1_RESTRICT prediction_1,
163 uint8_t* LIBGAV1_RESTRICT mask,
164 ptrdiff_t mask_stride) {
165 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
166 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
167 int y = 0;
168 do {
169 WEIGHT16_AND_STRIDE;
170 } while (++y < 7);
171 WEIGHT16_WITHOUT_STRIDE;
172 }
173
174 template <bool mask_is_inverse, int bitdepth>
WeightMask16x16_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)175 void WeightMask16x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
176 const void* LIBGAV1_RESTRICT prediction_1,
177 uint8_t* LIBGAV1_RESTRICT mask,
178 ptrdiff_t mask_stride) {
179 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
180 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
181 int y3 = 0;
182 do {
183 WEIGHT16_AND_STRIDE;
184 WEIGHT16_AND_STRIDE;
185 WEIGHT16_AND_STRIDE;
186 } while (++y3 < 5);
187 WEIGHT16_WITHOUT_STRIDE;
188 }
189
190 template <bool mask_is_inverse, int bitdepth>
WeightMask16x32_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)191 void WeightMask16x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
192 const void* LIBGAV1_RESTRICT prediction_1,
193 uint8_t* LIBGAV1_RESTRICT mask,
194 ptrdiff_t mask_stride) {
195 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
196 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
197 int y5 = 0;
198 do {
199 WEIGHT16_AND_STRIDE;
200 WEIGHT16_AND_STRIDE;
201 WEIGHT16_AND_STRIDE;
202 WEIGHT16_AND_STRIDE;
203 WEIGHT16_AND_STRIDE;
204 } while (++y5 < 6);
205 WEIGHT16_AND_STRIDE;
206 WEIGHT16_WITHOUT_STRIDE;
207 }
208
209 template <bool mask_is_inverse, int bitdepth>
WeightMask16x64_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)210 void WeightMask16x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
211 const void* LIBGAV1_RESTRICT prediction_1,
212 uint8_t* LIBGAV1_RESTRICT mask,
213 ptrdiff_t mask_stride) {
214 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
215 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
216 int y3 = 0;
217 do {
218 WEIGHT16_AND_STRIDE;
219 WEIGHT16_AND_STRIDE;
220 WEIGHT16_AND_STRIDE;
221 } while (++y3 < 21);
222 WEIGHT16_WITHOUT_STRIDE;
223 }
224
225 #define WEIGHT32_WITHOUT_STRIDE \
226 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask); \
227 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8, \
228 mask + 8); \
229 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 16, pred_1 + 16, \
230 mask + 16); \
231 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 24, pred_1 + 24, \
232 mask + 24)
233
234 #define WEIGHT32_AND_STRIDE \
235 WEIGHT32_WITHOUT_STRIDE; \
236 pred_0 += 32; \
237 pred_1 += 32; \
238 mask += mask_stride
239
240 template <bool mask_is_inverse, int bitdepth>
WeightMask32x8_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)241 void WeightMask32x8_NEON(const void* LIBGAV1_RESTRICT prediction_0,
242 const void* LIBGAV1_RESTRICT prediction_1,
243 uint8_t* LIBGAV1_RESTRICT mask,
244 ptrdiff_t mask_stride) {
245 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
246 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
247 WEIGHT32_AND_STRIDE;
248 WEIGHT32_AND_STRIDE;
249 WEIGHT32_AND_STRIDE;
250 WEIGHT32_AND_STRIDE;
251 WEIGHT32_AND_STRIDE;
252 WEIGHT32_AND_STRIDE;
253 WEIGHT32_AND_STRIDE;
254 WEIGHT32_WITHOUT_STRIDE;
255 }
256
257 template <bool mask_is_inverse, int bitdepth>
WeightMask32x16_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)258 void WeightMask32x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
259 const void* LIBGAV1_RESTRICT prediction_1,
260 uint8_t* LIBGAV1_RESTRICT mask,
261 ptrdiff_t mask_stride) {
262 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
263 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
264 int y3 = 0;
265 do {
266 WEIGHT32_AND_STRIDE;
267 WEIGHT32_AND_STRIDE;
268 WEIGHT32_AND_STRIDE;
269 } while (++y3 < 5);
270 WEIGHT32_WITHOUT_STRIDE;
271 }
272
273 template <bool mask_is_inverse, int bitdepth>
WeightMask32x32_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)274 void WeightMask32x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
275 const void* LIBGAV1_RESTRICT prediction_1,
276 uint8_t* LIBGAV1_RESTRICT mask,
277 ptrdiff_t mask_stride) {
278 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
279 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
280 int y5 = 0;
281 do {
282 WEIGHT32_AND_STRIDE;
283 WEIGHT32_AND_STRIDE;
284 WEIGHT32_AND_STRIDE;
285 WEIGHT32_AND_STRIDE;
286 WEIGHT32_AND_STRIDE;
287 } while (++y5 < 6);
288 WEIGHT32_AND_STRIDE;
289 WEIGHT32_WITHOUT_STRIDE;
290 }
291
292 template <bool mask_is_inverse, int bitdepth>
WeightMask32x64_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)293 void WeightMask32x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
294 const void* LIBGAV1_RESTRICT prediction_1,
295 uint8_t* LIBGAV1_RESTRICT mask,
296 ptrdiff_t mask_stride) {
297 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
298 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
299 int y3 = 0;
300 do {
301 WEIGHT32_AND_STRIDE;
302 WEIGHT32_AND_STRIDE;
303 WEIGHT32_AND_STRIDE;
304 } while (++y3 < 21);
305 WEIGHT32_WITHOUT_STRIDE;
306 }
307
308 #define WEIGHT64_WITHOUT_STRIDE \
309 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0, pred_1, mask); \
310 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 8, pred_1 + 8, \
311 mask + 8); \
312 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 16, pred_1 + 16, \
313 mask + 16); \
314 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 24, pred_1 + 24, \
315 mask + 24); \
316 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 32, pred_1 + 32, \
317 mask + 32); \
318 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 40, pred_1 + 40, \
319 mask + 40); \
320 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 48, pred_1 + 48, \
321 mask + 48); \
322 WeightMask8_NEON<mask_is_inverse, bitdepth>(pred_0 + 56, pred_1 + 56, \
323 mask + 56)
324
325 #define WEIGHT64_AND_STRIDE \
326 WEIGHT64_WITHOUT_STRIDE; \
327 pred_0 += 64; \
328 pred_1 += 64; \
329 mask += mask_stride
330
331 template <bool mask_is_inverse, int bitdepth>
WeightMask64x16_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)332 void WeightMask64x16_NEON(const void* LIBGAV1_RESTRICT prediction_0,
333 const void* LIBGAV1_RESTRICT prediction_1,
334 uint8_t* LIBGAV1_RESTRICT mask,
335 ptrdiff_t mask_stride) {
336 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
337 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
338 int y3 = 0;
339 do {
340 WEIGHT64_AND_STRIDE;
341 WEIGHT64_AND_STRIDE;
342 WEIGHT64_AND_STRIDE;
343 } while (++y3 < 5);
344 WEIGHT64_WITHOUT_STRIDE;
345 }
346
347 template <bool mask_is_inverse, int bitdepth>
WeightMask64x32_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)348 void WeightMask64x32_NEON(const void* LIBGAV1_RESTRICT prediction_0,
349 const void* LIBGAV1_RESTRICT prediction_1,
350 uint8_t* LIBGAV1_RESTRICT mask,
351 ptrdiff_t mask_stride) {
352 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
353 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
354 int y5 = 0;
355 do {
356 WEIGHT64_AND_STRIDE;
357 WEIGHT64_AND_STRIDE;
358 WEIGHT64_AND_STRIDE;
359 WEIGHT64_AND_STRIDE;
360 WEIGHT64_AND_STRIDE;
361 } while (++y5 < 6);
362 WEIGHT64_AND_STRIDE;
363 WEIGHT64_WITHOUT_STRIDE;
364 }
365
366 template <bool mask_is_inverse, int bitdepth>
WeightMask64x64_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)367 void WeightMask64x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
368 const void* LIBGAV1_RESTRICT prediction_1,
369 uint8_t* LIBGAV1_RESTRICT mask,
370 ptrdiff_t mask_stride) {
371 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
372 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
373 int y3 = 0;
374 do {
375 WEIGHT64_AND_STRIDE;
376 WEIGHT64_AND_STRIDE;
377 WEIGHT64_AND_STRIDE;
378 } while (++y3 < 21);
379 WEIGHT64_WITHOUT_STRIDE;
380 }
381
382 template <bool mask_is_inverse, int bitdepth>
WeightMask64x128_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)383 void WeightMask64x128_NEON(const void* LIBGAV1_RESTRICT prediction_0,
384 const void* LIBGAV1_RESTRICT prediction_1,
385 uint8_t* LIBGAV1_RESTRICT mask,
386 ptrdiff_t mask_stride) {
387 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
388 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
389 int y3 = 0;
390 do {
391 WEIGHT64_AND_STRIDE;
392 WEIGHT64_AND_STRIDE;
393 WEIGHT64_AND_STRIDE;
394 } while (++y3 < 42);
395 WEIGHT64_AND_STRIDE;
396 WEIGHT64_WITHOUT_STRIDE;
397 }
398
399 template <bool mask_is_inverse, int bitdepth>
WeightMask128x64_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)400 void WeightMask128x64_NEON(const void* LIBGAV1_RESTRICT prediction_0,
401 const void* LIBGAV1_RESTRICT prediction_1,
402 uint8_t* LIBGAV1_RESTRICT mask,
403 ptrdiff_t mask_stride) {
404 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
405 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
406 int y3 = 0;
407 const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
408 do {
409 WEIGHT64_WITHOUT_STRIDE;
410 pred_0 += 64;
411 pred_1 += 64;
412 mask += 64;
413 WEIGHT64_WITHOUT_STRIDE;
414 pred_0 += 64;
415 pred_1 += 64;
416 mask += adjusted_mask_stride;
417
418 WEIGHT64_WITHOUT_STRIDE;
419 pred_0 += 64;
420 pred_1 += 64;
421 mask += 64;
422 WEIGHT64_WITHOUT_STRIDE;
423 pred_0 += 64;
424 pred_1 += 64;
425 mask += adjusted_mask_stride;
426
427 WEIGHT64_WITHOUT_STRIDE;
428 pred_0 += 64;
429 pred_1 += 64;
430 mask += 64;
431 WEIGHT64_WITHOUT_STRIDE;
432 pred_0 += 64;
433 pred_1 += 64;
434 mask += adjusted_mask_stride;
435 } while (++y3 < 21);
436 WEIGHT64_WITHOUT_STRIDE;
437 pred_0 += 64;
438 pred_1 += 64;
439 mask += 64;
440 WEIGHT64_WITHOUT_STRIDE;
441 }
442
443 template <bool mask_is_inverse, int bitdepth>
WeightMask128x128_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT mask,ptrdiff_t mask_stride)444 void WeightMask128x128_NEON(const void* LIBGAV1_RESTRICT prediction_0,
445 const void* LIBGAV1_RESTRICT prediction_1,
446 uint8_t* LIBGAV1_RESTRICT mask,
447 ptrdiff_t mask_stride) {
448 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
449 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
450 int y3 = 0;
451 const ptrdiff_t adjusted_mask_stride = mask_stride - 64;
452 do {
453 WEIGHT64_WITHOUT_STRIDE;
454 pred_0 += 64;
455 pred_1 += 64;
456 mask += 64;
457 WEIGHT64_WITHOUT_STRIDE;
458 pred_0 += 64;
459 pred_1 += 64;
460 mask += adjusted_mask_stride;
461
462 WEIGHT64_WITHOUT_STRIDE;
463 pred_0 += 64;
464 pred_1 += 64;
465 mask += 64;
466 WEIGHT64_WITHOUT_STRIDE;
467 pred_0 += 64;
468 pred_1 += 64;
469 mask += adjusted_mask_stride;
470
471 WEIGHT64_WITHOUT_STRIDE;
472 pred_0 += 64;
473 pred_1 += 64;
474 mask += 64;
475 WEIGHT64_WITHOUT_STRIDE;
476 pred_0 += 64;
477 pred_1 += 64;
478 mask += adjusted_mask_stride;
479 } while (++y3 < 42);
480 WEIGHT64_WITHOUT_STRIDE;
481 pred_0 += 64;
482 pred_1 += 64;
483 mask += 64;
484 WEIGHT64_WITHOUT_STRIDE;
485 pred_0 += 64;
486 pred_1 += 64;
487 mask += adjusted_mask_stride;
488
489 WEIGHT64_WITHOUT_STRIDE;
490 pred_0 += 64;
491 pred_1 += 64;
492 mask += 64;
493 WEIGHT64_WITHOUT_STRIDE;
494 }
495 #undef WEIGHT8_WITHOUT_STRIDE
496 #undef WEIGHT8_AND_STRIDE
497 #undef WEIGHT16_WITHOUT_STRIDE
498 #undef WEIGHT16_AND_STRIDE
499 #undef WEIGHT32_WITHOUT_STRIDE
500 #undef WEIGHT32_AND_STRIDE
501 #undef WEIGHT64_WITHOUT_STRIDE
502 #undef WEIGHT64_AND_STRIDE
503
504 #define INIT_WEIGHT_MASK_8BPP(width, height, w_index, h_index) \
505 dsp->weight_mask[w_index][h_index][0] = \
506 WeightMask##width##x##height##_NEON<0, 8>; \
507 dsp->weight_mask[w_index][h_index][1] = \
508 WeightMask##width##x##height##_NEON<1, 8>
Init8bpp()509 void Init8bpp() {
510 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
511 assert(dsp != nullptr);
512 INIT_WEIGHT_MASK_8BPP(8, 8, 0, 0);
513 INIT_WEIGHT_MASK_8BPP(8, 16, 0, 1);
514 INIT_WEIGHT_MASK_8BPP(8, 32, 0, 2);
515 INIT_WEIGHT_MASK_8BPP(16, 8, 1, 0);
516 INIT_WEIGHT_MASK_8BPP(16, 16, 1, 1);
517 INIT_WEIGHT_MASK_8BPP(16, 32, 1, 2);
518 INIT_WEIGHT_MASK_8BPP(16, 64, 1, 3);
519 INIT_WEIGHT_MASK_8BPP(32, 8, 2, 0);
520 INIT_WEIGHT_MASK_8BPP(32, 16, 2, 1);
521 INIT_WEIGHT_MASK_8BPP(32, 32, 2, 2);
522 INIT_WEIGHT_MASK_8BPP(32, 64, 2, 3);
523 INIT_WEIGHT_MASK_8BPP(64, 16, 3, 1);
524 INIT_WEIGHT_MASK_8BPP(64, 32, 3, 2);
525 INIT_WEIGHT_MASK_8BPP(64, 64, 3, 3);
526 INIT_WEIGHT_MASK_8BPP(64, 128, 3, 4);
527 INIT_WEIGHT_MASK_8BPP(128, 64, 4, 3);
528 INIT_WEIGHT_MASK_8BPP(128, 128, 4, 4);
529 }
530 #undef INIT_WEIGHT_MASK_8BPP
531
532 } // namespace
533
534 #if LIBGAV1_MAX_BITDEPTH >= 10
535 namespace high_bitdepth {
536 namespace {
537
538 #define INIT_WEIGHT_MASK_10BPP(width, height, w_index, h_index) \
539 dsp->weight_mask[w_index][h_index][0] = \
540 WeightMask##width##x##height##_NEON<0, 10>; \
541 dsp->weight_mask[w_index][h_index][1] = \
542 WeightMask##width##x##height##_NEON<1, 10>
Init10bpp()543 void Init10bpp() {
544 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
545 assert(dsp != nullptr);
546 INIT_WEIGHT_MASK_10BPP(8, 8, 0, 0);
547 INIT_WEIGHT_MASK_10BPP(8, 16, 0, 1);
548 INIT_WEIGHT_MASK_10BPP(8, 32, 0, 2);
549 INIT_WEIGHT_MASK_10BPP(16, 8, 1, 0);
550 INIT_WEIGHT_MASK_10BPP(16, 16, 1, 1);
551 INIT_WEIGHT_MASK_10BPP(16, 32, 1, 2);
552 INIT_WEIGHT_MASK_10BPP(16, 64, 1, 3);
553 INIT_WEIGHT_MASK_10BPP(32, 8, 2, 0);
554 INIT_WEIGHT_MASK_10BPP(32, 16, 2, 1);
555 INIT_WEIGHT_MASK_10BPP(32, 32, 2, 2);
556 INIT_WEIGHT_MASK_10BPP(32, 64, 2, 3);
557 INIT_WEIGHT_MASK_10BPP(64, 16, 3, 1);
558 INIT_WEIGHT_MASK_10BPP(64, 32, 3, 2);
559 INIT_WEIGHT_MASK_10BPP(64, 64, 3, 3);
560 INIT_WEIGHT_MASK_10BPP(64, 128, 3, 4);
561 INIT_WEIGHT_MASK_10BPP(128, 64, 4, 3);
562 INIT_WEIGHT_MASK_10BPP(128, 128, 4, 4);
563 }
564 #undef INIT_WEIGHT_MASK_10BPP
565
566 } // namespace
567 } // namespace high_bitdepth
568 #endif // LIBGAV1_MAX_BITDEPTH >= 10
WeightMaskInit_NEON()569 void WeightMaskInit_NEON() {
570 Init8bpp();
571 #if LIBGAV1_MAX_BITDEPTH >= 10
572 high_bitdepth::Init10bpp();
573 #endif // LIBGAV1_MAX_BITDEPTH >= 10
574 }
575
576 } // namespace dsp
577 } // namespace libgav1
578
579 #else // !LIBGAV1_ENABLE_NEON
580
581 namespace libgav1 {
582 namespace dsp {
583
WeightMaskInit_NEON()584 void WeightMaskInit_NEON() {}
585
586 } // namespace dsp
587 } // namespace libgav1
588 #endif // LIBGAV1_ENABLE_NEON
589