xref: /aosp_15_r20/external/libaom/aom_dsp/arm/blend_a64_mask_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/arm/blend_neon.h"
20 #include "aom_dsp/arm/mem_neon.h"
21 #include "aom_dsp/blend.h"
22 
alpha_blend_a64_d16_u16x8(uint16x8_t m,uint16x8_t a,uint16x8_t b,uint16x8_t round_offset)23 static uint8x8_t alpha_blend_a64_d16_u16x8(uint16x8_t m, uint16x8_t a,
24                                            uint16x8_t b,
25                                            uint16x8_t round_offset) {
26   const uint16x8_t m_inv = vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m);
27 
28   uint32x4_t blend_u32_lo = vmull_u16(vget_low_u16(m), vget_low_u16(a));
29   uint32x4_t blend_u32_hi = vmull_u16(vget_high_u16(m), vget_high_u16(a));
30 
31   blend_u32_lo = vmlal_u16(blend_u32_lo, vget_low_u16(m_inv), vget_low_u16(b));
32   blend_u32_hi =
33       vmlal_u16(blend_u32_hi, vget_high_u16(m_inv), vget_high_u16(b));
34 
35   uint16x4_t blend_u16_lo = vshrn_n_u32(blend_u32_lo, AOM_BLEND_A64_ROUND_BITS);
36   uint16x4_t blend_u16_hi = vshrn_n_u32(blend_u32_hi, AOM_BLEND_A64_ROUND_BITS);
37 
38   uint16x8_t res = vcombine_u16(blend_u16_lo, blend_u16_hi);
39 
40   res = vqsubq_u16(res, round_offset);
41 
42   return vqrshrn_n_u16(res,
43                        2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS);
44 }
45 
aom_lowbd_blend_a64_d16_mask_neon(uint8_t * dst,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,ConvolveParams * conv_params)46 void aom_lowbd_blend_a64_d16_mask_neon(
47     uint8_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
48     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
49     const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh,
50     ConvolveParams *conv_params) {
51   (void)conv_params;
52 
53   const int bd = 8;
54   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
55   const int round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
56                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
57   const uint16x8_t offset_vec = vdupq_n_u16(round_offset);
58 
59   assert(IMPLIES((void *)src0 == dst, src0_stride == dst_stride));
60   assert(IMPLIES((void *)src1 == dst, src1_stride == dst_stride));
61 
62   assert(h >= 4);
63   assert(w >= 4);
64   assert(IS_POWER_OF_TWO(h));
65   assert(IS_POWER_OF_TWO(w));
66 
67   if (subw == 0 && subh == 0) {
68     if (w >= 8) {
69       do {
70         int i = 0;
71         do {
72           uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i));
73           uint16x8_t s0 = vld1q_u16(src0 + i);
74           uint16x8_t s1 = vld1q_u16(src1 + i);
75 
76           uint8x8_t blend = alpha_blend_a64_d16_u16x8(m0, s0, s1, offset_vec);
77 
78           vst1_u8(dst + i, blend);
79           i += 8;
80         } while (i < w);
81 
82         mask += mask_stride;
83         src0 += src0_stride;
84         src1 += src1_stride;
85         dst += dst_stride;
86       } while (--h != 0);
87     } else {
88       do {
89         uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride));
90         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
91         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
92 
93         uint8x8_t blend = alpha_blend_a64_d16_u16x8(m0, s0, s1, offset_vec);
94 
95         store_u8x4_strided_x2(dst, dst_stride, blend);
96 
97         mask += 2 * mask_stride;
98         src0 += 2 * src0_stride;
99         src1 += 2 * src1_stride;
100         dst += 2 * dst_stride;
101         h -= 2;
102       } while (h != 0);
103     }
104   } else if (subw == 1 && subh == 1) {
105     if (w >= 8) {
106       do {
107         int i = 0;
108         do {
109           uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + 2 * i);
110           uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + 2 * i);
111           uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 2 * i + 8);
112           uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 2 * i + 8);
113           uint16x8_t s0 = vld1q_u16(src0 + i);
114           uint16x8_t s1 = vld1q_u16(src1 + i);
115 
116           uint16x8_t m_avg =
117               vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));
118 
119           uint8x8_t blend =
120               alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec);
121 
122           vst1_u8(dst + i, blend);
123           i += 8;
124         } while (i < w);
125 
126         mask += 2 * mask_stride;
127         src0 += src0_stride;
128         src1 += src1_stride;
129         dst += dst_stride;
130       } while (--h != 0);
131     } else {
132       do {
133         uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
134         uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
135         uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride);
136         uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride);
137         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
138         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
139 
140         uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));
141         uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec);
142 
143         store_u8x4_strided_x2(dst, dst_stride, blend);
144 
145         mask += 4 * mask_stride;
146         src0 += 2 * src0_stride;
147         src1 += 2 * src1_stride;
148         dst += 2 * dst_stride;
149         h -= 2;
150       } while (h != 0);
151     }
152   } else if (subw == 1 && subh == 0) {
153     if (w >= 8) {
154       do {
155         int i = 0;
156         do {
157           uint8x8_t m0 = vld1_u8(mask + 2 * i);
158           uint8x8_t m1 = vld1_u8(mask + 2 * i + 8);
159           uint16x8_t s0 = vld1q_u16(src0 + i);
160           uint16x8_t s1 = vld1q_u16(src1 + i);
161 
162           uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));
163           uint8x8_t blend =
164               alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec);
165 
166           vst1_u8(dst + i, blend);
167           i += 8;
168         } while (i < w);
169 
170         mask += mask_stride;
171         src0 += src0_stride;
172         src1 += src1_stride;
173         dst += dst_stride;
174       } while (--h != 0);
175     } else {
176       do {
177         uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
178         uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
179         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
180         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
181 
182         uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));
183         uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec);
184 
185         store_u8x4_strided_x2(dst, dst_stride, blend);
186 
187         mask += 2 * mask_stride;
188         src0 += 2 * src0_stride;
189         src1 += 2 * src1_stride;
190         dst += 2 * dst_stride;
191         h -= 2;
192       } while (h != 0);
193     }
194   } else {
195     if (w >= 8) {
196       do {
197         int i = 0;
198         do {
199           uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i);
200           uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i);
201           uint16x8_t s0 = vld1q_u16(src0 + i);
202           uint16x8_t s1 = vld1q_u16(src1 + i);
203 
204           uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1));
205           uint8x8_t blend =
206               alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec);
207 
208           vst1_u8(dst + i, blend);
209           i += 8;
210         } while (i < w);
211 
212         mask += 2 * mask_stride;
213         src0 += src0_stride;
214         src1 += src1_stride;
215         dst += dst_stride;
216       } while (--h != 0);
217     } else {
218       do {
219         uint8x8_t m0_2 =
220             load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride);
221         uint8x8_t m1_3 =
222             load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride);
223         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
224         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
225 
226         uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3));
227         uint8x8_t blend = alpha_blend_a64_d16_u16x8(m_avg, s0, s1, offset_vec);
228 
229         store_u8x4_strided_x2(dst, dst_stride, blend);
230 
231         mask += 4 * mask_stride;
232         src0 += 2 * src0_stride;
233         src1 += 2 * src1_stride;
234         dst += 2 * dst_stride;
235         h -= 2;
236       } while (h != 0);
237     }
238   }
239 }
240 
aom_blend_a64_mask_neon(uint8_t * dst,uint32_t dst_stride,const uint8_t * src0,uint32_t src0_stride,const uint8_t * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh)241 void aom_blend_a64_mask_neon(uint8_t *dst, uint32_t dst_stride,
242                              const uint8_t *src0, uint32_t src0_stride,
243                              const uint8_t *src1, uint32_t src1_stride,
244                              const uint8_t *mask, uint32_t mask_stride, int w,
245                              int h, int subw, int subh) {
246   assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
247   assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
248 
249   assert(h >= 1);
250   assert(w >= 1);
251   assert(IS_POWER_OF_TWO(h));
252   assert(IS_POWER_OF_TWO(w));
253 
254   if ((subw | subh) == 0) {
255     if (w > 8) {
256       do {
257         int i = 0;
258         do {
259           uint8x16_t m0 = vld1q_u8(mask + i);
260           uint8x16_t s0 = vld1q_u8(src0 + i);
261           uint8x16_t s1 = vld1q_u8(src1 + i);
262 
263           uint8x16_t blend = alpha_blend_a64_u8x16(m0, s0, s1);
264 
265           vst1q_u8(dst + i, blend);
266           i += 16;
267         } while (i < w);
268 
269         mask += mask_stride;
270         src0 += src0_stride;
271         src1 += src1_stride;
272         dst += dst_stride;
273       } while (--h != 0);
274     } else if (w == 8) {
275       do {
276         uint8x8_t m0 = vld1_u8(mask);
277         uint8x8_t s0 = vld1_u8(src0);
278         uint8x8_t s1 = vld1_u8(src1);
279 
280         uint8x8_t blend = alpha_blend_a64_u8x8(m0, s0, s1);
281 
282         vst1_u8(dst, blend);
283 
284         mask += mask_stride;
285         src0 += src0_stride;
286         src1 += src1_stride;
287         dst += dst_stride;
288       } while (--h != 0);
289     } else {
290       do {
291         uint8x8_t m0 = load_unaligned_u8_4x2(mask, mask_stride);
292         uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride);
293         uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride);
294 
295         uint8x8_t blend = alpha_blend_a64_u8x8(m0, s0, s1);
296 
297         store_u8x4_strided_x2(dst, dst_stride, blend);
298 
299         mask += 2 * mask_stride;
300         src0 += 2 * src0_stride;
301         src1 += 2 * src1_stride;
302         dst += 2 * dst_stride;
303         h -= 2;
304       } while (h != 0);
305     }
306   } else if ((subw & subh) == 1) {
307     if (w > 8) {
308       do {
309         int i = 0;
310         do {
311           uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + 2 * i);
312           uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + 2 * i);
313           uint8x16_t m2 = vld1q_u8(mask + 0 * mask_stride + 2 * i + 16);
314           uint8x16_t m3 = vld1q_u8(mask + 1 * mask_stride + 2 * i + 16);
315           uint8x16_t s0 = vld1q_u8(src0 + i);
316           uint8x16_t s1 = vld1q_u8(src1 + i);
317 
318           uint8x16_t m_avg = avg_blend_pairwise_u8x16_4(m0, m1, m2, m3);
319           uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1);
320 
321           vst1q_u8(dst + i, blend);
322 
323           i += 16;
324         } while (i < w);
325 
326         mask += 2 * mask_stride;
327         src0 += src0_stride;
328         src1 += src1_stride;
329         dst += dst_stride;
330       } while (--h != 0);
331     } else if (w == 8) {
332       do {
333         uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
334         uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
335         uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 8);
336         uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 8);
337         uint8x8_t s0 = vld1_u8(src0);
338         uint8x8_t s1 = vld1_u8(src1);
339 
340         uint8x8_t m_avg = avg_blend_pairwise_u8x8_4(m0, m1, m2, m3);
341         uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1);
342 
343         vst1_u8(dst, blend);
344 
345         mask += 2 * mask_stride;
346         src0 += src0_stride;
347         src1 += src1_stride;
348         dst += dst_stride;
349       } while (--h != 0);
350     } else {
351       do {
352         uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
353         uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
354         uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride);
355         uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride);
356         uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride);
357         uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride);
358 
359         uint8x8_t m_avg = avg_blend_pairwise_u8x8_4(m0, m1, m2, m3);
360         uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1);
361 
362         store_u8x4_strided_x2(dst, dst_stride, blend);
363 
364         mask += 4 * mask_stride;
365         src0 += 2 * src0_stride;
366         src1 += 2 * src1_stride;
367         dst += 2 * dst_stride;
368         h -= 2;
369       } while (h != 0);
370     }
371   } else if (subw == 1 && subh == 0) {
372     if (w > 8) {
373       do {
374         int i = 0;
375 
376         do {
377           uint8x16_t m0 = vld1q_u8(mask + 2 * i);
378           uint8x16_t m1 = vld1q_u8(mask + 2 * i + 16);
379           uint8x16_t s0 = vld1q_u8(src0 + i);
380           uint8x16_t s1 = vld1q_u8(src1 + i);
381 
382           uint8x16_t m_avg = avg_blend_pairwise_u8x16(m0, m1);
383           uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1);
384 
385           vst1q_u8(dst + i, blend);
386 
387           i += 16;
388         } while (i < w);
389 
390         mask += mask_stride;
391         src0 += src0_stride;
392         src1 += src1_stride;
393         dst += dst_stride;
394       } while (--h != 0);
395     } else if (w == 8) {
396       do {
397         uint8x8_t m0 = vld1_u8(mask);
398         uint8x8_t m1 = vld1_u8(mask + 8);
399         uint8x8_t s0 = vld1_u8(src0);
400         uint8x8_t s1 = vld1_u8(src1);
401 
402         uint8x8_t m_avg = avg_blend_pairwise_u8x8(m0, m1);
403         uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1);
404 
405         vst1_u8(dst, blend);
406 
407         mask += mask_stride;
408         src0 += src0_stride;
409         src1 += src1_stride;
410         dst += dst_stride;
411       } while (--h != 0);
412     } else {
413       do {
414         uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
415         uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
416         uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride);
417         uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride);
418 
419         uint8x8_t m_avg = avg_blend_pairwise_u8x8(m0, m1);
420         uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1);
421 
422         store_u8x4_strided_x2(dst, dst_stride, blend);
423 
424         mask += 2 * mask_stride;
425         src0 += 2 * src0_stride;
426         src1 += 2 * src1_stride;
427         dst += 2 * dst_stride;
428         h -= 2;
429       } while (h != 0);
430     }
431   } else {
432     if (w > 8) {
433       do {
434         int i = 0;
435         do {
436           uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + i);
437           uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + i);
438           uint8x16_t s0 = vld1q_u8(src0 + i);
439           uint8x16_t s1 = vld1q_u8(src1 + i);
440 
441           uint8x16_t m_avg = avg_blend_u8x16(m0, m1);
442           uint8x16_t blend = alpha_blend_a64_u8x16(m_avg, s0, s1);
443 
444           vst1q_u8(dst + i, blend);
445 
446           i += 16;
447         } while (i < w);
448 
449         mask += 2 * mask_stride;
450         src0 += src0_stride;
451         src1 += src1_stride;
452         dst += dst_stride;
453       } while (--h != 0);
454     } else if (w == 8) {
455       do {
456         uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
457         uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
458         uint8x8_t s0 = vld1_u8(src0);
459         uint8x8_t s1 = vld1_u8(src1);
460 
461         uint8x8_t m_avg = avg_blend_u8x8(m0, m1);
462         uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1);
463 
464         vst1_u8(dst, blend);
465 
466         mask += 2 * mask_stride;
467         src0 += src0_stride;
468         src1 += src1_stride;
469         dst += dst_stride;
470       } while (--h != 0);
471     } else {
472       do {
473         uint8x8_t m0_2 =
474             load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride);
475         uint8x8_t m1_3 =
476             load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride);
477         uint8x8_t s0 = load_unaligned_u8_4x2(src0, src0_stride);
478         uint8x8_t s1 = load_unaligned_u8_4x2(src1, src1_stride);
479 
480         uint8x8_t m_avg = avg_blend_u8x8(m0_2, m1_3);
481         uint8x8_t blend = alpha_blend_a64_u8x8(m_avg, s0, s1);
482 
483         store_u8x4_strided_x2(dst, dst_stride, blend);
484 
485         mask += 4 * mask_stride;
486         src0 += 2 * src0_stride;
487         src1 += 2 * src1_stride;
488         dst += 2 * dst_stride;
489         h -= 2;
490       } while (h != 0);
491     }
492   }
493 }
494