1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/blend_neon.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 #include "aom_dsp/blend.h"
22
masked_sad_8x1_neon(uint16x8_t sad,const uint16_t * src,const uint16_t * a,const uint16_t * b,const uint8_t * m)23 static inline uint16x8_t masked_sad_8x1_neon(uint16x8_t sad,
24 const uint16_t *src,
25 const uint16_t *a,
26 const uint16_t *b,
27 const uint8_t *m) {
28 const uint16x8_t s0 = vld1q_u16(src);
29 const uint16x8_t a0 = vld1q_u16(a);
30 const uint16x8_t b0 = vld1q_u16(b);
31 const uint16x8_t m0 = vmovl_u8(vld1_u8(m));
32
33 uint16x8_t blend_u16 = alpha_blend_a64_u16x8(m0, a0, b0);
34
35 return vaddq_u16(sad, vabdq_u16(blend_u16, s0));
36 }
37
masked_sad_16x1_neon(uint16x8_t sad,const uint16_t * src,const uint16_t * a,const uint16_t * b,const uint8_t * m)38 static inline uint16x8_t masked_sad_16x1_neon(uint16x8_t sad,
39 const uint16_t *src,
40 const uint16_t *a,
41 const uint16_t *b,
42 const uint8_t *m) {
43 sad = masked_sad_8x1_neon(sad, src, a, b, m);
44 return masked_sad_8x1_neon(sad, &src[8], &a[8], &b[8], &m[8]);
45 }
46
masked_sad_32x1_neon(uint16x8_t sad,const uint16_t * src,const uint16_t * a,const uint16_t * b,const uint8_t * m)47 static inline uint16x8_t masked_sad_32x1_neon(uint16x8_t sad,
48 const uint16_t *src,
49 const uint16_t *a,
50 const uint16_t *b,
51 const uint8_t *m) {
52 sad = masked_sad_16x1_neon(sad, src, a, b, m);
53 return masked_sad_16x1_neon(sad, &src[16], &a[16], &b[16], &m[16]);
54 }
55
masked_sad_128xh_large_neon(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m,int m_stride,int height)56 static inline unsigned int masked_sad_128xh_large_neon(
57 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
58 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride,
59 int height) {
60 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
61 const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
62 const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
63 uint32x4_t sad_u32[] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
64 vdupq_n_u32(0) };
65
66 do {
67 uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
68 vdupq_n_u16(0) };
69 for (int h = 0; h < 4; ++h) {
70 sad[0] = masked_sad_32x1_neon(sad[0], src, a, b, m);
71 sad[1] = masked_sad_32x1_neon(sad[1], &src[32], &a[32], &b[32], &m[32]);
72 sad[2] = masked_sad_32x1_neon(sad[2], &src[64], &a[64], &b[64], &m[64]);
73 sad[3] = masked_sad_32x1_neon(sad[3], &src[96], &a[96], &b[96], &m[96]);
74
75 src += src_stride;
76 a += a_stride;
77 b += b_stride;
78 m += m_stride;
79 }
80
81 sad_u32[0] = vpadalq_u16(sad_u32[0], sad[0]);
82 sad_u32[1] = vpadalq_u16(sad_u32[1], sad[1]);
83 sad_u32[2] = vpadalq_u16(sad_u32[2], sad[2]);
84 sad_u32[3] = vpadalq_u16(sad_u32[3], sad[3]);
85 height -= 4;
86 } while (height != 0);
87
88 sad_u32[0] = vaddq_u32(sad_u32[0], sad_u32[1]);
89 sad_u32[2] = vaddq_u32(sad_u32[2], sad_u32[3]);
90 sad_u32[0] = vaddq_u32(sad_u32[0], sad_u32[2]);
91
92 return horizontal_add_u32x4(sad_u32[0]);
93 }
94
masked_sad_64xh_large_neon(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m,int m_stride,int height)95 static inline unsigned int masked_sad_64xh_large_neon(
96 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
97 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride,
98 int height) {
99 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
100 const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
101 const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
102 uint32x4_t sad_u32[] = { vdupq_n_u32(0), vdupq_n_u32(0) };
103
104 do {
105 uint16x8_t sad[] = { vdupq_n_u16(0), vdupq_n_u16(0) };
106 for (int h = 0; h < 4; ++h) {
107 sad[0] = masked_sad_32x1_neon(sad[0], src, a, b, m);
108 sad[1] = masked_sad_32x1_neon(sad[1], &src[32], &a[32], &b[32], &m[32]);
109
110 src += src_stride;
111 a += a_stride;
112 b += b_stride;
113 m += m_stride;
114 }
115
116 sad_u32[0] = vpadalq_u16(sad_u32[0], sad[0]);
117 sad_u32[1] = vpadalq_u16(sad_u32[1], sad[1]);
118 height -= 4;
119 } while (height != 0);
120
121 return horizontal_add_u32x4(vaddq_u32(sad_u32[0], sad_u32[1]));
122 }
123
masked_sad_32xh_large_neon(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m,int m_stride,int height)124 static inline unsigned int masked_sad_32xh_large_neon(
125 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
126 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride,
127 int height) {
128 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
129 const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
130 const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
131 uint32x4_t sad_u32 = vdupq_n_u32(0);
132
133 do {
134 uint16x8_t sad = vdupq_n_u16(0);
135 for (int h = 0; h < 4; ++h) {
136 sad = masked_sad_32x1_neon(sad, src, a, b, m);
137
138 src += src_stride;
139 a += a_stride;
140 b += b_stride;
141 m += m_stride;
142 }
143
144 sad_u32 = vpadalq_u16(sad_u32, sad);
145 height -= 4;
146 } while (height != 0);
147
148 return horizontal_add_u32x4(sad_u32);
149 }
150
masked_sad_16xh_large_neon(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m,int m_stride,int height)151 static inline unsigned int masked_sad_16xh_large_neon(
152 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
153 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride,
154 int height) {
155 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
156 const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
157 const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
158 uint32x4_t sad_u32 = vdupq_n_u32(0);
159
160 do {
161 uint16x8_t sad_u16 = vdupq_n_u16(0);
162
163 for (int h = 0; h < 8; ++h) {
164 sad_u16 = masked_sad_16x1_neon(sad_u16, src, a, b, m);
165
166 src += src_stride;
167 a += a_stride;
168 b += b_stride;
169 m += m_stride;
170 }
171
172 sad_u32 = vpadalq_u16(sad_u32, sad_u16);
173 height -= 8;
174 } while (height != 0);
175
176 return horizontal_add_u32x4(sad_u32);
177 }
178
179 #if !CONFIG_REALTIME_ONLY
masked_sad_8xh_large_neon(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m,int m_stride,int height)180 static inline unsigned int masked_sad_8xh_large_neon(
181 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
182 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride,
183 int height) {
184 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
185 const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
186 const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
187 uint32x4_t sad_u32 = vdupq_n_u32(0);
188
189 do {
190 uint16x8_t sad_u16 = vdupq_n_u16(0);
191
192 for (int h = 0; h < 16; ++h) {
193 sad_u16 = masked_sad_8x1_neon(sad_u16, src, a, b, m);
194
195 src += src_stride;
196 a += a_stride;
197 b += b_stride;
198 m += m_stride;
199 }
200
201 sad_u32 = vpadalq_u16(sad_u32, sad_u16);
202 height -= 16;
203 } while (height != 0);
204
205 return horizontal_add_u32x4(sad_u32);
206 }
207 #endif // !CONFIG_REALTIME_ONLY
208
masked_sad_16xh_small_neon(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m,int m_stride,int height)209 static inline unsigned int masked_sad_16xh_small_neon(
210 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
211 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride,
212 int height) {
213 // For 12-bit data, we can only accumulate up to 128 elements in the
214 // uint16x8_t type sad accumulator, so we can only process up to 8 rows
215 // before we have to accumulate into 32-bit elements.
216 assert(height <= 8);
217 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
218 const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
219 const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
220 uint16x8_t sad = vdupq_n_u16(0);
221
222 do {
223 sad = masked_sad_16x1_neon(sad, src, a, b, m);
224
225 src += src_stride;
226 a += a_stride;
227 b += b_stride;
228 m += m_stride;
229 } while (--height != 0);
230
231 return horizontal_add_u16x8(sad);
232 }
233
masked_sad_8xh_small_neon(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m,int m_stride,int height)234 static inline unsigned int masked_sad_8xh_small_neon(
235 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
236 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride,
237 int height) {
238 // For 12-bit data, we can only accumulate up to 128 elements in the
239 // uint16x8_t type sad accumulator, so we can only process up to 16 rows
240 // before we have to accumulate into 32-bit elements.
241 assert(height <= 16);
242 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
243 const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
244 const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
245 uint16x8_t sad = vdupq_n_u16(0);
246
247 do {
248 sad = masked_sad_8x1_neon(sad, src, a, b, m);
249
250 src += src_stride;
251 a += a_stride;
252 b += b_stride;
253 m += m_stride;
254 } while (--height != 0);
255
256 return horizontal_add_u16x8(sad);
257 }
258
masked_sad_4xh_small_neon(const uint8_t * src8,int src_stride,const uint8_t * a8,int a_stride,const uint8_t * b8,int b_stride,const uint8_t * m,int m_stride,int height)259 static inline unsigned int masked_sad_4xh_small_neon(
260 const uint8_t *src8, int src_stride, const uint8_t *a8, int a_stride,
261 const uint8_t *b8, int b_stride, const uint8_t *m, int m_stride,
262 int height) {
263 // For 12-bit data, we can only accumulate up to 64 elements in the
264 // uint16x4_t type sad accumulator, so we can only process up to 16 rows
265 // before we have to accumulate into 32-bit elements.
266 assert(height <= 16);
267 const uint16_t *src = CONVERT_TO_SHORTPTR(src8);
268 const uint16_t *a = CONVERT_TO_SHORTPTR(a8);
269 const uint16_t *b = CONVERT_TO_SHORTPTR(b8);
270
271 uint16x4_t sad = vdup_n_u16(0);
272 do {
273 uint16x4_t m0 = vget_low_u16(vmovl_u8(load_unaligned_u8_4x1(m)));
274 uint16x4_t a0 = load_unaligned_u16_4x1(a);
275 uint16x4_t b0 = load_unaligned_u16_4x1(b);
276 uint16x4_t s0 = load_unaligned_u16_4x1(src);
277
278 uint16x4_t blend_u16 = alpha_blend_a64_u16x4(m0, a0, b0);
279
280 sad = vadd_u16(sad, vabd_u16(blend_u16, s0));
281
282 src += src_stride;
283 a += a_stride;
284 b += b_stride;
285 m += m_stride;
286 } while (--height != 0);
287
288 return horizontal_add_u16x4(sad);
289 }
290
291 #define HIGHBD_MASKED_SAD_WXH_SMALL_NEON(w, h) \
292 unsigned int aom_highbd_masked_sad##w##x##h##_neon( \
293 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
294 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
295 int invert_mask) { \
296 if (!invert_mask) \
297 return masked_sad_##w##xh_small_neon(src, src_stride, ref, ref_stride, \
298 second_pred, w, msk, msk_stride, \
299 h); \
300 else \
301 return masked_sad_##w##xh_small_neon(src, src_stride, second_pred, w, \
302 ref, ref_stride, msk, msk_stride, \
303 h); \
304 }
305
306 #define HIGHBD_MASKED_SAD_WXH_LARGE_NEON(w, h) \
307 unsigned int aom_highbd_masked_sad##w##x##h##_neon( \
308 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
309 const uint8_t *second_pred, const uint8_t *msk, int msk_stride, \
310 int invert_mask) { \
311 if (!invert_mask) \
312 return masked_sad_##w##xh_large_neon(src, src_stride, ref, ref_stride, \
313 second_pred, w, msk, msk_stride, \
314 h); \
315 else \
316 return masked_sad_##w##xh_large_neon(src, src_stride, second_pred, w, \
317 ref, ref_stride, msk, msk_stride, \
318 h); \
319 }
320
321 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 4)
322 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 8)
323
324 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 4)
325 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 8)
326 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(8, 16)
327
328 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(16, 8)
329 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 16)
330 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 32)
331
332 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 16)
333 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 32)
334 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 64)
335
336 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 32)
337 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 64)
338 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 128)
339
340 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(128, 64)
341 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(128, 128)
342
343 #if !CONFIG_REALTIME_ONLY
344 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(4, 16)
345
346 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(8, 32)
347
348 HIGHBD_MASKED_SAD_WXH_SMALL_NEON(16, 4)
349 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(16, 64)
350
351 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(32, 8)
352
353 HIGHBD_MASKED_SAD_WXH_LARGE_NEON(64, 16)
354 #endif // !CONFIG_REALTIME_ONLY
355