1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/dist_wtd_avg_neon.h"
19 #include "aom_dsp/arm/mem_neon.h"
20 #include "aom_dsp/arm/sum_neon.h"
21
sadwxh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int w,int h)22 static inline unsigned int sadwxh_neon_dotprod(const uint8_t *src_ptr,
23 int src_stride,
24 const uint8_t *ref_ptr,
25 int ref_stride, int w, int h) {
26 // Only two accumulators are required for optimal instruction throughput of
27 // the ABD, UDOT sequence on CPUs with either 2 or 4 Neon pipes.
28 uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
29
30 int i = h;
31 do {
32 int j = 0;
33 do {
34 uint8x16_t s0, s1, r0, r1, diff0, diff1;
35
36 s0 = vld1q_u8(src_ptr + j);
37 r0 = vld1q_u8(ref_ptr + j);
38 diff0 = vabdq_u8(s0, r0);
39 sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
40
41 s1 = vld1q_u8(src_ptr + j + 16);
42 r1 = vld1q_u8(ref_ptr + j + 16);
43 diff1 = vabdq_u8(s1, r1);
44 sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
45
46 j += 32;
47 } while (j < w);
48
49 src_ptr += src_stride;
50 ref_ptr += ref_stride;
51 } while (--i != 0);
52
53 return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
54 }
55
sad128xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)56 static inline unsigned int sad128xh_neon_dotprod(const uint8_t *src_ptr,
57 int src_stride,
58 const uint8_t *ref_ptr,
59 int ref_stride, int h) {
60 return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 128, h);
61 }
62
sad64xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)63 static inline unsigned int sad64xh_neon_dotprod(const uint8_t *src_ptr,
64 int src_stride,
65 const uint8_t *ref_ptr,
66 int ref_stride, int h) {
67 return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 64, h);
68 }
69
sad32xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)70 static inline unsigned int sad32xh_neon_dotprod(const uint8_t *src_ptr,
71 int src_stride,
72 const uint8_t *ref_ptr,
73 int ref_stride, int h) {
74 return sadwxh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 32, h);
75 }
76
sad16xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h)77 static inline unsigned int sad16xh_neon_dotprod(const uint8_t *src_ptr,
78 int src_stride,
79 const uint8_t *ref_ptr,
80 int ref_stride, int h) {
81 uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
82
83 int i = h / 2;
84 do {
85 uint8x16_t s0, s1, r0, r1, diff0, diff1;
86
87 s0 = vld1q_u8(src_ptr);
88 r0 = vld1q_u8(ref_ptr);
89 diff0 = vabdq_u8(s0, r0);
90 sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
91
92 src_ptr += src_stride;
93 ref_ptr += ref_stride;
94
95 s1 = vld1q_u8(src_ptr);
96 r1 = vld1q_u8(ref_ptr);
97 diff1 = vabdq_u8(s1, r1);
98 sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
99
100 src_ptr += src_stride;
101 ref_ptr += ref_stride;
102 } while (--i != 0);
103
104 return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
105 }
106
107 #define SAD_WXH_NEON_DOTPROD(w, h) \
108 unsigned int aom_sad##w##x##h##_neon_dotprod( \
109 const uint8_t *src, int src_stride, const uint8_t *ref, \
110 int ref_stride) { \
111 return sad##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, (h)); \
112 }
113
114 SAD_WXH_NEON_DOTPROD(16, 8)
115 SAD_WXH_NEON_DOTPROD(16, 16)
116 SAD_WXH_NEON_DOTPROD(16, 32)
117
118 SAD_WXH_NEON_DOTPROD(32, 16)
119 SAD_WXH_NEON_DOTPROD(32, 32)
120 SAD_WXH_NEON_DOTPROD(32, 64)
121
122 SAD_WXH_NEON_DOTPROD(64, 32)
123 SAD_WXH_NEON_DOTPROD(64, 64)
124 SAD_WXH_NEON_DOTPROD(64, 128)
125
126 SAD_WXH_NEON_DOTPROD(128, 64)
127 SAD_WXH_NEON_DOTPROD(128, 128)
128
129 #if !CONFIG_REALTIME_ONLY
130 SAD_WXH_NEON_DOTPROD(16, 4)
131 SAD_WXH_NEON_DOTPROD(16, 64)
132 SAD_WXH_NEON_DOTPROD(32, 8)
133 SAD_WXH_NEON_DOTPROD(64, 16)
134 #endif // !CONFIG_REALTIME_ONLY
135
136 #undef SAD_WXH_NEON_DOTPROD
137
138 #define SAD_SKIP_WXH_NEON_DOTPROD(w, h) \
139 unsigned int aom_sad_skip_##w##x##h##_neon_dotprod( \
140 const uint8_t *src, int src_stride, const uint8_t *ref, \
141 int ref_stride) { \
142 return 2 * sad##w##xh_neon_dotprod(src, 2 * src_stride, ref, \
143 2 * ref_stride, (h) / 2); \
144 }
145
146 SAD_SKIP_WXH_NEON_DOTPROD(16, 8)
147 SAD_SKIP_WXH_NEON_DOTPROD(16, 16)
148 SAD_SKIP_WXH_NEON_DOTPROD(16, 32)
149
150 SAD_SKIP_WXH_NEON_DOTPROD(32, 16)
151 SAD_SKIP_WXH_NEON_DOTPROD(32, 32)
152 SAD_SKIP_WXH_NEON_DOTPROD(32, 64)
153
154 SAD_SKIP_WXH_NEON_DOTPROD(64, 32)
155 SAD_SKIP_WXH_NEON_DOTPROD(64, 64)
156 SAD_SKIP_WXH_NEON_DOTPROD(64, 128)
157
158 SAD_SKIP_WXH_NEON_DOTPROD(128, 64)
159 SAD_SKIP_WXH_NEON_DOTPROD(128, 128)
160
161 #if !CONFIG_REALTIME_ONLY
162 SAD_SKIP_WXH_NEON_DOTPROD(16, 4)
163 SAD_SKIP_WXH_NEON_DOTPROD(16, 64)
164 SAD_SKIP_WXH_NEON_DOTPROD(32, 8)
165 SAD_SKIP_WXH_NEON_DOTPROD(64, 16)
166 #endif // !CONFIG_REALTIME_ONLY
167
168 #undef SAD_SKIP_WXH_NEON_DOTPROD
169
sadwxh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int w,int h,const uint8_t * second_pred)170 static inline unsigned int sadwxh_avg_neon_dotprod(const uint8_t *src_ptr,
171 int src_stride,
172 const uint8_t *ref_ptr,
173 int ref_stride, int w, int h,
174 const uint8_t *second_pred) {
175 // Only two accumulators are required for optimal instruction throughput of
176 // the ABD, UDOT sequence on CPUs with either 2 or 4 Neon pipes.
177 uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
178
179 int i = h;
180 do {
181 int j = 0;
182 do {
183 uint8x16_t s0, s1, r0, r1, p0, p1, avg0, avg1, diff0, diff1;
184
185 s0 = vld1q_u8(src_ptr + j);
186 r0 = vld1q_u8(ref_ptr + j);
187 p0 = vld1q_u8(second_pred);
188 avg0 = vrhaddq_u8(r0, p0);
189 diff0 = vabdq_u8(s0, avg0);
190 sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
191
192 s1 = vld1q_u8(src_ptr + j + 16);
193 r1 = vld1q_u8(ref_ptr + j + 16);
194 p1 = vld1q_u8(second_pred + 16);
195 avg1 = vrhaddq_u8(r1, p1);
196 diff1 = vabdq_u8(s1, avg1);
197 sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
198
199 j += 32;
200 second_pred += 32;
201 } while (j < w);
202
203 src_ptr += src_stride;
204 ref_ptr += ref_stride;
205 } while (--i != 0);
206
207 return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
208 }
209
sad128xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)210 static inline unsigned int sad128xh_avg_neon_dotprod(
211 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
212 int ref_stride, int h, const uint8_t *second_pred) {
213 return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 128,
214 h, second_pred);
215 }
216
sad64xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)217 static inline unsigned int sad64xh_avg_neon_dotprod(
218 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
219 int ref_stride, int h, const uint8_t *second_pred) {
220 return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 64,
221 h, second_pred);
222 }
223
sad32xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)224 static inline unsigned int sad32xh_avg_neon_dotprod(
225 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
226 int ref_stride, int h, const uint8_t *second_pred) {
227 return sadwxh_avg_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 32,
228 h, second_pred);
229 }
230
sad16xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred)231 static inline unsigned int sad16xh_avg_neon_dotprod(
232 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
233 int ref_stride, int h, const uint8_t *second_pred) {
234 uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
235
236 int i = h / 2;
237 do {
238 uint8x16_t s0, s1, r0, r1, p0, p1, avg0, avg1, diff0, diff1;
239
240 s0 = vld1q_u8(src_ptr);
241 r0 = vld1q_u8(ref_ptr);
242 p0 = vld1q_u8(second_pred);
243 avg0 = vrhaddq_u8(r0, p0);
244 diff0 = vabdq_u8(s0, avg0);
245 sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
246
247 src_ptr += src_stride;
248 ref_ptr += ref_stride;
249 second_pred += 16;
250
251 s1 = vld1q_u8(src_ptr);
252 r1 = vld1q_u8(ref_ptr);
253 p1 = vld1q_u8(second_pred);
254 avg1 = vrhaddq_u8(r1, p1);
255 diff1 = vabdq_u8(s1, avg1);
256 sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
257
258 src_ptr += src_stride;
259 ref_ptr += ref_stride;
260 second_pred += 16;
261 } while (--i != 0);
262
263 return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
264 }
265
266 #define SAD_WXH_AVG_NEON_DOTPROD(w, h) \
267 unsigned int aom_sad##w##x##h##_avg_neon_dotprod( \
268 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
269 const uint8_t *second_pred) { \
270 return sad##w##xh_avg_neon_dotprod(src, src_stride, ref, ref_stride, (h), \
271 second_pred); \
272 }
273
274 SAD_WXH_AVG_NEON_DOTPROD(16, 8)
275 SAD_WXH_AVG_NEON_DOTPROD(16, 16)
276 SAD_WXH_AVG_NEON_DOTPROD(16, 32)
277
278 SAD_WXH_AVG_NEON_DOTPROD(32, 16)
279 SAD_WXH_AVG_NEON_DOTPROD(32, 32)
280 SAD_WXH_AVG_NEON_DOTPROD(32, 64)
281
282 SAD_WXH_AVG_NEON_DOTPROD(64, 32)
283 SAD_WXH_AVG_NEON_DOTPROD(64, 64)
284 SAD_WXH_AVG_NEON_DOTPROD(64, 128)
285
286 SAD_WXH_AVG_NEON_DOTPROD(128, 64)
287 SAD_WXH_AVG_NEON_DOTPROD(128, 128)
288
289 #if !CONFIG_REALTIME_ONLY
290 SAD_WXH_AVG_NEON_DOTPROD(16, 4)
291 SAD_WXH_AVG_NEON_DOTPROD(16, 64)
292 SAD_WXH_AVG_NEON_DOTPROD(32, 8)
293 SAD_WXH_AVG_NEON_DOTPROD(64, 16)
294 #endif // !CONFIG_REALTIME_ONLY
295
296 #undef SAD_WXH_AVG_NEON_DOTPROD
297
dist_wtd_sad128xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)298 static inline unsigned int dist_wtd_sad128xh_avg_neon_dotprod(
299 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
300 int ref_stride, int h, const uint8_t *second_pred,
301 const DIST_WTD_COMP_PARAMS *jcp_param) {
302 const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
303 const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
304 // We use 8 accumulators to minimize the accumulation and loop carried
305 // dependencies for better instruction throughput.
306 uint32x4_t sum[8] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
307 vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
308 vdupq_n_u32(0), vdupq_n_u32(0) };
309
310 do {
311 uint8x16_t s0 = vld1q_u8(src_ptr);
312 uint8x16_t r0 = vld1q_u8(ref_ptr);
313 uint8x16_t p0 = vld1q_u8(second_pred);
314 uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
315 uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
316 sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
317
318 uint8x16_t s1 = vld1q_u8(src_ptr + 16);
319 uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
320 uint8x16_t p1 = vld1q_u8(second_pred + 16);
321 uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
322 uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
323 sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
324
325 uint8x16_t s2 = vld1q_u8(src_ptr + 32);
326 uint8x16_t r2 = vld1q_u8(ref_ptr + 32);
327 uint8x16_t p2 = vld1q_u8(second_pred + 32);
328 uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset);
329 uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2);
330 sum[2] = vdotq_u32(sum[2], diff2, vdupq_n_u8(1));
331
332 uint8x16_t s3 = vld1q_u8(src_ptr + 48);
333 uint8x16_t r3 = vld1q_u8(ref_ptr + 48);
334 uint8x16_t p3 = vld1q_u8(second_pred + 48);
335 uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset);
336 uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3);
337 sum[3] = vdotq_u32(sum[3], diff3, vdupq_n_u8(1));
338
339 uint8x16_t s4 = vld1q_u8(src_ptr + 64);
340 uint8x16_t r4 = vld1q_u8(ref_ptr + 64);
341 uint8x16_t p4 = vld1q_u8(second_pred + 64);
342 uint8x16_t wtd_avg4 = dist_wtd_avg_u8x16(p4, r4, bck_offset, fwd_offset);
343 uint8x16_t diff4 = vabdq_u8(s4, wtd_avg4);
344 sum[4] = vdotq_u32(sum[4], diff4, vdupq_n_u8(1));
345
346 uint8x16_t s5 = vld1q_u8(src_ptr + 80);
347 uint8x16_t r5 = vld1q_u8(ref_ptr + 80);
348 uint8x16_t p5 = vld1q_u8(second_pred + 80);
349 uint8x16_t wtd_avg5 = dist_wtd_avg_u8x16(p5, r5, bck_offset, fwd_offset);
350 uint8x16_t diff5 = vabdq_u8(s5, wtd_avg5);
351 sum[5] = vdotq_u32(sum[5], diff5, vdupq_n_u8(1));
352
353 uint8x16_t s6 = vld1q_u8(src_ptr + 96);
354 uint8x16_t r6 = vld1q_u8(ref_ptr + 96);
355 uint8x16_t p6 = vld1q_u8(second_pred + 96);
356 uint8x16_t wtd_avg6 = dist_wtd_avg_u8x16(p6, r6, bck_offset, fwd_offset);
357 uint8x16_t diff6 = vabdq_u8(s6, wtd_avg6);
358 sum[6] = vdotq_u32(sum[6], diff6, vdupq_n_u8(1));
359
360 uint8x16_t s7 = vld1q_u8(src_ptr + 112);
361 uint8x16_t r7 = vld1q_u8(ref_ptr + 112);
362 uint8x16_t p7 = vld1q_u8(second_pred + 112);
363 uint8x16_t wtd_avg7 = dist_wtd_avg_u8x16(p7, r7, bck_offset, fwd_offset);
364 uint8x16_t diff7 = vabdq_u8(s7, wtd_avg7);
365 sum[7] = vdotq_u32(sum[7], diff7, vdupq_n_u8(1));
366
367 src_ptr += src_stride;
368 ref_ptr += ref_stride;
369 second_pred += 128;
370 } while (--h != 0);
371
372 sum[0] = vaddq_u32(sum[0], sum[1]);
373 sum[2] = vaddq_u32(sum[2], sum[3]);
374 sum[4] = vaddq_u32(sum[4], sum[5]);
375 sum[6] = vaddq_u32(sum[6], sum[7]);
376 sum[0] = vaddq_u32(sum[0], sum[2]);
377 sum[4] = vaddq_u32(sum[4], sum[6]);
378 sum[0] = vaddq_u32(sum[0], sum[4]);
379 return horizontal_add_u32x4(sum[0]);
380 }
381
dist_wtd_sad64xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)382 static inline unsigned int dist_wtd_sad64xh_avg_neon_dotprod(
383 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
384 int ref_stride, int h, const uint8_t *second_pred,
385 const DIST_WTD_COMP_PARAMS *jcp_param) {
386 const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
387 const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
388 uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
389 vdupq_n_u32(0) };
390
391 do {
392 uint8x16_t s0 = vld1q_u8(src_ptr);
393 uint8x16_t r0 = vld1q_u8(ref_ptr);
394 uint8x16_t p0 = vld1q_u8(second_pred);
395 uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
396 uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
397 sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
398
399 uint8x16_t s1 = vld1q_u8(src_ptr + 16);
400 uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
401 uint8x16_t p1 = vld1q_u8(second_pred + 16);
402 uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
403 uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
404 sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
405
406 uint8x16_t s2 = vld1q_u8(src_ptr + 32);
407 uint8x16_t r2 = vld1q_u8(ref_ptr + 32);
408 uint8x16_t p2 = vld1q_u8(second_pred + 32);
409 uint8x16_t wtd_avg2 = dist_wtd_avg_u8x16(p2, r2, bck_offset, fwd_offset);
410 uint8x16_t diff2 = vabdq_u8(s2, wtd_avg2);
411 sum[2] = vdotq_u32(sum[2], diff2, vdupq_n_u8(1));
412
413 uint8x16_t s3 = vld1q_u8(src_ptr + 48);
414 uint8x16_t r3 = vld1q_u8(ref_ptr + 48);
415 uint8x16_t p3 = vld1q_u8(second_pred + 48);
416 uint8x16_t wtd_avg3 = dist_wtd_avg_u8x16(p3, r3, bck_offset, fwd_offset);
417 uint8x16_t diff3 = vabdq_u8(s3, wtd_avg3);
418 sum[3] = vdotq_u32(sum[3], diff3, vdupq_n_u8(1));
419
420 src_ptr += src_stride;
421 ref_ptr += ref_stride;
422 second_pred += 64;
423 } while (--h != 0);
424
425 sum[0] = vaddq_u32(sum[0], sum[1]);
426 sum[2] = vaddq_u32(sum[2], sum[3]);
427 sum[0] = vaddq_u32(sum[0], sum[2]);
428 return horizontal_add_u32x4(sum[0]);
429 }
430
dist_wtd_sad32xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)431 static inline unsigned int dist_wtd_sad32xh_avg_neon_dotprod(
432 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
433 int ref_stride, int h, const uint8_t *second_pred,
434 const DIST_WTD_COMP_PARAMS *jcp_param) {
435 const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
436 const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
437 uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
438
439 do {
440 uint8x16_t s0 = vld1q_u8(src_ptr);
441 uint8x16_t r0 = vld1q_u8(ref_ptr);
442 uint8x16_t p0 = vld1q_u8(second_pred);
443 uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
444 uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
445 sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
446
447 uint8x16_t s1 = vld1q_u8(src_ptr + 16);
448 uint8x16_t r1 = vld1q_u8(ref_ptr + 16);
449 uint8x16_t p1 = vld1q_u8(second_pred + 16);
450 uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
451 uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
452 sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
453
454 src_ptr += src_stride;
455 ref_ptr += ref_stride;
456 second_pred += 32;
457 } while (--h != 0);
458
459 sum[0] = vaddq_u32(sum[0], sum[1]);
460 return horizontal_add_u32x4(sum[0]);
461 }
462
dist_wtd_sad16xh_avg_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,const uint8_t * second_pred,const DIST_WTD_COMP_PARAMS * jcp_param)463 static inline unsigned int dist_wtd_sad16xh_avg_neon_dotprod(
464 const uint8_t *src_ptr, int src_stride, const uint8_t *ref_ptr,
465 int ref_stride, int h, const uint8_t *second_pred,
466 const DIST_WTD_COMP_PARAMS *jcp_param) {
467 const uint8x16_t fwd_offset = vdupq_n_u8(jcp_param->fwd_offset);
468 const uint8x16_t bck_offset = vdupq_n_u8(jcp_param->bck_offset);
469 uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
470
471 int i = h / 2;
472 do {
473 uint8x16_t s0 = vld1q_u8(src_ptr);
474 uint8x16_t r0 = vld1q_u8(ref_ptr);
475 uint8x16_t p0 = vld1q_u8(second_pred);
476 uint8x16_t wtd_avg0 = dist_wtd_avg_u8x16(p0, r0, bck_offset, fwd_offset);
477 uint8x16_t diff0 = vabdq_u8(s0, wtd_avg0);
478 sum[0] = vdotq_u32(sum[0], diff0, vdupq_n_u8(1));
479
480 src_ptr += src_stride;
481 ref_ptr += ref_stride;
482 second_pred += 16;
483
484 uint8x16_t s1 = vld1q_u8(src_ptr);
485 uint8x16_t r1 = vld1q_u8(ref_ptr);
486 uint8x16_t p1 = vld1q_u8(second_pred);
487 uint8x16_t wtd_avg1 = dist_wtd_avg_u8x16(p1, r1, bck_offset, fwd_offset);
488 uint8x16_t diff1 = vabdq_u8(s1, wtd_avg1);
489 sum[1] = vdotq_u32(sum[1], diff1, vdupq_n_u8(1));
490
491 src_ptr += src_stride;
492 ref_ptr += ref_stride;
493 second_pred += 16;
494 } while (--i != 0);
495
496 sum[0] = vaddq_u32(sum[0], sum[1]);
497 return horizontal_add_u32x4(sum[0]);
498 }
499
500 #define DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(w, h) \
501 unsigned int aom_dist_wtd_sad##w##x##h##_avg_neon_dotprod( \
502 const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
503 const uint8_t *second_pred, const DIST_WTD_COMP_PARAMS *jcp_param) { \
504 return dist_wtd_sad##w##xh_avg_neon_dotprod( \
505 src, src_stride, ref, ref_stride, (h), second_pred, jcp_param); \
506 }
507
508 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 8)
509 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 16)
510 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 32)
511
512 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 16)
513 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 32)
514 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 64)
515
516 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 32)
517 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 64)
518 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 128)
519
520 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(128, 64)
521 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(128, 128)
522
523 #if !CONFIG_REALTIME_ONLY
524 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 4)
525 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(16, 64)
526 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(32, 8)
527 DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD(64, 16)
528 #endif // !CONFIG_REALTIME_ONLY
529
530 #undef DIST_WTD_SAD_WXH_AVG_NEON_DOTPROD
531