xref: /aosp_15_r20/external/libaom/aom_dsp/arm/sad_neon_dotprod.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/dist_wtd_avg_neon.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 
sadwxh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int w,int h)22 static inline unsigned int sadwxh_neon_dotprod(const uint8_t *src_ptr,
23                                                int src_stride,
24                                                const uint8_t *ref_ptr,
25                                                int ref_stride, int w, int h) {
26   // Only two accumulators are required for optimal instruction throughput of
27   // the ABD, UDOT sequence on CPUs with either 2 or 4 Neon pipes.
28   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
29 
30   int i = h;
31   do {
32     int j = 0;
33     do {
34       uint8x16_t s0, s1, r0, r1, diff0, diff1;
35 
36       s0 = vld1q_u8(src_ptr + j);
37       r0 = vld1q_u8(ref_ptr + j);
38       diff0 = vabdq_u8(s0, r0);
39       sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
40 
41       s1 = vld1q_u8(src_ptr + j + 16);
42       r1 = vld1q_u8(ref_ptr + j + 16);
43       diff1 = vabdq_u8(s1, r1);
44       sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
45 
46       j += 32;
47     } while (j < w);
48 
49     src_ptr += src_stride;
50     ref_ptr += ref_stride;
51   } while (--i != 0);
52 
53   return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
54 }
55 
sad128xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)56 static inline unsigned int sad128xh_neon_dotprod(const uint8_t *src_ptr,
57                                                  int src_stride,
58                                                  const uint8_t *ref_ptr,
59                                                  int ref_stride, int h) {
60   return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 128, h);
61 }
62 
sad64xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)63 static inline unsigned int sad64xh_neon_dotprod(const uint8_t *src_ptr,
64                                                 int src_stride,
65                                                 const uint8_t *ref_ptr,
66                                                 int ref_stride, int h) {
67   return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 64, h);
68 }
69 
sad32xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)70 static inline unsigned int sad32xh_neon_dotprod(const uint8_t *src_ptr,
71                                                 int src_stride,
72                                                 const uint8_t *ref_ptr,
73                                                 int ref_stride, int h) {
74   return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 32, h);
75 }
76 
sad16xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)77 static inline unsigned int sad16xh_neon_dotprod(const uint8_t *src_ptr,
78                                                 int src_stride,
79                                                 const uint8_t *ref_ptr,
80                                                 int ref_stride, int h) {
81   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
82 
83   int i = h / 2;
84   do {
85     uint8x16_t s0, s1, r0, r1, diff0, diff1;
86 
87     s0 = vld1q_u8(src_ptr);
88     r0 = vld1q_u8(ref_ptr);
89     diff0 = vabdq_u8(s0, r0);
90     sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
91 
92     src_ptr += src_stride;
93     ref_ptr += ref_stride;
94 
95     s1 = vld1q_u8(src_ptr);
96     r1 = vld1q_u8(ref_ptr);
97     diff1 = vabdq_u8(s1, r1);
98     sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
99 
100     src_ptr += src_stride;
101     ref_ptr += ref_stride;
102   } while (--i != 0);
103 
104   return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
105 }
106 
107 #define SAD_WXH_NEON_DOTPROD(w, h)                                         \
108   unsigned int aom_sad##w##x##h##_neon_dotprod(                            \
109       const uint8_t *src, int src_stride, const uint8_t *ref,              \
110       int ref_stride) {                                                    \
111     return sad##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, (h)); \
112   }
113 
114 SAD_WXH_NEON_DOTPROD(16, 8)
115 SAD_WXH_NEON_DOTPROD(16, 16)
116 SAD_WXH_NEON_DOTPROD(16, 32)
117 
118 SAD_WXH_NEON_DOTPROD(32, 16)
119 SAD_WXH_NEON_DOTPROD(32, 32)
120 SAD_WXH_NEON_DOTPROD(32, 64)
121 
122 SAD_WXH_NEON_DOTPROD(64, 32)
123 SAD_WXH_NEON_DOTPROD(64, 64)
124 SAD_WXH_NEON_DOTPROD(64, 128)
125 
126 SAD_WXH_NEON_DOTPROD(128, 64)
127 SAD_WXH_NEON_DOTPROD(128, 128)
128 
129 #if !CONFIG_REALTIME_ONLY
130 SAD_WXH_NEON_DOTPROD(16, 4)
131 SAD_WXH_NEON_DOTPROD(16, 64)
132 SAD_WXH_NEON_DOTPROD(32, 8)
133 SAD_WXH_NEON_DOTPROD(64, 16)
134 #endif  // !CONFIG_REALTIME_ONLY
135 
136 #undef SAD_WXH_NEON_DOTPROD
137 
138 #define SAD_SKIP_WXH_NEON_DOTPROD(w, h)                          \
139   unsigned int aom_sad_skip_##w##x##h##_neon_dotprod(            \
140       const uint8_t *src, int src_stride, const uint8_t *ref,    \
141       int ref_stride) {                                          \
142     return 2 * sad##w##xh_neon_dotprod(src, 2 * src_stride, ref, \
143                                        2 * ref_stride, (h) / 2); \
144   }
145 
146 SAD_SKIP_WXH_NEON_DOTPROD(16, 8)
147 SAD_SKIP_WXH_NEON_DOTPROD(16, 16)
148 SAD_SKIP_WXH_NEON_DOTPROD(16, 32)
149 
150 SAD_SKIP_WXH_NEON_DOTPROD(32, 16)
151 SAD_SKIP_WXH_NEON_DOTPROD(32, 32)
152 SAD_SKIP_WXH_NEON_DOTPROD(32, 64)
153 
154 SAD_SKIP_WXH_NEON_DOTPROD(64, 32)
155 SAD_SKIP_WXH_NEON_DOTPROD(64, 64)
156 SAD_SKIP_WXH_NEON_DOTPROD(64, 128)
157 
158 SAD_SKIP_WXH_NEON_DOTPROD(128, 64)
159 SAD_SKIP_WXH_NEON_DOTPROD(128, 128)
160 
161 #if !CONFIG_REALTIME_ONLY
162 SAD_SKIP_WXH_NEON_DOTPROD(16, 4)
163 SAD_SKIP_WXH_NEON_DOTPROD(16, 64)
164 SAD_SKIP_WXH_NEON_DOTPROD(32, 8)
165 SAD_SKIP_WXH_NEON_DOTPROD(64, 16)
166 #endif  // !CONFIG_REALTIME_ONLY
167 
168 #undef SAD_SKIP_WXH_NEON_DOTPROD
169 
sadwxh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int w,int h,const uint8_t * second_pred)170 static inline unsigned int sadwxh_avg_neon_dotprod(const uint8_t *src_ptr,
171                                                    int src_stride,
172                                                    const uint8_t *ref_ptr,
173                                                    int ref_stride, int w, int h,
174                                                    const uint8_t *second_pred) {
175   // Only two accumulators are required for optimal instruction throughput of
176   // the ABD, UDOT sequence on CPUs with either 2 or 4 Neon pipes.
177   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
178 
179   int i = h;
180   do {
181     int j = 0;
182     do {
183       uint8x16_t s0, s1, r0, r1, p0, p1, avg0, avg1, diff0, diff1;
184 
185       s0 = vld1q_u8(src_ptr + j);
186       r0 = vld1q_u8(ref_ptr + j);
187       p0 = vld1q_u8(second_pred);
188       avg0 = vrhaddq_u8(r0, p0);
189       diff0 = vabdq_u8(s0, avg0);
190       sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
191 
192       s1 = vld1q_u8(src_ptr + j + 16);
193       r1 = vld1q_u8(ref_ptr + j + 16);
194       p1 = vld1q_u8(second_pred + 16);
195       avg1 = vrhaddq_u8(r1, p1);
196       diff1 = vabdq_u8(s1, avg1);
197       sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
198 
199       j += 32;
200       second_pred += 32;
201     } while (j < w);
202 
203     src_ptr += src_stride;
204     ref_ptr += ref_stride;
205   } while (--i != 0);
206 
207   return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
208 }
209 
sad128xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)210 static inline unsigned int sad128xh_avg_neon_dotprod(
211     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
212     int ref_stride, int h, const uint8_t *second_pred) {
213   return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 128,
214                                  h, second_pred);
215 }
216 
sad64xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)217 static inline unsigned int sad64xh_avg_neon_dotprod(
218     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
219     int ref_stride, int h, const uint8_t *second_pred) {
220   return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 64,
221                                  h, second_pred);
222 }
223 
sad32xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)224 static inline unsigned int sad32xh_avg_neon_dotprod(
225     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
226     int ref_stride, int h, const uint8_t *second_pred) {
227   return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 32,
228                                  h, second_pred);
229 }
230 
sad16xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)231 static inline unsigned int sad16xh_avg_neon_dotprod(
232     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
233     int ref_stride, int h, const uint8_t *second_pred) {
234   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
235 
236   int i = h / 2;
237   do {
238     uint8x16_t s0, s1, r0, r1, p0, p1, avg0, avg1, diff0, diff1;
239 
240     s0 = vld1q_u8(src_ptr);
241     r0 = vld1q_u8(ref_ptr);
242     p0 = vld1q_u8(second_pred);
243     avg0 = vrhaddq_u8(r0, p0);
244     diff0 = vabdq_u8(s0, avg0);
245     sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
246 
247     src_ptr += src_stride;
248     ref_ptr += ref_stride;
249     second_pred += 16;
250 
251     s1 = vld1q_u8(src_ptr);
252     r1 = vld1q_u8(ref_ptr);
253     p1 = vld1q_u8(second_pred);
254     avg1 = vrhaddq_u8(r1, p1);
255     diff1 = vabdq_u8(s1, avg1);
256     sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
257 
258     src_ptr += src_stride;
259     ref_ptr += ref_stride;
260     second_pred += 16;
261   } while (--i != 0);
262 
263   return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
264 }
265 
266 #define SAD_WXH_AVG_NEON_DOTPROD(w, h)                                        \
267   unsigned int aom_sad##w##x##h##_avg_neon_dotprod(                           \
268       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
269       const uint8_t *second_pred) {                                           \
270     return sad##w##xh_avg_neon_dotprod(src, src_stride, ref, ref_stride, (h), \
271                                        second_pred);                          \
272   }
273 
274 SAD_WXH_AVG_NEON_DOTPROD(16, 8)
275 SAD_WXH_AVG_NEON_DOTPROD(16, 16)
276 SAD_WXH_AVG_NEON_DOTPROD(16, 32)
277 
278 SAD_WXH_AVG_NEON_DOTPROD(32, 16)
279 SAD_WXH_AVG_NEON_DOTPROD(32, 32)
280 SAD_WXH_AVG_NEON_DOTPROD(32, 64)
281 
282 SAD_WXH_AVG_NEON_DOTPROD(64, 32)
283 SAD_WXH_AVG_NEON_DOTPROD(64, 64)
284 SAD_WXH_AVG_NEON_DOTPROD(64, 128)
285 
286 SAD_WXH_AVG_NEON_DOTPROD(128, 64)
287 SAD_WXH_AVG_NEON_DOTPROD(128, 128)
288 
289 #if !CONFIG_REALTIME_ONLY
290 SAD_WXH_AVG_NEON_DOTPROD(16, 4)
291 SAD_WXH_AVG_NEON_DOTPROD(16, 64)
292 SAD_WXH_AVG_NEON_DOTPROD(32, 8)
293 SAD_WXH_AVG_NEON_DOTPROD(64, 16)
294 #endif  // !CONFIG_REALTIME_ONLY
295 
296 #undef SAD_WXH_AVG_NEON_DOTPROD
297 
dist_wtd_sad128xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)298 static inline unsigned int dist_wtd_sad128xh_avg_neon_dotprod(
299     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
300     int ref_stride, int h, const uint8_t *second_pred,
301     const DIST_WTD_COMP_PARAMS *jcp_param) {
302   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
303   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
304   // We use 8 accumulators to minimize the accumulation and loop carried
305   // dependencies for better instruction throughput.
306   uint32x4_t sum[8] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
307                         vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
308                         vdupq_n_u32(0), vdupq_n_u32(0) };
309 
310   do {
311     uint8x16_t s0 = vld1q_u8(src_ptr);
312     uint8x16_t r0 = vld1q_u8(ref_ptr);
313     uint8x16_t p0 = vld1q_u8(second_pred);
314     uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
315     uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
316     sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
317 
318     uint8x16_t s1 = vld1q_u8(src_ptr + 16);
319     uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
320     uint8x16_t p1 = vld1q_u8(second_pred + 16);
321     uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
322     uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
323     sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
324 
325     uint8x16_t s2 = vld1q_u8(src_ptr + 32);
326     uint8x16_t r2 = vld1q_u8(ref_ptr + 32);
327     uint8x16_t p2 = vld1q_u8(second_pred + 32);
328     uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset);
329     uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2);
330     sum[2] = vdotq_u32(sum[2], diff2, vdupq_n_u8(1));
331 
332     uint8x16_t s3 = vld1q_u8(src_ptr + 48);
333     uint8x16_t r3 = vld1q_u8(ref_ptr + 48);
334     uint8x16_t p3 = vld1q_u8(second_pred + 48);
335     uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset);
336     uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3);
337     sum[3] = vdotq_u32(sum[3], diff3, vdupq_n_u8(1));
338 
339     uint8x16_t s4 = vld1q_u8(src_ptr + 64);
340     uint8x16_t r4 = vld1q_u8(ref_ptr + 64);
341     uint8x16_t p4 = vld1q_u8(second_pred + 64);
342     uint8x16_t wtd_avg4 = dist_wtd_avg_u8x16(p4, r4, bck_offset, fwd_offset);
343     uint8x16_t diff4 = vabdq_u8(s4, wtd_avg4);
344     sum[4] = vdotq_u32(sum[4], diff4, vdupq_n_u8(1));
345 
346     uint8x16_t s5 = vld1q_u8(src_ptr + 80);
347     uint8x16_t r5 = vld1q_u8(ref_ptr + 80);
348     uint8x16_t p5 = vld1q_u8(second_pred + 80);
349     uint8x16_t wtd_avg5 = dist_wtd_avg_u8x16(p5, r5, bck_offset, fwd_offset);
350     uint8x16_t diff5 = vabdq_u8(s5, wtd_avg5);
351     sum[5] = vdotq_u32(sum[5], diff5, vdupq_n_u8(1));
352 
353     uint8x16_t s6 = vld1q_u8(src_ptr + 96);
354     uint8x16_t r6 = vld1q_u8(ref_ptr + 96);
355     uint8x16_t p6 = vld1q_u8(second_pred + 96);
356     uint8x16_t wtd_avg6 = dist_wtd_avg_u8x16(p6, r6, bck_offset, fwd_offset);
357     uint8x16_t diff6 = vabdq_u8(s6, wtd_avg6);
358     sum[6] = vdotq_u32(sum[6], diff6, vdupq_n_u8(1));
359 
360     uint8x16_t s7 = vld1q_u8(src_ptr + 112);
361     uint8x16_t r7 = vld1q_u8(ref_ptr + 112);
362     uint8x16_t p7 = vld1q_u8(second_pred + 112);
363     uint8x16_t wtd_avg7 = dist_wtd_avg_u8x16(p7, r7, bck_offset, fwd_offset);
364     uint8x16_t diff7 = vabdq_u8(s7, wtd_avg7);
365     sum[7] = vdotq_u32(sum[7], diff7, vdupq_n_u8(1));
366 
367     src_ptr += src_stride;
368     ref_ptr += ref_stride;
369     second_pred += 128;
370   } while (--h != 0);
371 
372   sum[0] = vaddq_u32(sum[0], sum[1]);
373   sum[2] = vaddq_u32(sum[2], sum[3]);
374   sum[4] = vaddq_u32(sum[4], sum[5]);
375   sum[6] = vaddq_u32(sum[6], sum[7]);
376   sum[0] = vaddq_u32(sum[0], sum[2]);
377   sum[4] = vaddq_u32(sum[4], sum[6]);
378   sum[0] = vaddq_u32(sum[0], sum[4]);
379   return horizontal_add_u32x4(sum[0]);
380 }
381 
dist_wtd_sad64xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)382 static inline unsigned int dist_wtd_sad64xh_avg_neon_dotprod(
383     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
384     int ref_stride, int h, const uint8_t *second_pred,
385     const DIST_WTD_COMP_PARAMS *jcp_param) {
386   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
387   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
388   uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
389                         vdupq_n_u32(0) };
390 
391   do {
392     uint8x16_t s0 = vld1q_u8(src_ptr);
393     uint8x16_t r0 = vld1q_u8(ref_ptr);
394     uint8x16_t p0 = vld1q_u8(second_pred);
395     uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
396     uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
397     sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
398 
399     uint8x16_t s1 = vld1q_u8(src_ptr + 16);
400     uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
401     uint8x16_t p1 = vld1q_u8(second_pred + 16);
402     uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
403     uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
404     sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
405 
406     uint8x16_t s2 = vld1q_u8(src_ptr + 32);
407     uint8x16_t r2 = vld1q_u8(ref_ptr + 32);
408     uint8x16_t p2 = vld1q_u8(second_pred + 32);
409     uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset);
410     uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2);
411     sum[2] = vdotq_u32(sum[2], diff2, vdupq_n_u8(1));
412 
413     uint8x16_t s3 = vld1q_u8(src_ptr + 48);
414     uint8x16_t r3 = vld1q_u8(ref_ptr + 48);
415     uint8x16_t p3 = vld1q_u8(second_pred + 48);
416     uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset);
417     uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3);
418     sum[3] = vdotq_u32(sum[3], diff3, vdupq_n_u8(1));
419 
420     src_ptr += src_stride;
421     ref_ptr += ref_stride;
422     second_pred += 64;
423   } while (--h != 0);
424 
425   sum[0] = vaddq_u32(sum[0], sum[1]);
426   sum[2] = vaddq_u32(sum[2], sum[3]);
427   sum[0] = vaddq_u32(sum[0], sum[2]);
428   return horizontal_add_u32x4(sum[0]);
429 }
430 
dist_wtd_sad32xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)431 static inline unsigned int dist_wtd_sad32xh_avg_neon_dotprod(
432     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
433     int ref_stride, int h, const uint8_t *second_pred,
434     const DIST_WTD_COMP_PARAMS *jcp_param) {
435   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
436   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
437   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
438 
439   do {
440     uint8x16_t s0 = vld1q_u8(src_ptr);
441     uint8x16_t r0 = vld1q_u8(ref_ptr);
442     uint8x16_t p0 = vld1q_u8(second_pred);
443     uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
444     uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
445     sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
446 
447     uint8x16_t s1 = vld1q_u8(src_ptr + 16);
448     uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
449     uint8x16_t p1 = vld1q_u8(second_pred + 16);
450     uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
451     uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
452     sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
453 
454     src_ptr += src_stride;
455     ref_ptr += ref_stride;
456     second_pred += 32;
457   } while (--h != 0);
458 
459   sum[0] = vaddq_u32(sum[0], sum[1]);
460   return horizontal_add_u32x4(sum[0]);
461 }
462 
dist_wtd_sad16xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)463 static inline unsigned int dist_wtd_sad16xh_avg_neon_dotprod(
464     const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
465     int ref_stride, int h, const uint8_t *second_pred,
466     const DIST_WTD_COMP_PARAMS *jcp_param) {
467   const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
468   const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
469   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
470 
471   int i = h / 2;
472   do {
473     uint8x16_t s0 = vld1q_u8(src_ptr);
474     uint8x16_t r0 = vld1q_u8(ref_ptr);
475     uint8x16_t p0 = vld1q_u8(second_pred);
476     uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
477     uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
478     sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
479 
480     src_ptr += src_stride;
481     ref_ptr += ref_stride;
482     second_pred += 16;
483 
484     uint8x16_t s1 = vld1q_u8(src_ptr);
485     uint8x16_t r1 = vld1q_u8(ref_ptr);
486     uint8x16_t p1 = vld1q_u8(second_pred);
487     uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
488     uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
489     sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
490 
491     src_ptr += src_stride;
492     ref_ptr += ref_stride;
493     second_pred += 16;
494   } while (--i != 0);
495 
496   sum[0] = vaddq_u32(sum[0], sum[1]);
497   return horizontal_add_u32x4(sum[0]);
498 }
499 
500 #define DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(w, h)                               \
501   unsigned int aom_dist_wtd_sad##w##x##h##_avg_neon_dotprod(                  \
502       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
503       const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) {    \
504     return dist_wtd_sad##w##xh_avg_neon_dotprod(                              \
505         src, src_stride, ref, ref_stride, (h), second_pred, jcp_param);       \
506   }
507 
508 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 8)
509 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 16)
510 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 32)
511 
512 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 16)
513 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 32)
514 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 64)
515 
516 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 32)
517 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 64)
518 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 128)
519 
520 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(128, 64)
521 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(128, 128)
522 
523 #if !CONFIG_REALTIME_ONLY
524 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 4)
525 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 64)
526 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 8)
527 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 16)
528 #endif  // !CONFIG_REALTIME_ONLY
529 
530 #undef DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD
531