xref: /aosp_15_r20/external/libaom/aom_dsp/arm/highbd_blend_a64_mask_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom_dsp/arm/blend_neon.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/blend.h"
21 
22 #define HBD_BLEND_A64_D16_MASK(bd, round0_bits)                               \
23   static inline uint16x8_t alpha_##bd##_blend_a64_d16_u16x8(                  \
24       uint16x8_t m, uint16x8_t a, uint16x8_t b, int32x4_t round_offset) {     \
25     const uint16x8_t m_inv =                                                  \
26         vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m);                   \
27                                                                               \
28     uint32x4_t blend_u32_lo = vmlal_u16(vreinterpretq_u32_s32(round_offset),  \
29                                         vget_low_u16(m), vget_low_u16(a));    \
30     uint32x4_t blend_u32_hi = vmlal_u16(vreinterpretq_u32_s32(round_offset),  \
31                                         vget_high_u16(m), vget_high_u16(a));  \
32                                                                               \
33     blend_u32_lo =                                                            \
34         vmlal_u16(blend_u32_lo, vget_low_u16(m_inv), vget_low_u16(b));        \
35     blend_u32_hi =                                                            \
36         vmlal_u16(blend_u32_hi, vget_high_u16(m_inv), vget_high_u16(b));      \
37                                                                               \
38     uint16x4_t blend_u16_lo =                                                 \
39         vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_lo),                   \
40                        AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS -           \
41                            round0_bits - COMPOUND_ROUND1_BITS);               \
42     uint16x4_t blend_u16_hi =                                                 \
43         vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_hi),                   \
44                        AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS -           \
45                            round0_bits - COMPOUND_ROUND1_BITS);               \
46                                                                               \
47     uint16x8_t blend_u16 = vcombine_u16(blend_u16_lo, blend_u16_hi);          \
48     blend_u16 = vminq_u16(blend_u16, vdupq_n_u16((1 << bd) - 1));             \
49                                                                               \
50     return blend_u16;                                                         \
51   }                                                                           \
52                                                                               \
53   static inline void highbd_##bd##_blend_a64_d16_mask_neon(                   \
54       uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0,          \
55       uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,  \
56       const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw,      \
57       int subh) {                                                             \
58     const int offset_bits = bd + 2 * FILTER_BITS - round0_bits;               \
59     int32_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +      \
60                            (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));   \
61     int32x4_t offset =                                                        \
62         vdupq_n_s32(-(round_offset << AOM_BLEND_A64_ROUND_BITS));             \
63                                                                               \
64     if ((subw | subh) == 0) {                                                 \
65       if (w >= 8) {                                                           \
66         do {                                                                  \
67           int i = 0;                                                          \
68           do {                                                                \
69             uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i));                      \
70             uint16x8_t s0 = vld1q_u16(src0 + i);                              \
71             uint16x8_t s1 = vld1q_u16(src1 + i);                              \
72                                                                               \
73             uint16x8_t blend =                                                \
74                 alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset);         \
75                                                                               \
76             vst1q_u16(dst + i, blend);                                        \
77             i += 8;                                                           \
78           } while (i < w);                                                    \
79                                                                               \
80           mask += mask_stride;                                                \
81           src0 += src0_stride;                                                \
82           src1 += src1_stride;                                                \
83           dst += dst_stride;                                                  \
84         } while (--h != 0);                                                   \
85       } else {                                                                \
86         do {                                                                  \
87           uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); \
88           uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);          \
89           uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);          \
90                                                                               \
91           uint16x8_t blend =                                                  \
92               alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset);           \
93                                                                               \
94           store_u16x4_strided_x2(dst, dst_stride, blend);                     \
95                                                                               \
96           mask += 2 * mask_stride;                                            \
97           src0 += 2 * src0_stride;                                            \
98           src1 += 2 * src1_stride;                                            \
99           dst += 2 * dst_stride;                                              \
100           h -= 2;                                                             \
101         } while (h != 0);                                                     \
102       }                                                                       \
103     } else if ((subw & subh) == 1) {                                          \
104       if (w >= 8) {                                                           \
105         do {                                                                  \
106           int i = 0;                                                          \
107           do {                                                                \
108             uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + 2 * i);         \
109             uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + 2 * i);         \
110             uint16x8_t s0 = vld1q_u16(src0 + i);                              \
111             uint16x8_t s1 = vld1q_u16(src1 + i);                              \
112                                                                               \
113             uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(            \
114                 vget_low_u8(m0), vget_low_u8(m1), vget_high_u8(m0),           \
115                 vget_high_u8(m1)));                                           \
116             uint16x8_t blend =                                                \
117                 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);      \
118                                                                               \
119             vst1q_u16(dst + i, blend);                                        \
120             i += 8;                                                           \
121           } while (i < w);                                                    \
122                                                                               \
123           mask += 2 * mask_stride;                                            \
124           src0 += src0_stride;                                                \
125           src1 += src1_stride;                                                \
126           dst += dst_stride;                                                  \
127         } while (--h != 0);                                                   \
128       } else {                                                                \
129         do {                                                                  \
130           uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);                     \
131           uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);                     \
132           uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride);                     \
133           uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride);                     \
134           uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);          \
135           uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);          \
136                                                                               \
137           uint16x8_t m_avg =                                                  \
138               vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));            \
139           uint16x8_t blend =                                                  \
140               alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);        \
141                                                                               \
142           store_u16x4_strided_x2(dst, dst_stride, blend);                     \
143                                                                               \
144           mask += 4 * mask_stride;                                            \
145           src0 += 2 * src0_stride;                                            \
146           src1 += 2 * src1_stride;                                            \
147           dst += 2 * dst_stride;                                              \
148           h -= 2;                                                             \
149         } while (h != 0);                                                     \
150       }                                                                       \
151     } else if (subw == 1 && subh == 0) {                                      \
152       if (w >= 8) {                                                           \
153         do {                                                                  \
154           int i = 0;                                                          \
155           do {                                                                \
156             uint8x8_t m0 = vld1_u8(mask + 2 * i);                             \
157             uint8x8_t m1 = vld1_u8(mask + 2 * i + 8);                         \
158             uint16x8_t s0 = vld1q_u16(src0 + i);                              \
159             uint16x8_t s1 = vld1q_u16(src1 + i);                              \
160                                                                               \
161             uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));     \
162             uint16x8_t blend =                                                \
163                 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);      \
164                                                                               \
165             vst1q_u16(dst + i, blend);                                        \
166             i += 8;                                                           \
167           } while (i < w);                                                    \
168                                                                               \
169           mask += mask_stride;                                                \
170           src0 += src0_stride;                                                \
171           src1 += src1_stride;                                                \
172           dst += dst_stride;                                                  \
173         } while (--h != 0);                                                   \
174       } else {                                                                \
175         do {                                                                  \
176           uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);                     \
177           uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);                     \
178           uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);          \
179           uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);          \
180                                                                               \
181           uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));       \
182           uint16x8_t blend =                                                  \
183               alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);        \
184                                                                               \
185           store_u16x4_strided_x2(dst, dst_stride, blend);                     \
186                                                                               \
187           mask += 2 * mask_stride;                                            \
188           src0 += 2 * src0_stride;                                            \
189           src1 += 2 * src1_stride;                                            \
190           dst += 2 * dst_stride;                                              \
191           h -= 2;                                                             \
192         } while (h != 0);                                                     \
193       }                                                                       \
194     } else {                                                                  \
195       if (w >= 8) {                                                           \
196         do {                                                                  \
197           int i = 0;                                                          \
198           do {                                                                \
199             uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i);               \
200             uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i);               \
201             uint16x8_t s0 = vld1q_u16(src0 + i);                              \
202             uint16x8_t s1 = vld1q_u16(src1 + i);                              \
203                                                                               \
204             uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1));              \
205             uint16x8_t blend =                                                \
206                 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);      \
207                                                                               \
208             vst1q_u16(dst + i, blend);                                        \
209             i += 8;                                                           \
210           } while (i < w);                                                    \
211                                                                               \
212           mask += 2 * mask_stride;                                            \
213           src0 += src0_stride;                                                \
214           src1 += src1_stride;                                                \
215           dst += dst_stride;                                                  \
216         } while (--h != 0);                                                   \
217       } else {                                                                \
218         do {                                                                  \
219           uint8x8_t m0_2 =                                                    \
220               load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); \
221           uint8x8_t m1_3 =                                                    \
222               load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); \
223           uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);          \
224           uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);          \
225                                                                               \
226           uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3));            \
227           uint16x8_t blend =                                                  \
228               alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset);        \
229                                                                               \
230           store_u16x4_strided_x2(dst, dst_stride, blend);                     \
231                                                                               \
232           mask += 4 * mask_stride;                                            \
233           src0 += 2 * src0_stride;                                            \
234           src1 += 2 * src1_stride;                                            \
235           dst += 2 * dst_stride;                                              \
236           h -= 2;                                                             \
237         } while (h != 0);                                                     \
238       }                                                                       \
239     }                                                                         \
240   }
241 
242 // 12 bitdepth
243 HBD_BLEND_A64_D16_MASK(12, (ROUND0_BITS + 2))
244 // 10 bitdepth
245 HBD_BLEND_A64_D16_MASK(10, ROUND0_BITS)
246 // 8 bitdepth
247 HBD_BLEND_A64_D16_MASK(8, ROUND0_BITS)
248 
aom_highbd_blend_a64_d16_mask_neon(uint8_t * dst_8,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,ConvolveParams * conv_params,const int bd)249 void aom_highbd_blend_a64_d16_mask_neon(
250     uint8_t *dst_8, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
251     uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
252     const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh,
253     ConvolveParams *conv_params, const int bd) {
254   (void)conv_params;
255   assert(h >= 1);
256   assert(w >= 1);
257   assert(IS_POWER_OF_TWO(h));
258   assert(IS_POWER_OF_TWO(w));
259 
260   uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8);
261   assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
262   assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
263 
264   if (bd == 12) {
265     highbd_12_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
266                                       src1_stride, mask, mask_stride, w, h,
267                                       subw, subh);
268   } else if (bd == 10) {
269     highbd_10_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
270                                       src1_stride, mask, mask_stride, w, h,
271                                       subw, subh);
272   } else {
273     highbd_8_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
274                                      src1_stride, mask, mask_stride, w, h, subw,
275                                      subh);
276   }
277 }
278 
aom_highbd_blend_a64_mask_neon(uint8_t * dst_8,uint32_t dst_stride,const uint8_t * src0_8,uint32_t src0_stride,const uint8_t * src1_8,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,int bd)279 void aom_highbd_blend_a64_mask_neon(uint8_t *dst_8, uint32_t dst_stride,
280                                     const uint8_t *src0_8, uint32_t src0_stride,
281                                     const uint8_t *src1_8, uint32_t src1_stride,
282                                     const uint8_t *mask, uint32_t mask_stride,
283                                     int w, int h, int subw, int subh, int bd) {
284   (void)bd;
285 
286   const uint16_t *src0 = CONVERT_TO_SHORTPTR(src0_8);
287   const uint16_t *src1 = CONVERT_TO_SHORTPTR(src1_8);
288   uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8);
289 
290   assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
291   assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
292 
293   assert(h >= 1);
294   assert(w >= 1);
295   assert(IS_POWER_OF_TWO(h));
296   assert(IS_POWER_OF_TWO(w));
297 
298   assert(bd == 8 || bd == 10 || bd == 12);
299 
300   if ((subw | subh) == 0) {
301     if (w >= 8) {
302       do {
303         int i = 0;
304         do {
305           uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i));
306           uint16x8_t s0 = vld1q_u16(src0 + i);
307           uint16x8_t s1 = vld1q_u16(src1 + i);
308 
309           uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1);
310 
311           vst1q_u16(dst + i, blend);
312           i += 8;
313         } while (i < w);
314 
315         mask += mask_stride;
316         src0 += src0_stride;
317         src1 += src1_stride;
318         dst += dst_stride;
319       } while (--h != 0);
320     } else {
321       do {
322         uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride));
323         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
324         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
325 
326         uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1);
327 
328         store_u16x4_strided_x2(dst, dst_stride, blend);
329 
330         mask += 2 * mask_stride;
331         src0 += 2 * src0_stride;
332         src1 += 2 * src1_stride;
333         dst += 2 * dst_stride;
334         h -= 2;
335       } while (h != 0);
336     }
337   } else if ((subw & subh) == 1) {
338     if (w >= 8) {
339       do {
340         int i = 0;
341         do {
342           uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + 2 * i);
343           uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + 2 * i);
344           uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 2 * i + 8);
345           uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 2 * i + 8);
346           uint16x8_t s0 = vld1q_u16(src0 + i);
347           uint16x8_t s1 = vld1q_u16(src1 + i);
348 
349           uint16x8_t m_avg =
350               vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));
351 
352           uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
353 
354           vst1q_u16(dst + i, blend);
355 
356           i += 8;
357         } while (i < w);
358 
359         mask += 2 * mask_stride;
360         src0 += src0_stride;
361         src1 += src1_stride;
362         dst += dst_stride;
363       } while (--h != 0);
364     } else {
365       do {
366         uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
367         uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
368         uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride);
369         uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride);
370         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
371         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
372 
373         uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));
374         uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
375 
376         store_u16x4_strided_x2(dst, dst_stride, blend);
377 
378         mask += 4 * mask_stride;
379         src0 += 2 * src0_stride;
380         src1 += 2 * src1_stride;
381         dst += 2 * dst_stride;
382         h -= 2;
383       } while (h != 0);
384     }
385   } else if (subw == 1 && subh == 0) {
386     if (w >= 8) {
387       do {
388         int i = 0;
389 
390         do {
391           uint8x8_t m0 = vld1_u8(mask + 2 * i);
392           uint8x8_t m1 = vld1_u8(mask + 2 * i + 8);
393           uint16x8_t s0 = vld1q_u16(src0 + i);
394           uint16x8_t s1 = vld1q_u16(src1 + i);
395 
396           uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));
397           uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
398 
399           vst1q_u16(dst + i, blend);
400 
401           i += 8;
402         } while (i < w);
403 
404         mask += mask_stride;
405         src0 += src0_stride;
406         src1 += src1_stride;
407         dst += dst_stride;
408       } while (--h != 0);
409     } else {
410       do {
411         uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
412         uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
413         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
414         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
415 
416         uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));
417         uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
418 
419         store_u16x4_strided_x2(dst, dst_stride, blend);
420 
421         mask += 2 * mask_stride;
422         src0 += 2 * src0_stride;
423         src1 += 2 * src1_stride;
424         dst += 2 * dst_stride;
425         h -= 2;
426       } while (h != 0);
427     }
428   } else {
429     if (w >= 8) {
430       do {
431         int i = 0;
432         do {
433           uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i);
434           uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i);
435           uint16x8_t s0 = vld1q_u16(src0 + i);
436           uint16x8_t s1 = vld1q_u16(src1 + i);
437 
438           uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1));
439           uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
440 
441           vst1q_u16(dst + i, blend);
442 
443           i += 8;
444         } while (i < w);
445 
446         mask += 2 * mask_stride;
447         src0 += src0_stride;
448         src1 += src1_stride;
449         dst += dst_stride;
450       } while (--h != 0);
451     } else {
452       do {
453         uint8x8_t m0_2 =
454             load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride);
455         uint8x8_t m1_3 =
456             load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride);
457         uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
458         uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
459 
460         uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3));
461         uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
462 
463         store_u16x4_strided_x2(dst, dst_stride, blend);
464 
465         mask += 4 * mask_stride;
466         src0 += 2 * src0_stride;
467         src1 += 2 * src1_stride;
468         dst += 2 * dst_stride;
469         h -= 2;
470       } while (h != 0);
471     }
472   }
473 }
474