xref: /aosp_15_r20/external/libaom/aom_dsp/arm/sadxd_neon_dotprod.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_dsp/arm/sum_neon.h"
20 
sad16_neon(uint8x16_t src,uint8x16_t ref,uint32x4_t * const sad_sum)21 static inline void sad16_neon(uint8x16_t src, uint8x16_t ref,
22                               uint32x4_t *const sad_sum) {
23   uint8x16_t abs_diff = vabdq_u8(src, ref);
24   *sad_sum = vdotq_u32(*sad_sum, abs_diff, vdupq_n_u8(1));
25 }
26 
sadwxhx3d_large_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int w,int h)27 static inline void sadwxhx3d_large_neon_dotprod(const uint8_t *src,
28                                                 int src_stride,
29                                                 const uint8_t *const ref[4],
30                                                 int ref_stride, uint32_t res[4],
31                                                 int w, int h) {
32   uint32x4_t sum_lo[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
33   uint32x4_t sum_hi[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
34 
35   int ref_offset = 0;
36   int i = h;
37   do {
38     int j = 0;
39     do {
40       const uint8x16_t s0 = vld1q_u8(src + j);
41       sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]);
42       sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]);
43       sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]);
44 
45       const uint8x16_t s1 = vld1q_u8(src + j + 16);
46       sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]);
47       sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]);
48       sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]);
49 
50       j += 32;
51     } while (j < w);
52 
53     src += src_stride;
54     ref_offset += ref_stride;
55   } while (--i != 0);
56 
57   res[0] = horizontal_add_u32x4(vaddq_u32(sum_lo[0], sum_hi[0]));
58   res[1] = horizontal_add_u32x4(vaddq_u32(sum_lo[1], sum_hi[1]));
59   res[2] = horizontal_add_u32x4(vaddq_u32(sum_lo[2], sum_hi[2]));
60 }
61 
sad128xhx3d_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int h)62 static inline void sad128xhx3d_neon_dotprod(const uint8_t *src, int src_stride,
63                                             const uint8_t *const ref[4],
64                                             int ref_stride, uint32_t res[4],
65                                             int h) {
66   sadwxhx3d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 128, h);
67 }
68 
sad64xhx3d_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int h)69 static inline void sad64xhx3d_neon_dotprod(const uint8_t *src, int src_stride,
70                                            const uint8_t *const ref[4],
71                                            int ref_stride, uint32_t res[4],
72                                            int h) {
73   sadwxhx3d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 64, h);
74 }
75 
sad32xhx3d_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int h)76 static inline void sad32xhx3d_neon_dotprod(const uint8_t *src, int src_stride,
77                                            const uint8_t *const ref[4],
78                                            int ref_stride, uint32_t res[4],
79                                            int h) {
80   sadwxhx3d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 32, h);
81 }
82 
sad16xhx3d_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int h)83 static inline void sad16xhx3d_neon_dotprod(const uint8_t *src, int src_stride,
84                                            const uint8_t *const ref[4],
85                                            int ref_stride, uint32_t res[4],
86                                            int h) {
87   uint32x4_t sum[3] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0) };
88 
89   int ref_offset = 0;
90   int i = h;
91   do {
92     const uint8x16_t s = vld1q_u8(src);
93     sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]);
94     sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]);
95     sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]);
96 
97     src += src_stride;
98     ref_offset += ref_stride;
99   } while (--i != 0);
100 
101   res[0] = horizontal_add_u32x4(sum[0]);
102   res[1] = horizontal_add_u32x4(sum[1]);
103   res[2] = horizontal_add_u32x4(sum[2]);
104 }
105 
106 #define SAD_WXH_3D_NEON_DOTPROD(w, h)                                         \
107   void aom_sad##w##x##h##x3d_neon_dotprod(const uint8_t *src, int src_stride, \
108                                           const uint8_t *const ref[4],        \
109                                           int ref_stride, uint32_t res[4]) {  \
110     sad##w##xhx3d_neon_dotprod(src, src_stride, ref, ref_stride, res, (h));   \
111   }
112 
113 SAD_WXH_3D_NEON_DOTPROD(16, 8)
114 SAD_WXH_3D_NEON_DOTPROD(16, 16)
115 SAD_WXH_3D_NEON_DOTPROD(16, 32)
116 
117 SAD_WXH_3D_NEON_DOTPROD(32, 16)
118 SAD_WXH_3D_NEON_DOTPROD(32, 32)
119 SAD_WXH_3D_NEON_DOTPROD(32, 64)
120 
121 SAD_WXH_3D_NEON_DOTPROD(64, 32)
122 SAD_WXH_3D_NEON_DOTPROD(64, 64)
123 SAD_WXH_3D_NEON_DOTPROD(64, 128)
124 
125 SAD_WXH_3D_NEON_DOTPROD(128, 64)
126 SAD_WXH_3D_NEON_DOTPROD(128, 128)
127 
128 #if !CONFIG_REALTIME_ONLY
129 SAD_WXH_3D_NEON_DOTPROD(16, 4)
130 SAD_WXH_3D_NEON_DOTPROD(16, 64)
131 SAD_WXH_3D_NEON_DOTPROD(32, 8)
132 SAD_WXH_3D_NEON_DOTPROD(64, 16)
133 #endif  // !CONFIG_REALTIME_ONLY
134 
135 #undef SAD_WXH_3D_NEON_DOTPROD
136 
sadwxhx4d_large_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int w,int h)137 static inline void sadwxhx4d_large_neon_dotprod(const uint8_t *src,
138                                                 int src_stride,
139                                                 const uint8_t *const ref[4],
140                                                 int ref_stride, uint32_t res[4],
141                                                 int w, int h) {
142   uint32x4_t sum_lo[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
143                            vdupq_n_u32(0) };
144   uint32x4_t sum_hi[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
145                            vdupq_n_u32(0) };
146   uint32x4_t sum[4];
147 
148   int ref_offset = 0;
149   int i = h;
150   do {
151     int j = 0;
152     do {
153       const uint8x16_t s0 = vld1q_u8(src + j);
154       sad16_neon(s0, vld1q_u8(ref[0] + ref_offset + j), &sum_lo[0]);
155       sad16_neon(s0, vld1q_u8(ref[1] + ref_offset + j), &sum_lo[1]);
156       sad16_neon(s0, vld1q_u8(ref[2] + ref_offset + j), &sum_lo[2]);
157       sad16_neon(s0, vld1q_u8(ref[3] + ref_offset + j), &sum_lo[3]);
158 
159       const uint8x16_t s1 = vld1q_u8(src + j + 16);
160       sad16_neon(s1, vld1q_u8(ref[0] + ref_offset + j + 16), &sum_hi[0]);
161       sad16_neon(s1, vld1q_u8(ref[1] + ref_offset + j + 16), &sum_hi[1]);
162       sad16_neon(s1, vld1q_u8(ref[2] + ref_offset + j + 16), &sum_hi[2]);
163       sad16_neon(s1, vld1q_u8(ref[3] + ref_offset + j + 16), &sum_hi[3]);
164 
165       j += 32;
166     } while (j < w);
167 
168     src += src_stride;
169     ref_offset += ref_stride;
170   } while (--i != 0);
171 
172   sum[0] = vaddq_u32(sum_lo[0], sum_hi[0]);
173   sum[1] = vaddq_u32(sum_lo[1], sum_hi[1]);
174   sum[2] = vaddq_u32(sum_lo[2], sum_hi[2]);
175   sum[3] = vaddq_u32(sum_lo[3], sum_hi[3]);
176 
177   vst1q_u32(res, horizontal_add_4d_u32x4(sum));
178 }
179 
sad128xhx4d_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int h)180 static inline void sad128xhx4d_neon_dotprod(const uint8_t *src, int src_stride,
181                                             const uint8_t *const ref[4],
182                                             int ref_stride, uint32_t res[4],
183                                             int h) {
184   sadwxhx4d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 128, h);
185 }
186 
sad64xhx4d_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int h)187 static inline void sad64xhx4d_neon_dotprod(const uint8_t *src, int src_stride,
188                                            const uint8_t *const ref[4],
189                                            int ref_stride, uint32_t res[4],
190                                            int h) {
191   sadwxhx4d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 64, h);
192 }
193 
sad32xhx4d_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int h)194 static inline void sad32xhx4d_neon_dotprod(const uint8_t *src, int src_stride,
195                                            const uint8_t *const ref[4],
196                                            int ref_stride, uint32_t res[4],
197                                            int h) {
198   sadwxhx4d_large_neon_dotprod(src, src_stride, ref, ref_stride, res, 32, h);
199 }
200 
sad16xhx4d_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * const ref[4],int ref_stride,uint32_t res[4],int h)201 static inline void sad16xhx4d_neon_dotprod(const uint8_t *src, int src_stride,
202                                            const uint8_t *const ref[4],
203                                            int ref_stride, uint32_t res[4],
204                                            int h) {
205   uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
206                         vdupq_n_u32(0) };
207 
208   int ref_offset = 0;
209   int i = h;
210   do {
211     const uint8x16_t s = vld1q_u8(src);
212     sad16_neon(s, vld1q_u8(ref[0] + ref_offset), &sum[0]);
213     sad16_neon(s, vld1q_u8(ref[1] + ref_offset), &sum[1]);
214     sad16_neon(s, vld1q_u8(ref[2] + ref_offset), &sum[2]);
215     sad16_neon(s, vld1q_u8(ref[3] + ref_offset), &sum[3]);
216 
217     src += src_stride;
218     ref_offset += ref_stride;
219   } while (--i != 0);
220 
221   vst1q_u32(res, horizontal_add_4d_u32x4(sum));
222 }
223 
224 #define SAD_WXH_4D_NEON_DOTPROD(w, h)                                         \
225   void aom_sad##w##x##h##x4d_neon_dotprod(const uint8_t *src, int src_stride, \
226                                           const uint8_t *const ref[4],        \
227                                           int ref_stride, uint32_t res[4]) {  \
228     sad##w##xhx4d_neon_dotprod(src, src_stride, ref, ref_stride, res, (h));   \
229   }
230 
231 SAD_WXH_4D_NEON_DOTPROD(16, 8)
232 SAD_WXH_4D_NEON_DOTPROD(16, 16)
233 SAD_WXH_4D_NEON_DOTPROD(16, 32)
234 
235 SAD_WXH_4D_NEON_DOTPROD(32, 16)
236 SAD_WXH_4D_NEON_DOTPROD(32, 32)
237 SAD_WXH_4D_NEON_DOTPROD(32, 64)
238 
239 SAD_WXH_4D_NEON_DOTPROD(64, 32)
240 SAD_WXH_4D_NEON_DOTPROD(64, 64)
241 SAD_WXH_4D_NEON_DOTPROD(64, 128)
242 
243 SAD_WXH_4D_NEON_DOTPROD(128, 64)
244 SAD_WXH_4D_NEON_DOTPROD(128, 128)
245 
246 #if !CONFIG_REALTIME_ONLY
247 SAD_WXH_4D_NEON_DOTPROD(16, 4)
248 SAD_WXH_4D_NEON_DOTPROD(16, 64)
249 SAD_WXH_4D_NEON_DOTPROD(32, 8)
250 SAD_WXH_4D_NEON_DOTPROD(64, 16)
251 #endif  // !CONFIG_REALTIME_ONLY
252 
253 #undef SAD_WXH_4D_NEON_DOTPROD
254 
255 #define SAD_SKIP_WXH_4D_NEON_DOTPROD(w, h)                                    \
256   void aom_sad_skip_##w##x##h##x4d_neon_dotprod(                              \
257       const uint8_t *src, int src_stride, const uint8_t *const ref[4],        \
258       int ref_stride, uint32_t res[4]) {                                      \
259     sad##w##xhx4d_neon_dotprod(src, 2 * src_stride, ref, 2 * ref_stride, res, \
260                                ((h) >> 1));                                   \
261     res[0] <<= 1;                                                             \
262     res[1] <<= 1;                                                             \
263     res[2] <<= 1;                                                             \
264     res[3] <<= 1;                                                             \
265   }
266 
267 SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 8)
268 SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 16)
269 SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 32)
270 
271 SAD_SKIP_WXH_4D_NEON_DOTPROD(32, 16)
272 SAD_SKIP_WXH_4D_NEON_DOTPROD(32, 32)
273 SAD_SKIP_WXH_4D_NEON_DOTPROD(32, 64)
274 
275 SAD_SKIP_WXH_4D_NEON_DOTPROD(64, 32)
276 SAD_SKIP_WXH_4D_NEON_DOTPROD(64, 64)
277 SAD_SKIP_WXH_4D_NEON_DOTPROD(64, 128)
278 
279 SAD_SKIP_WXH_4D_NEON_DOTPROD(128, 64)
280 SAD_SKIP_WXH_4D_NEON_DOTPROD(128, 128)
281 
282 #if !CONFIG_REALTIME_ONLY
283 SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 4)
284 SAD_SKIP_WXH_4D_NEON_DOTPROD(16, 64)
285 SAD_SKIP_WXH_4D_NEON_DOTPROD(32, 8)
286 SAD_SKIP_WXH_4D_NEON_DOTPROD(64, 16)
287 #endif  // !CONFIG_REALTIME_ONLY
288 
289 #undef SAD_SKIP_WXH_4D_NEON_DOTPROD
290