1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17
18 #include "aom_dsp/arm/blend_neon.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/blend.h"
21
22 #define HBD_BLEND_A64_D16_MASK(bd, round0_bits) \
23 static inline uint16x8_t alpha_##bd##_blend_a64_d16_u16x8( \
24 uint16x8_t m, uint16x8_t a, uint16x8_t b, int32x4_t round_offset) { \
25 const uint16x8_t m_inv = \
26 vsubq_u16(vdupq_n_u16(AOM_BLEND_A64_MAX_ALPHA), m); \
27 \
28 uint32x4_t blend_u32_lo = vmlal_u16(vreinterpretq_u32_s32(round_offset), \
29 vget_low_u16(m), vget_low_u16(a)); \
30 uint32x4_t blend_u32_hi = vmlal_u16(vreinterpretq_u32_s32(round_offset), \
31 vget_high_u16(m), vget_high_u16(a)); \
32 \
33 blend_u32_lo = \
34 vmlal_u16(blend_u32_lo, vget_low_u16(m_inv), vget_low_u16(b)); \
35 blend_u32_hi = \
36 vmlal_u16(blend_u32_hi, vget_high_u16(m_inv), vget_high_u16(b)); \
37 \
38 uint16x4_t blend_u16_lo = \
39 vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_lo), \
40 AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS - \
41 round0_bits - COMPOUND_ROUND1_BITS); \
42 uint16x4_t blend_u16_hi = \
43 vqrshrun_n_s32(vreinterpretq_s32_u32(blend_u32_hi), \
44 AOM_BLEND_A64_ROUND_BITS + 2 * FILTER_BITS - \
45 round0_bits - COMPOUND_ROUND1_BITS); \
46 \
47 uint16x8_t blend_u16 = vcombine_u16(blend_u16_lo, blend_u16_hi); \
48 blend_u16 = vminq_u16(blend_u16, vdupq_n_u16((1 << bd) - 1)); \
49 \
50 return blend_u16; \
51 } \
52 \
53 static inline void highbd_##bd##_blend_a64_d16_mask_neon( \
54 uint16_t *dst, uint32_t dst_stride, const CONV_BUF_TYPE *src0, \
55 uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride, \
56 const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, \
57 int subh) { \
58 const int offset_bits = bd + 2 * FILTER_BITS - round0_bits; \
59 int32_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) + \
60 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1)); \
61 int32x4_t offset = \
62 vdupq_n_s32(-(round_offset << AOM_BLEND_A64_ROUND_BITS)); \
63 \
64 if ((subw | subh) == 0) { \
65 if (w >= 8) { \
66 do { \
67 int i = 0; \
68 do { \
69 uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i)); \
70 uint16x8_t s0 = vld1q_u16(src0 + i); \
71 uint16x8_t s1 = vld1q_u16(src1 + i); \
72 \
73 uint16x8_t blend = \
74 alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset); \
75 \
76 vst1q_u16(dst + i, blend); \
77 i += 8; \
78 } while (i < w); \
79 \
80 mask += mask_stride; \
81 src0 += src0_stride; \
82 src1 += src1_stride; \
83 dst += dst_stride; \
84 } while (--h != 0); \
85 } else { \
86 do { \
87 uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride)); \
88 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \
89 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \
90 \
91 uint16x8_t blend = \
92 alpha_##bd##_blend_a64_d16_u16x8(m0, s0, s1, offset); \
93 \
94 store_u16x4_strided_x2(dst, dst_stride, blend); \
95 \
96 mask += 2 * mask_stride; \
97 src0 += 2 * src0_stride; \
98 src1 += 2 * src1_stride; \
99 dst += 2 * dst_stride; \
100 h -= 2; \
101 } while (h != 0); \
102 } \
103 } else if ((subw & subh) == 1) { \
104 if (w >= 8) { \
105 do { \
106 int i = 0; \
107 do { \
108 uint8x16_t m0 = vld1q_u8(mask + 0 * mask_stride + 2 * i); \
109 uint8x16_t m1 = vld1q_u8(mask + 1 * mask_stride + 2 * i); \
110 uint16x8_t s0 = vld1q_u16(src0 + i); \
111 uint16x8_t s1 = vld1q_u16(src1 + i); \
112 \
113 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4( \
114 vget_low_u8(m0), vget_low_u8(m1), vget_high_u8(m0), \
115 vget_high_u8(m1))); \
116 uint16x8_t blend = \
117 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \
118 \
119 vst1q_u16(dst + i, blend); \
120 i += 8; \
121 } while (i < w); \
122 \
123 mask += 2 * mask_stride; \
124 src0 += src0_stride; \
125 src1 += src1_stride; \
126 dst += dst_stride; \
127 } while (--h != 0); \
128 } else { \
129 do { \
130 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); \
131 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); \
132 uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride); \
133 uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride); \
134 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \
135 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \
136 \
137 uint16x8_t m_avg = \
138 vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3)); \
139 uint16x8_t blend = \
140 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \
141 \
142 store_u16x4_strided_x2(dst, dst_stride, blend); \
143 \
144 mask += 4 * mask_stride; \
145 src0 += 2 * src0_stride; \
146 src1 += 2 * src1_stride; \
147 dst += 2 * dst_stride; \
148 h -= 2; \
149 } while (h != 0); \
150 } \
151 } else if (subw == 1 && subh == 0) { \
152 if (w >= 8) { \
153 do { \
154 int i = 0; \
155 do { \
156 uint8x8_t m0 = vld1_u8(mask + 2 * i); \
157 uint8x8_t m1 = vld1_u8(mask + 2 * i + 8); \
158 uint16x8_t s0 = vld1q_u16(src0 + i); \
159 uint16x8_t s1 = vld1q_u16(src1 + i); \
160 \
161 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); \
162 uint16x8_t blend = \
163 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \
164 \
165 vst1q_u16(dst + i, blend); \
166 i += 8; \
167 } while (i < w); \
168 \
169 mask += mask_stride; \
170 src0 += src0_stride; \
171 src1 += src1_stride; \
172 dst += dst_stride; \
173 } while (--h != 0); \
174 } else { \
175 do { \
176 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride); \
177 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride); \
178 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \
179 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \
180 \
181 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1)); \
182 uint16x8_t blend = \
183 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \
184 \
185 store_u16x4_strided_x2(dst, dst_stride, blend); \
186 \
187 mask += 2 * mask_stride; \
188 src0 += 2 * src0_stride; \
189 src1 += 2 * src1_stride; \
190 dst += 2 * dst_stride; \
191 h -= 2; \
192 } while (h != 0); \
193 } \
194 } else { \
195 if (w >= 8) { \
196 do { \
197 int i = 0; \
198 do { \
199 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i); \
200 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i); \
201 uint16x8_t s0 = vld1q_u16(src0 + i); \
202 uint16x8_t s1 = vld1q_u16(src1 + i); \
203 \
204 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1)); \
205 uint16x8_t blend = \
206 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \
207 \
208 vst1q_u16(dst + i, blend); \
209 i += 8; \
210 } while (i < w); \
211 \
212 mask += 2 * mask_stride; \
213 src0 += src0_stride; \
214 src1 += src1_stride; \
215 dst += dst_stride; \
216 } while (--h != 0); \
217 } else { \
218 do { \
219 uint8x8_t m0_2 = \
220 load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride); \
221 uint8x8_t m1_3 = \
222 load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride); \
223 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride); \
224 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride); \
225 \
226 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3)); \
227 uint16x8_t blend = \
228 alpha_##bd##_blend_a64_d16_u16x8(m_avg, s0, s1, offset); \
229 \
230 store_u16x4_strided_x2(dst, dst_stride, blend); \
231 \
232 mask += 4 * mask_stride; \
233 src0 += 2 * src0_stride; \
234 src1 += 2 * src1_stride; \
235 dst += 2 * dst_stride; \
236 h -= 2; \
237 } while (h != 0); \
238 } \
239 } \
240 }
241
242 // 12 bitdepth
243 HBD_BLEND_A64_D16_MASK(12, (ROUND0_BITS + 2))
244 // 10 bitdepth
245 HBD_BLEND_A64_D16_MASK(10, ROUND0_BITS)
246 // 8 bitdepth
247 HBD_BLEND_A64_D16_MASK(8, ROUND0_BITS)
248
aom_highbd_blend_a64_d16_mask_neon(uint8_t * dst_8,uint32_t dst_stride,const CONV_BUF_TYPE * src0,uint32_t src0_stride,const CONV_BUF_TYPE * src1,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,ConvolveParams * conv_params,const int bd)249 void aom_highbd_blend_a64_d16_mask_neon(
250 uint8_t *dst_8, uint32_t dst_stride, const CONV_BUF_TYPE *src0,
251 uint32_t src0_stride, const CONV_BUF_TYPE *src1, uint32_t src1_stride,
252 const uint8_t *mask, uint32_t mask_stride, int w, int h, int subw, int subh,
253 ConvolveParams *conv_params, const int bd) {
254 (void)conv_params;
255 assert(h >= 1);
256 assert(w >= 1);
257 assert(IS_POWER_OF_TWO(h));
258 assert(IS_POWER_OF_TWO(w));
259
260 uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8);
261 assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
262 assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
263
264 if (bd == 12) {
265 highbd_12_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
266 src1_stride, mask, mask_stride, w, h,
267 subw, subh);
268 } else if (bd == 10) {
269 highbd_10_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
270 src1_stride, mask, mask_stride, w, h,
271 subw, subh);
272 } else {
273 highbd_8_blend_a64_d16_mask_neon(dst, dst_stride, src0, src0_stride, src1,
274 src1_stride, mask, mask_stride, w, h, subw,
275 subh);
276 }
277 }
278
aom_highbd_blend_a64_mask_neon(uint8_t * dst_8,uint32_t dst_stride,const uint8_t * src0_8,uint32_t src0_stride,const uint8_t * src1_8,uint32_t src1_stride,const uint8_t * mask,uint32_t mask_stride,int w,int h,int subw,int subh,int bd)279 void aom_highbd_blend_a64_mask_neon(uint8_t *dst_8, uint32_t dst_stride,
280 const uint8_t *src0_8, uint32_t src0_stride,
281 const uint8_t *src1_8, uint32_t src1_stride,
282 const uint8_t *mask, uint32_t mask_stride,
283 int w, int h, int subw, int subh, int bd) {
284 (void)bd;
285
286 const uint16_t *src0 = CONVERT_TO_SHORTPTR(src0_8);
287 const uint16_t *src1 = CONVERT_TO_SHORTPTR(src1_8);
288 uint16_t *dst = CONVERT_TO_SHORTPTR(dst_8);
289
290 assert(IMPLIES(src0 == dst, src0_stride == dst_stride));
291 assert(IMPLIES(src1 == dst, src1_stride == dst_stride));
292
293 assert(h >= 1);
294 assert(w >= 1);
295 assert(IS_POWER_OF_TWO(h));
296 assert(IS_POWER_OF_TWO(w));
297
298 assert(bd == 8 || bd == 10 || bd == 12);
299
300 if ((subw | subh) == 0) {
301 if (w >= 8) {
302 do {
303 int i = 0;
304 do {
305 uint16x8_t m0 = vmovl_u8(vld1_u8(mask + i));
306 uint16x8_t s0 = vld1q_u16(src0 + i);
307 uint16x8_t s1 = vld1q_u16(src1 + i);
308
309 uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1);
310
311 vst1q_u16(dst + i, blend);
312 i += 8;
313 } while (i < w);
314
315 mask += mask_stride;
316 src0 += src0_stride;
317 src1 += src1_stride;
318 dst += dst_stride;
319 } while (--h != 0);
320 } else {
321 do {
322 uint16x8_t m0 = vmovl_u8(load_unaligned_u8_4x2(mask, mask_stride));
323 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
324 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
325
326 uint16x8_t blend = alpha_blend_a64_u16x8(m0, s0, s1);
327
328 store_u16x4_strided_x2(dst, dst_stride, blend);
329
330 mask += 2 * mask_stride;
331 src0 += 2 * src0_stride;
332 src1 += 2 * src1_stride;
333 dst += 2 * dst_stride;
334 h -= 2;
335 } while (h != 0);
336 }
337 } else if ((subw & subh) == 1) {
338 if (w >= 8) {
339 do {
340 int i = 0;
341 do {
342 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + 2 * i);
343 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + 2 * i);
344 uint8x8_t m2 = vld1_u8(mask + 0 * mask_stride + 2 * i + 8);
345 uint8x8_t m3 = vld1_u8(mask + 1 * mask_stride + 2 * i + 8);
346 uint16x8_t s0 = vld1q_u16(src0 + i);
347 uint16x8_t s1 = vld1q_u16(src1 + i);
348
349 uint16x8_t m_avg =
350 vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));
351
352 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
353
354 vst1q_u16(dst + i, blend);
355
356 i += 8;
357 } while (i < w);
358
359 mask += 2 * mask_stride;
360 src0 += src0_stride;
361 src1 += src1_stride;
362 dst += dst_stride;
363 } while (--h != 0);
364 } else {
365 do {
366 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
367 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
368 uint8x8_t m2 = vld1_u8(mask + 2 * mask_stride);
369 uint8x8_t m3 = vld1_u8(mask + 3 * mask_stride);
370 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
371 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
372
373 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8_4(m0, m1, m2, m3));
374 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
375
376 store_u16x4_strided_x2(dst, dst_stride, blend);
377
378 mask += 4 * mask_stride;
379 src0 += 2 * src0_stride;
380 src1 += 2 * src1_stride;
381 dst += 2 * dst_stride;
382 h -= 2;
383 } while (h != 0);
384 }
385 } else if (subw == 1 && subh == 0) {
386 if (w >= 8) {
387 do {
388 int i = 0;
389
390 do {
391 uint8x8_t m0 = vld1_u8(mask + 2 * i);
392 uint8x8_t m1 = vld1_u8(mask + 2 * i + 8);
393 uint16x8_t s0 = vld1q_u16(src0 + i);
394 uint16x8_t s1 = vld1q_u16(src1 + i);
395
396 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));
397 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
398
399 vst1q_u16(dst + i, blend);
400
401 i += 8;
402 } while (i < w);
403
404 mask += mask_stride;
405 src0 += src0_stride;
406 src1 += src1_stride;
407 dst += dst_stride;
408 } while (--h != 0);
409 } else {
410 do {
411 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride);
412 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride);
413 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
414 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
415
416 uint16x8_t m_avg = vmovl_u8(avg_blend_pairwise_u8x8(m0, m1));
417 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
418
419 store_u16x4_strided_x2(dst, dst_stride, blend);
420
421 mask += 2 * mask_stride;
422 src0 += 2 * src0_stride;
423 src1 += 2 * src1_stride;
424 dst += 2 * dst_stride;
425 h -= 2;
426 } while (h != 0);
427 }
428 } else {
429 if (w >= 8) {
430 do {
431 int i = 0;
432 do {
433 uint8x8_t m0 = vld1_u8(mask + 0 * mask_stride + i);
434 uint8x8_t m1 = vld1_u8(mask + 1 * mask_stride + i);
435 uint16x8_t s0 = vld1q_u16(src0 + i);
436 uint16x8_t s1 = vld1q_u16(src1 + i);
437
438 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0, m1));
439 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
440
441 vst1q_u16(dst + i, blend);
442
443 i += 8;
444 } while (i < w);
445
446 mask += 2 * mask_stride;
447 src0 += src0_stride;
448 src1 += src1_stride;
449 dst += dst_stride;
450 } while (--h != 0);
451 } else {
452 do {
453 uint8x8_t m0_2 =
454 load_unaligned_u8_4x2(mask + 0 * mask_stride, 2 * mask_stride);
455 uint8x8_t m1_3 =
456 load_unaligned_u8_4x2(mask + 1 * mask_stride, 2 * mask_stride);
457 uint16x8_t s0 = load_unaligned_u16_4x2(src0, src0_stride);
458 uint16x8_t s1 = load_unaligned_u16_4x2(src1, src1_stride);
459
460 uint16x8_t m_avg = vmovl_u8(avg_blend_u8x8(m0_2, m1_3));
461 uint16x8_t blend = alpha_blend_a64_u16x8(m_avg, s0, s1);
462
463 store_u16x4_strided_x2(dst, dst_stride, blend);
464
465 mask += 4 * mask_stride;
466 src0 += 2 * src0_stride;
467 src1 += 2 * src1_stride;
468 dst += 2 * dst_stride;
469 h -= 2;
470 } while (h != 0);
471 }
472 }
473 }
474