xref: /aosp_15_r20/external/libaom/aom_dsp/arm/highbd_sadxd_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023 The WebM project authors. All rights reserved.
3  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
4  *
5  * This source code is subject to the terms of the BSD 2 Clause License and
6  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7  * was not distributed with this source code in the LICENSE file, you can
8  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9  * Media Patent License 1.0 was not distributed with this source code in the
10  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 
13 #include <arm_neon.h>
14 
15 #include "config/aom_config.h"
16 #include "config/aom_dsp_rtcd.h"
17 
18 #include "aom/aom_integer.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 
highbd_sad4xhx4d_small_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)22 static inline void highbd_sad4xhx4d_small_neon(const uint8_t *src_ptr,
23                                                int src_stride,
24                                                const uint8_t *const ref_ptr[4],
25                                                int ref_stride, uint32_t res[4],
26                                                int h) {
27   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
28   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
29   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
30   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
31   const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]);
32 
33   uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
34                         vdupq_n_u32(0) };
35 
36   int i = 0;
37   do {
38     uint16x4_t s = vld1_u16(src16_ptr + i * src_stride);
39     uint16x4_t r0 = vld1_u16(ref16_ptr0 + i * ref_stride);
40     uint16x4_t r1 = vld1_u16(ref16_ptr1 + i * ref_stride);
41     uint16x4_t r2 = vld1_u16(ref16_ptr2 + i * ref_stride);
42     uint16x4_t r3 = vld1_u16(ref16_ptr3 + i * ref_stride);
43 
44     sum[0] = vabal_u16(sum[0], s, r0);
45     sum[1] = vabal_u16(sum[1], s, r1);
46     sum[2] = vabal_u16(sum[2], s, r2);
47     sum[3] = vabal_u16(sum[3], s, r3);
48 
49   } while (++i < h);
50 
51   vst1q_u32(res, horizontal_add_4d_u32x4(sum));
52 }
53 
highbd_sad8xhx4d_small_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)54 static inline void highbd_sad8xhx4d_small_neon(const uint8_t *src_ptr,
55                                                int src_stride,
56                                                const uint8_t *const ref_ptr[4],
57                                                int ref_stride, uint32_t res[4],
58                                                int h) {
59   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
60   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
61   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
62   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
63   const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]);
64 
65   uint16x8_t sum[4] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0),
66                         vdupq_n_u16(0) };
67   uint32x4_t sum_u32[4];
68 
69   int i = 0;
70   do {
71     uint16x8_t s = vld1q_u16(src16_ptr + i * src_stride);
72 
73     sum[0] = vabaq_u16(sum[0], s, vld1q_u16(ref16_ptr0 + i * ref_stride));
74     sum[1] = vabaq_u16(sum[1], s, vld1q_u16(ref16_ptr1 + i * ref_stride));
75     sum[2] = vabaq_u16(sum[2], s, vld1q_u16(ref16_ptr2 + i * ref_stride));
76     sum[3] = vabaq_u16(sum[3], s, vld1q_u16(ref16_ptr3 + i * ref_stride));
77 
78   } while (++i < h);
79 
80   sum_u32[0] = vpaddlq_u16(sum[0]);
81   sum_u32[1] = vpaddlq_u16(sum[1]);
82   sum_u32[2] = vpaddlq_u16(sum[2]);
83   sum_u32[3] = vpaddlq_u16(sum[3]);
84   vst1q_u32(res, horizontal_add_4d_u32x4(sum_u32));
85 }
86 
sad8_neon(uint16x8_t src,uint16x8_t ref,uint32x4_t * const sad_sum)87 static inline void sad8_neon(uint16x8_t src, uint16x8_t ref,
88                              uint32x4_t *const sad_sum) {
89   uint16x8_t abs_diff = vabdq_u16(src, ref);
90   *sad_sum = vpadalq_u16(*sad_sum, abs_diff);
91 }
92 
93 #if !CONFIG_REALTIME_ONLY
highbd_sad8xhx4d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)94 static inline void highbd_sad8xhx4d_large_neon(const uint8_t *src_ptr,
95                                                int src_stride,
96                                                const uint8_t *const ref_ptr[4],
97                                                int ref_stride, uint32_t res[4],
98                                                int h) {
99   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
100   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
101   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
102   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
103   const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]);
104 
105   uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
106                         vdupq_n_u32(0) };
107 
108   int i = 0;
109   do {
110     uint16x8_t s = vld1q_u16(src16_ptr + i * src_stride);
111     sad8_neon(s, vld1q_u16(ref16_ptr0 + i * ref_stride), &sum[0]);
112     sad8_neon(s, vld1q_u16(ref16_ptr1 + i * ref_stride), &sum[1]);
113     sad8_neon(s, vld1q_u16(ref16_ptr2 + i * ref_stride), &sum[2]);
114     sad8_neon(s, vld1q_u16(ref16_ptr3 + i * ref_stride), &sum[3]);
115 
116   } while (++i < h);
117 
118   vst1q_u32(res, horizontal_add_4d_u32x4(sum));
119 }
120 #endif  // !CONFIG_REALTIME_ONLY
121 
highbd_sad16xhx4d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)122 static inline void highbd_sad16xhx4d_large_neon(const uint8_t *src_ptr,
123                                                 int src_stride,
124                                                 const uint8_t *const ref_ptr[4],
125                                                 int ref_stride, uint32_t res[4],
126                                                 int h) {
127   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
128   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
129   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
130   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
131   const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]);
132 
133   uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
134                            vdupq_n_u32(0) };
135   uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
136                            vdupq_n_u32(0) };
137   uint32x4_t sum[4];
138 
139   int i = 0;
140   do {
141     uint16x8_t s0 = vld1q_u16(src16_ptr + i * src_stride);
142     sad8_neon(s0, vld1q_u16(ref16_ptr0 + i * ref_stride), &sum_lo[0]);
143     sad8_neon(s0, vld1q_u16(ref16_ptr1 + i * ref_stride), &sum_lo[1]);
144     sad8_neon(s0, vld1q_u16(ref16_ptr2 + i * ref_stride), &sum_lo[2]);
145     sad8_neon(s0, vld1q_u16(ref16_ptr3 + i * ref_stride), &sum_lo[3]);
146 
147     uint16x8_t s1 = vld1q_u16(src16_ptr + i * src_stride + 8);
148     sad8_neon(s1, vld1q_u16(ref16_ptr0 + i * ref_stride + 8), &sum_hi[0]);
149     sad8_neon(s1, vld1q_u16(ref16_ptr1 + i * ref_stride + 8), &sum_hi[1]);
150     sad8_neon(s1, vld1q_u16(ref16_ptr2 + i * ref_stride + 8), &sum_hi[2]);
151     sad8_neon(s1, vld1q_u16(ref16_ptr3 + i * ref_stride + 8), &sum_hi[3]);
152 
153   } while (++i < h);
154 
155   sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]);
156   sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]);
157   sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]);
158   sum[3] = vaddq_u32(sum_lo[3], sum_hi[3]);
159 
160   vst1q_u32(res, horizontal_add_4d_u32x4(sum));
161 }
162 
highbd_sadwxhx4d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int w,int h)163 static inline void highbd_sadwxhx4d_large_neon(const uint8_t *src_ptr,
164                                                int src_stride,
165                                                const uint8_t *const ref_ptr[4],
166                                                int ref_stride, uint32_t res[4],
167                                                int w, int h) {
168   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
169   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
170   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
171   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
172   const uint16_t *ref16_ptr3 = CONVERT_TO_SHORTPTR(ref_ptr[3]);
173 
174   uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
175                            vdupq_n_u32(0) };
176   uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
177                            vdupq_n_u32(0) };
178   uint32x4_t sum[4];
179 
180   int i = 0;
181   do {
182     int j = 0;
183     do {
184       uint16x8_t s0 = vld1q_u16(src16_ptr + i * src_stride + j);
185       sad8_neon(s0, vld1q_u16(ref16_ptr0 + i * ref_stride + j), &sum_lo[0]);
186       sad8_neon(s0, vld1q_u16(ref16_ptr1 + i * ref_stride + j), &sum_lo[1]);
187       sad8_neon(s0, vld1q_u16(ref16_ptr2 + i * ref_stride + j), &sum_lo[2]);
188       sad8_neon(s0, vld1q_u16(ref16_ptr3 + i * ref_stride + j), &sum_lo[3]);
189 
190       uint16x8_t s1 = vld1q_u16(src16_ptr + i * src_stride + j + 8);
191       sad8_neon(s1, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 8), &sum_hi[0]);
192       sad8_neon(s1, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 8), &sum_hi[1]);
193       sad8_neon(s1, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 8), &sum_hi[2]);
194       sad8_neon(s1, vld1q_u16(ref16_ptr3 + i * ref_stride + j + 8), &sum_hi[3]);
195 
196       uint16x8_t s2 = vld1q_u16(src16_ptr + i * src_stride + j + 16);
197       sad8_neon(s2, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 16),
198                 &sum_lo[0]);
199       sad8_neon(s2, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 16),
200                 &sum_lo[1]);
201       sad8_neon(s2, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 16),
202                 &sum_lo[2]);
203       sad8_neon(s2, vld1q_u16(ref16_ptr3 + i * ref_stride + j + 16),
204                 &sum_lo[3]);
205 
206       uint16x8_t s3 = vld1q_u16(src16_ptr + i * src_stride + j + 24);
207       sad8_neon(s3, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 24),
208                 &sum_hi[0]);
209       sad8_neon(s3, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 24),
210                 &sum_hi[1]);
211       sad8_neon(s3, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 24),
212                 &sum_hi[2]);
213       sad8_neon(s3, vld1q_u16(ref16_ptr3 + i * ref_stride + j + 24),
214                 &sum_hi[3]);
215 
216       j += 32;
217     } while (j < w);
218 
219   } while (++i < h);
220 
221   sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]);
222   sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]);
223   sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]);
224   sum[3] = vaddq_u32(sum_lo[3], sum_hi[3]);
225 
226   vst1q_u32(res, horizontal_add_4d_u32x4(sum));
227 }
228 
highbd_sad128xhx4d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)229 static inline void highbd_sad128xhx4d_large_neon(
230     const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4],
231     int ref_stride, uint32_t res[4], int h) {
232   highbd_sadwxhx4d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res,
233                               128, h);
234 }
235 
highbd_sad64xhx4d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)236 static inline void highbd_sad64xhx4d_large_neon(const uint8_t *src_ptr,
237                                                 int src_stride,
238                                                 const uint8_t *const ref_ptr[4],
239                                                 int ref_stride, uint32_t res[4],
240                                                 int h) {
241   highbd_sadwxhx4d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, 64,
242                               h);
243 }
244 
highbd_sad32xhx4d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)245 static inline void highbd_sad32xhx4d_large_neon(const uint8_t *src_ptr,
246                                                 int src_stride,
247                                                 const uint8_t *const ref_ptr[4],
248                                                 int ref_stride, uint32_t res[4],
249                                                 int h) {
250   highbd_sadwxhx4d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, 32,
251                               h);
252 }
253 
254 #define HBD_SAD_WXH_4D_SMALL_NEON(w, h)                                      \
255   void aom_highbd_sad##w##x##h##x4d_neon(                                    \
256       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
257       int ref_stride, uint32_t sad_array[4]) {                               \
258     highbd_sad##w##xhx4d_small_neon(src, src_stride, ref_array, ref_stride,  \
259                                     sad_array, (h));                         \
260   }
261 
262 #define HBD_SAD_WXH_4D_LARGE_NEON(w, h)                                      \
263   void aom_highbd_sad##w##x##h##x4d_neon(                                    \
264       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
265       int ref_stride, uint32_t sad_array[4]) {                               \
266     highbd_sad##w##xhx4d_large_neon(src, src_stride, ref_array, ref_stride,  \
267                                     sad_array, (h));                         \
268   }
269 
270 HBD_SAD_WXH_4D_SMALL_NEON(4, 4)
271 HBD_SAD_WXH_4D_SMALL_NEON(4, 8)
272 
273 HBD_SAD_WXH_4D_SMALL_NEON(8, 4)
274 HBD_SAD_WXH_4D_SMALL_NEON(8, 8)
275 HBD_SAD_WXH_4D_SMALL_NEON(8, 16)
276 
277 HBD_SAD_WXH_4D_LARGE_NEON(16, 8)
278 HBD_SAD_WXH_4D_LARGE_NEON(16, 16)
279 HBD_SAD_WXH_4D_LARGE_NEON(16, 32)
280 
281 HBD_SAD_WXH_4D_LARGE_NEON(32, 16)
282 HBD_SAD_WXH_4D_LARGE_NEON(32, 32)
283 HBD_SAD_WXH_4D_LARGE_NEON(32, 64)
284 
285 HBD_SAD_WXH_4D_LARGE_NEON(64, 32)
286 HBD_SAD_WXH_4D_LARGE_NEON(64, 64)
287 HBD_SAD_WXH_4D_LARGE_NEON(64, 128)
288 
289 HBD_SAD_WXH_4D_LARGE_NEON(128, 64)
290 HBD_SAD_WXH_4D_LARGE_NEON(128, 128)
291 
292 #if !CONFIG_REALTIME_ONLY
293 HBD_SAD_WXH_4D_SMALL_NEON(4, 16)
294 
295 HBD_SAD_WXH_4D_LARGE_NEON(8, 32)
296 
297 HBD_SAD_WXH_4D_LARGE_NEON(16, 4)
298 HBD_SAD_WXH_4D_LARGE_NEON(16, 64)
299 
300 HBD_SAD_WXH_4D_LARGE_NEON(32, 8)
301 
302 HBD_SAD_WXH_4D_LARGE_NEON(64, 16)
303 #endif  // !CONFIG_REALTIME_ONLY
304 
305 #define HBD_SAD_SKIP_WXH_4D_SMALL_NEON(w, h)                                 \
306   void aom_highbd_sad_skip_##w##x##h##x4d_neon(                              \
307       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
308       int ref_stride, uint32_t sad_array[4]) {                               \
309     highbd_sad##w##xhx4d_small_neon(src, 2 * src_stride, ref_array,          \
310                                     2 * ref_stride, sad_array, ((h) >> 1));  \
311     sad_array[0] <<= 1;                                                      \
312     sad_array[1] <<= 1;                                                      \
313     sad_array[2] <<= 1;                                                      \
314     sad_array[3] <<= 1;                                                      \
315   }
316 
317 #define HBD_SAD_SKIP_WXH_4D_LARGE_NEON(w, h)                                 \
318   void aom_highbd_sad_skip_##w##x##h##x4d_neon(                              \
319       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
320       int ref_stride, uint32_t sad_array[4]) {                               \
321     highbd_sad##w##xhx4d_large_neon(src, 2 * src_stride, ref_array,          \
322                                     2 * ref_stride, sad_array, ((h) >> 1));  \
323     sad_array[0] <<= 1;                                                      \
324     sad_array[1] <<= 1;                                                      \
325     sad_array[2] <<= 1;                                                      \
326     sad_array[3] <<= 1;                                                      \
327   }
328 
329 HBD_SAD_SKIP_WXH_4D_SMALL_NEON(4, 4)
330 HBD_SAD_SKIP_WXH_4D_SMALL_NEON(4, 8)
331 
332 HBD_SAD_SKIP_WXH_4D_SMALL_NEON(8, 4)
333 HBD_SAD_SKIP_WXH_4D_SMALL_NEON(8, 8)
334 HBD_SAD_SKIP_WXH_4D_SMALL_NEON(8, 16)
335 
336 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 8)
337 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 16)
338 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 32)
339 
340 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(32, 16)
341 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(32, 32)
342 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(32, 64)
343 
344 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(64, 32)
345 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(64, 64)
346 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(64, 128)
347 
348 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(128, 64)
349 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(128, 128)
350 
351 #if !CONFIG_REALTIME_ONLY
352 HBD_SAD_SKIP_WXH_4D_SMALL_NEON(4, 16)
353 
354 HBD_SAD_SKIP_WXH_4D_SMALL_NEON(8, 32)
355 
356 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 4)
357 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(16, 64)
358 
359 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(32, 8)
360 
361 HBD_SAD_SKIP_WXH_4D_LARGE_NEON(64, 16)
362 #endif  // !CONFIG_REALTIME_ONLY
363 
highbd_sad4xhx3d_small_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)364 static inline void highbd_sad4xhx3d_small_neon(const uint8_t *src_ptr,
365                                                int src_stride,
366                                                const uint8_t *const ref_ptr[4],
367                                                int ref_stride, uint32_t res[4],
368                                                int h) {
369   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
370   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
371   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
372   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
373 
374   uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
375 
376   int i = 0;
377   do {
378     uint16x4_t s = vld1_u16(src16_ptr + i * src_stride);
379     uint16x4_t r0 = vld1_u16(ref16_ptr0 + i * ref_stride);
380     uint16x4_t r1 = vld1_u16(ref16_ptr1 + i * ref_stride);
381     uint16x4_t r2 = vld1_u16(ref16_ptr2 + i * ref_stride);
382 
383     sum[0] = vabal_u16(sum[0], s, r0);
384     sum[1] = vabal_u16(sum[1], s, r1);
385     sum[2] = vabal_u16(sum[2], s, r2);
386 
387   } while (++i < h);
388 
389   res[0] = horizontal_add_u32x4(sum[0]);
390   res[1] = horizontal_add_u32x4(sum[1]);
391   res[2] = horizontal_add_u32x4(sum[2]);
392 }
393 
highbd_sad8xhx3d_small_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)394 static inline void highbd_sad8xhx3d_small_neon(const uint8_t *src_ptr,
395                                                int src_stride,
396                                                const uint8_t *const ref_ptr[4],
397                                                int ref_stride, uint32_t res[4],
398                                                int h) {
399   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
400   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
401   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
402   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
403 
404   uint16x8_t sum[3] = { vdupq_n_u16(0), vdupq_n_u16(0), vdupq_n_u16(0) };
405 
406   int i = 0;
407   do {
408     uint16x8_t s = vld1q_u16(src16_ptr + i * src_stride);
409 
410     sum[0] = vabaq_u16(sum[0], s, vld1q_u16(ref16_ptr0 + i * ref_stride));
411     sum[1] = vabaq_u16(sum[1], s, vld1q_u16(ref16_ptr1 + i * ref_stride));
412     sum[2] = vabaq_u16(sum[2], s, vld1q_u16(ref16_ptr2 + i * ref_stride));
413 
414   } while (++i < h);
415 
416   res[0] = horizontal_add_u32x4(vpaddlq_u16(sum[0]));
417   res[1] = horizontal_add_u32x4(vpaddlq_u16(sum[1]));
418   res[2] = horizontal_add_u32x4(vpaddlq_u16(sum[2]));
419 }
420 
421 #if !CONFIG_REALTIME_ONLY
highbd_sad8xhx3d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)422 static inline void highbd_sad8xhx3d_large_neon(const uint8_t *src_ptr,
423                                                int src_stride,
424                                                const uint8_t *const ref_ptr[4],
425                                                int ref_stride, uint32_t res[4],
426                                                int h) {
427   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
428   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
429   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
430   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
431 
432   uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
433 
434   int i = 0;
435   do {
436     uint16x8_t s = vld1q_u16(src16_ptr + i * src_stride);
437     uint16x8_t r0 = vld1q_u16(ref16_ptr0 + i * ref_stride);
438     uint16x8_t r1 = vld1q_u16(ref16_ptr1 + i * ref_stride);
439     uint16x8_t r2 = vld1q_u16(ref16_ptr2 + i * ref_stride);
440 
441     sad8_neon(s, r0, &sum[0]);
442     sad8_neon(s, r1, &sum[1]);
443     sad8_neon(s, r2, &sum[2]);
444 
445   } while (++i < h);
446 
447   res[0] = horizontal_add_u32x4(sum[0]);
448   res[1] = horizontal_add_u32x4(sum[1]);
449   res[2] = horizontal_add_u32x4(sum[2]);
450 }
451 #endif  // !CONFIG_REALTIME_ONLY
452 
highbd_sad16xhx3d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)453 static inline void highbd_sad16xhx3d_large_neon(const uint8_t *src_ptr,
454                                                 int src_stride,
455                                                 const uint8_t *const ref_ptr[4],
456                                                 int ref_stride, uint32_t res[4],
457                                                 int h) {
458   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
459   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
460   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
461   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
462 
463   uint32x4_t sum_lo[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
464   uint32x4_t sum_hi[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
465 
466   int i = 0;
467   do {
468     uint16x8_t s0 = vld1q_u16(src16_ptr + i * src_stride);
469     sad8_neon(s0, vld1q_u16(ref16_ptr0 + i * ref_stride), &sum_lo[0]);
470     sad8_neon(s0, vld1q_u16(ref16_ptr1 + i * ref_stride), &sum_lo[1]);
471     sad8_neon(s0, vld1q_u16(ref16_ptr2 + i * ref_stride), &sum_lo[2]);
472 
473     uint16x8_t s1 = vld1q_u16(src16_ptr + i * src_stride + 8);
474     sad8_neon(s1, vld1q_u16(ref16_ptr0 + i * ref_stride + 8), &sum_hi[0]);
475     sad8_neon(s1, vld1q_u16(ref16_ptr1 + i * ref_stride + 8), &sum_hi[1]);
476     sad8_neon(s1, vld1q_u16(ref16_ptr2 + i * ref_stride + 8), &sum_hi[2]);
477 
478   } while (++i < h);
479 
480   res[0] = horizontal_add_u32x4(vaddq_u32(sum_lo[0], sum_hi[0]));
481   res[1] = horizontal_add_u32x4(vaddq_u32(sum_lo[1], sum_hi[1]));
482   res[2] = horizontal_add_u32x4(vaddq_u32(sum_lo[2], sum_hi[2]));
483 }
484 
highbd_sadwxhx3d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int w,int h)485 static inline void highbd_sadwxhx3d_large_neon(const uint8_t *src_ptr,
486                                                int src_stride,
487                                                const uint8_t *const ref_ptr[4],
488                                                int ref_stride, uint32_t res[4],
489                                                int w, int h) {
490   const uint16_t *src16_ptr = CONVERT_TO_SHORTPTR(src_ptr);
491   const uint16_t *ref16_ptr0 = CONVERT_TO_SHORTPTR(ref_ptr[0]);
492   const uint16_t *ref16_ptr1 = CONVERT_TO_SHORTPTR(ref_ptr[1]);
493   const uint16_t *ref16_ptr2 = CONVERT_TO_SHORTPTR(ref_ptr[2]);
494 
495   uint32x4_t sum_lo[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
496   uint32x4_t sum_hi[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
497   uint32x4_t sum[3];
498 
499   int i = 0;
500   do {
501     int j = 0;
502     do {
503       uint16x8_t s0 = vld1q_u16(src16_ptr + i * src_stride + j);
504       sad8_neon(s0, vld1q_u16(ref16_ptr0 + i * ref_stride + j), &sum_lo[0]);
505       sad8_neon(s0, vld1q_u16(ref16_ptr1 + i * ref_stride + j), &sum_lo[1]);
506       sad8_neon(s0, vld1q_u16(ref16_ptr2 + i * ref_stride + j), &sum_lo[2]);
507 
508       uint16x8_t s1 = vld1q_u16(src16_ptr + i * src_stride + j + 8);
509       sad8_neon(s1, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 8), &sum_hi[0]);
510       sad8_neon(s1, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 8), &sum_hi[1]);
511       sad8_neon(s1, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 8), &sum_hi[2]);
512 
513       uint16x8_t s2 = vld1q_u16(src16_ptr + i * src_stride + j + 16);
514       sad8_neon(s2, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 16),
515                 &sum_lo[0]);
516       sad8_neon(s2, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 16),
517                 &sum_lo[1]);
518       sad8_neon(s2, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 16),
519                 &sum_lo[2]);
520 
521       uint16x8_t s3 = vld1q_u16(src16_ptr + i * src_stride + j + 24);
522       sad8_neon(s3, vld1q_u16(ref16_ptr0 + i * ref_stride + j + 24),
523                 &sum_hi[0]);
524       sad8_neon(s3, vld1q_u16(ref16_ptr1 + i * ref_stride + j + 24),
525                 &sum_hi[1]);
526       sad8_neon(s3, vld1q_u16(ref16_ptr2 + i * ref_stride + j + 24),
527                 &sum_hi[2]);
528 
529       j += 32;
530     } while (j < w);
531 
532   } while (++i < h);
533 
534   sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]);
535   sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]);
536   sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]);
537 
538   res[0] = horizontal_add_u32x4(sum[0]);
539   res[1] = horizontal_add_u32x4(sum[1]);
540   res[2] = horizontal_add_u32x4(sum[2]);
541 }
542 
highbd_sad128xhx3d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)543 static inline void highbd_sad128xhx3d_large_neon(
544     const uint8_t *src_ptr, int src_stride, const uint8_t *const ref_ptr[4],
545     int ref_stride, uint32_t res[4], int h) {
546   highbd_sadwxhx3d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res,
547                               128, h);
548 }
549 
highbd_sad64xhx3d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)550 static inline void highbd_sad64xhx3d_large_neon(const uint8_t *src_ptr,
551                                                 int src_stride,
552                                                 const uint8_t *const ref_ptr[4],
553                                                 int ref_stride, uint32_t res[4],
554                                                 int h) {
555   highbd_sadwxhx3d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, 64,
556                               h);
557 }
558 
highbd_sad32xhx3d_large_neon(const uint8_t * src_ptr,int src_stride,const uint8_t * const ref_ptr[4],int ref_stride,uint32_t res[4],int h)559 static inline void highbd_sad32xhx3d_large_neon(const uint8_t *src_ptr,
560                                                 int src_stride,
561                                                 const uint8_t *const ref_ptr[4],
562                                                 int ref_stride, uint32_t res[4],
563                                                 int h) {
564   highbd_sadwxhx3d_large_neon(src_ptr, src_stride, ref_ptr, ref_stride, res, 32,
565                               h);
566 }
567 
568 #define HBD_SAD_WXH_3D_SMALL_NEON(w, h)                                      \
569   void aom_highbd_sad##w##x##h##x3d_neon(                                    \
570       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
571       int ref_stride, uint32_t sad_array[4]) {                               \
572     highbd_sad##w##xhx3d_small_neon(src, src_stride, ref_array, ref_stride,  \
573                                     sad_array, (h));                         \
574   }
575 
576 #define HBD_SAD_WXH_3D_LARGE_NEON(w, h)                                      \
577   void aom_highbd_sad##w##x##h##x3d_neon(                                    \
578       const uint8_t *src, int src_stride, const uint8_t *const ref_array[4], \
579       int ref_stride, uint32_t sad_array[4]) {                               \
580     highbd_sad##w##xhx3d_large_neon(src, src_stride, ref_array, ref_stride,  \
581                                     sad_array, (h));                         \
582   }
583 
584 HBD_SAD_WXH_3D_SMALL_NEON(4, 4)
585 HBD_SAD_WXH_3D_SMALL_NEON(4, 8)
586 
587 HBD_SAD_WXH_3D_SMALL_NEON(8, 4)
588 HBD_SAD_WXH_3D_SMALL_NEON(8, 8)
589 HBD_SAD_WXH_3D_SMALL_NEON(8, 16)
590 
591 HBD_SAD_WXH_3D_LARGE_NEON(16, 8)
592 HBD_SAD_WXH_3D_LARGE_NEON(16, 16)
593 HBD_SAD_WXH_3D_LARGE_NEON(16, 32)
594 
595 HBD_SAD_WXH_3D_LARGE_NEON(32, 16)
596 HBD_SAD_WXH_3D_LARGE_NEON(32, 32)
597 HBD_SAD_WXH_3D_LARGE_NEON(32, 64)
598 
599 HBD_SAD_WXH_3D_LARGE_NEON(64, 32)
600 HBD_SAD_WXH_3D_LARGE_NEON(64, 64)
601 HBD_SAD_WXH_3D_LARGE_NEON(64, 128)
602 
603 HBD_SAD_WXH_3D_LARGE_NEON(128, 64)
604 HBD_SAD_WXH_3D_LARGE_NEON(128, 128)
605 
606 #if !CONFIG_REALTIME_ONLY
607 HBD_SAD_WXH_3D_SMALL_NEON(4, 16)
608 
609 HBD_SAD_WXH_3D_LARGE_NEON(8, 32)
610 
611 HBD_SAD_WXH_3D_LARGE_NEON(16, 4)
612 HBD_SAD_WXH_3D_LARGE_NEON(16, 64)
613 
614 HBD_SAD_WXH_3D_LARGE_NEON(32, 8)
615 
616 HBD_SAD_WXH_3D_LARGE_NEON(64, 16)
617 #endif  // !CONFIG_REALTIME_ONLY
618