1 /*
2 * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/sum_neon.h"
17 #include "config/aom_config.h"
18 #include "config/aom_dsp_rtcd.h"
19
aom_sum_squares_2d_i16_4x4_neon(const int16_t * src,int stride)20 static inline uint64_t aom_sum_squares_2d_i16_4x4_neon(const int16_t *src,
21 int stride) {
22 int16x4_t s0 = vld1_s16(src + 0 * stride);
23 int16x4_t s1 = vld1_s16(src + 1 * stride);
24 int16x4_t s2 = vld1_s16(src + 2 * stride);
25 int16x4_t s3 = vld1_s16(src + 3 * stride);
26
27 int32x4_t sum_squares = vmull_s16(s0, s0);
28 sum_squares = vmlal_s16(sum_squares, s1, s1);
29 sum_squares = vmlal_s16(sum_squares, s2, s2);
30 sum_squares = vmlal_s16(sum_squares, s3, s3);
31
32 return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sum_squares));
33 }
34
aom_sum_squares_2d_i16_4xn_neon(const int16_t * src,int stride,int height)35 static inline uint64_t aom_sum_squares_2d_i16_4xn_neon(const int16_t *src,
36 int stride, int height) {
37 int32x4_t sum_squares[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
38
39 int h = height;
40 do {
41 int16x4_t s0 = vld1_s16(src + 0 * stride);
42 int16x4_t s1 = vld1_s16(src + 1 * stride);
43 int16x4_t s2 = vld1_s16(src + 2 * stride);
44 int16x4_t s3 = vld1_s16(src + 3 * stride);
45
46 sum_squares[0] = vmlal_s16(sum_squares[0], s0, s0);
47 sum_squares[0] = vmlal_s16(sum_squares[0], s1, s1);
48 sum_squares[1] = vmlal_s16(sum_squares[1], s2, s2);
49 sum_squares[1] = vmlal_s16(sum_squares[1], s3, s3);
50
51 src += 4 * stride;
52 h -= 4;
53 } while (h != 0);
54
55 return horizontal_long_add_u32x4(
56 vreinterpretq_u32_s32(vaddq_s32(sum_squares[0], sum_squares[1])));
57 }
58
aom_sum_squares_2d_i16_nxn_neon(const int16_t * src,int stride,int width,int height)59 static inline uint64_t aom_sum_squares_2d_i16_nxn_neon(const int16_t *src,
60 int stride, int width,
61 int height) {
62 uint64x2_t sum_squares = vdupq_n_u64(0);
63
64 int h = height;
65 do {
66 int32x4_t ss_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
67 int w = 0;
68 do {
69 const int16_t *s = src + w;
70 int16x8_t s0 = vld1q_s16(s + 0 * stride);
71 int16x8_t s1 = vld1q_s16(s + 1 * stride);
72 int16x8_t s2 = vld1q_s16(s + 2 * stride);
73 int16x8_t s3 = vld1q_s16(s + 3 * stride);
74
75 ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s0), vget_low_s16(s0));
76 ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s1), vget_low_s16(s1));
77 ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s2), vget_low_s16(s2));
78 ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s3), vget_low_s16(s3));
79 ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s0), vget_high_s16(s0));
80 ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s1), vget_high_s16(s1));
81 ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s2), vget_high_s16(s2));
82 ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s3), vget_high_s16(s3));
83 w += 8;
84 } while (w < width);
85
86 sum_squares = vpadalq_u32(
87 sum_squares, vreinterpretq_u32_s32(vaddq_s32(ss_row[0], ss_row[1])));
88
89 src += 4 * stride;
90 h -= 4;
91 } while (h != 0);
92
93 return horizontal_add_u64x2(sum_squares);
94 }
95
aom_sum_squares_2d_i16_neon(const int16_t * src,int stride,int width,int height)96 uint64_t aom_sum_squares_2d_i16_neon(const int16_t *src, int stride, int width,
97 int height) {
98 // 4 elements per row only requires half an SIMD register, so this
99 // must be a special case, but also note that over 75% of all calls
100 // are with size == 4, so it is also the common case.
101 if (LIKELY(width == 4 && height == 4)) {
102 return aom_sum_squares_2d_i16_4x4_neon(src, stride);
103 } else if (LIKELY(width == 4 && (height & 3) == 0)) {
104 return aom_sum_squares_2d_i16_4xn_neon(src, stride, height);
105 } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) {
106 // Generic case
107 return aom_sum_squares_2d_i16_nxn_neon(src, stride, width, height);
108 } else {
109 return aom_sum_squares_2d_i16_c(src, stride, width, height);
110 }
111 }
112
aom_sum_sse_2d_i16_4x4_neon(const int16_t * src,int stride,int * sum)113 static inline uint64_t aom_sum_sse_2d_i16_4x4_neon(const int16_t *src,
114 int stride, int *sum) {
115 int16x4_t s0 = vld1_s16(src + 0 * stride);
116 int16x4_t s1 = vld1_s16(src + 1 * stride);
117 int16x4_t s2 = vld1_s16(src + 2 * stride);
118 int16x4_t s3 = vld1_s16(src + 3 * stride);
119
120 int32x4_t sse = vmull_s16(s0, s0);
121 sse = vmlal_s16(sse, s1, s1);
122 sse = vmlal_s16(sse, s2, s2);
123 sse = vmlal_s16(sse, s3, s3);
124
125 int32x4_t sum_01 = vaddl_s16(s0, s1);
126 int32x4_t sum_23 = vaddl_s16(s2, s3);
127 *sum += horizontal_add_s32x4(vaddq_s32(sum_01, sum_23));
128
129 return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sse));
130 }
131
aom_sum_sse_2d_i16_4xn_neon(const int16_t * src,int stride,int height,int * sum)132 static inline uint64_t aom_sum_sse_2d_i16_4xn_neon(const int16_t *src,
133 int stride, int height,
134 int *sum) {
135 int32x4_t sse[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
136 int32x2_t sum_acc[2] = { vdup_n_s32(0), vdup_n_s32(0) };
137
138 int h = height;
139 do {
140 int16x4_t s0 = vld1_s16(src + 0 * stride);
141 int16x4_t s1 = vld1_s16(src + 1 * stride);
142 int16x4_t s2 = vld1_s16(src + 2 * stride);
143 int16x4_t s3 = vld1_s16(src + 3 * stride);
144
145 sse[0] = vmlal_s16(sse[0], s0, s0);
146 sse[0] = vmlal_s16(sse[0], s1, s1);
147 sse[1] = vmlal_s16(sse[1], s2, s2);
148 sse[1] = vmlal_s16(sse[1], s3, s3);
149
150 sum_acc[0] = vpadal_s16(sum_acc[0], s0);
151 sum_acc[0] = vpadal_s16(sum_acc[0], s1);
152 sum_acc[1] = vpadal_s16(sum_acc[1], s2);
153 sum_acc[1] = vpadal_s16(sum_acc[1], s3);
154
155 src += 4 * stride;
156 h -= 4;
157 } while (h != 0);
158
159 *sum += horizontal_add_s32x4(vcombine_s32(sum_acc[0], sum_acc[1]));
160 return horizontal_long_add_u32x4(
161 vreinterpretq_u32_s32(vaddq_s32(sse[0], sse[1])));
162 }
163
aom_sum_sse_2d_i16_nxn_neon(const int16_t * src,int stride,int width,int height,int * sum)164 static inline uint64_t aom_sum_sse_2d_i16_nxn_neon(const int16_t *src,
165 int stride, int width,
166 int height, int *sum) {
167 uint64x2_t sse = vdupq_n_u64(0);
168 int32x4_t sum_acc = vdupq_n_s32(0);
169
170 int h = height;
171 do {
172 int32x4_t sse_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
173 int w = 0;
174 do {
175 const int16_t *s = src + w;
176 int16x8_t s0 = vld1q_s16(s + 0 * stride);
177 int16x8_t s1 = vld1q_s16(s + 1 * stride);
178 int16x8_t s2 = vld1q_s16(s + 2 * stride);
179 int16x8_t s3 = vld1q_s16(s + 3 * stride);
180
181 sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s0), vget_low_s16(s0));
182 sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s1), vget_low_s16(s1));
183 sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s2), vget_low_s16(s2));
184 sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s3), vget_low_s16(s3));
185 sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s0), vget_high_s16(s0));
186 sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s1), vget_high_s16(s1));
187 sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s2), vget_high_s16(s2));
188 sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s3), vget_high_s16(s3));
189
190 sum_acc = vpadalq_s16(sum_acc, s0);
191 sum_acc = vpadalq_s16(sum_acc, s1);
192 sum_acc = vpadalq_s16(sum_acc, s2);
193 sum_acc = vpadalq_s16(sum_acc, s3);
194
195 w += 8;
196 } while (w < width);
197
198 sse = vpadalq_u32(sse,
199 vreinterpretq_u32_s32(vaddq_s32(sse_row[0], sse_row[1])));
200
201 src += 4 * stride;
202 h -= 4;
203 } while (h != 0);
204
205 *sum += horizontal_add_s32x4(sum_acc);
206 return horizontal_add_u64x2(sse);
207 }
208
aom_sum_sse_2d_i16_neon(const int16_t * src,int stride,int width,int height,int * sum)209 uint64_t aom_sum_sse_2d_i16_neon(const int16_t *src, int stride, int width,
210 int height, int *sum) {
211 uint64_t sse;
212
213 if (LIKELY(width == 4 && height == 4)) {
214 sse = aom_sum_sse_2d_i16_4x4_neon(src, stride, sum);
215 } else if (LIKELY(width == 4 && (height & 3) == 0)) {
216 // width = 4, height is a multiple of 4.
217 sse = aom_sum_sse_2d_i16_4xn_neon(src, stride, height, sum);
218 } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) {
219 // Generic case - width is multiple of 8, height is multiple of 4.
220 sse = aom_sum_sse_2d_i16_nxn_neon(src, stride, width, height, sum);
221 } else {
222 sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum);
223 }
224
225 return sse;
226 }
227
aom_sum_squares_i16_4xn_neon(const int16_t * src,uint32_t n)228 static inline uint64_t aom_sum_squares_i16_4xn_neon(const int16_t *src,
229 uint32_t n) {
230 uint64x2_t sum_u64 = vdupq_n_u64(0);
231
232 int i = n;
233 do {
234 uint32x4_t sum;
235 int16x4_t s0 = vld1_s16(src);
236
237 sum = vreinterpretq_u32_s32(vmull_s16(s0, s0));
238
239 sum_u64 = vpadalq_u32(sum_u64, sum);
240
241 src += 4;
242 i -= 4;
243 } while (i >= 4);
244
245 if (i > 0) {
246 return horizontal_add_u64x2(sum_u64) + aom_sum_squares_i16_c(src, i);
247 }
248 return horizontal_add_u64x2(sum_u64);
249 }
250
aom_sum_squares_i16_8xn_neon(const int16_t * src,uint32_t n)251 static inline uint64_t aom_sum_squares_i16_8xn_neon(const int16_t *src,
252 uint32_t n) {
253 uint64x2_t sum_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
254
255 int i = n;
256 do {
257 uint32x4_t sum[2];
258 int16x8_t s0 = vld1q_s16(src);
259
260 sum[0] =
261 vreinterpretq_u32_s32(vmull_s16(vget_low_s16(s0), vget_low_s16(s0)));
262 sum[1] =
263 vreinterpretq_u32_s32(vmull_s16(vget_high_s16(s0), vget_high_s16(s0)));
264
265 sum_u64[0] = vpadalq_u32(sum_u64[0], sum[0]);
266 sum_u64[1] = vpadalq_u32(sum_u64[1], sum[1]);
267
268 src += 8;
269 i -= 8;
270 } while (i >= 8);
271
272 if (i > 0) {
273 return horizontal_add_u64x2(vaddq_u64(sum_u64[0], sum_u64[1])) +
274 aom_sum_squares_i16_c(src, i);
275 }
276 return horizontal_add_u64x2(vaddq_u64(sum_u64[0], sum_u64[1]));
277 }
278
aom_sum_squares_i16_neon(const int16_t * src,uint32_t n)279 uint64_t aom_sum_squares_i16_neon(const int16_t *src, uint32_t n) {
280 // This function seems to be called only for values of N >= 64. See
281 // av1/encoder/compound_type.c.
282 if (LIKELY(n >= 8)) {
283 return aom_sum_squares_i16_8xn_neon(src, n);
284 }
285 if (n >= 4) {
286 return aom_sum_squares_i16_4xn_neon(src, n);
287 }
288 return aom_sum_squares_i16_c(src, n);
289 }
290
aom_var_2d_u8_4xh_neon(uint8_t * src,int src_stride,int width,int height)291 static inline uint64_t aom_var_2d_u8_4xh_neon(uint8_t *src, int src_stride,
292 int width, int height) {
293 uint64_t sum = 0;
294 uint64_t sse = 0;
295 uint32x2_t sum_u32 = vdup_n_u32(0);
296 uint32x4_t sse_u32 = vdupq_n_u32(0);
297
298 // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
299 // element before we need to accumulate to 32-bit elements. Since we're
300 // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
301 // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
302 // * 256) / width.
303 int h_limit = (4 * 256) / width;
304 int h_tmp = height > h_limit ? h_limit : height;
305
306 int h = 0;
307 do {
308 uint16x4_t sum_u16 = vdup_n_u16(0);
309 do {
310 uint8_t *src_ptr = src;
311 int w = width;
312 do {
313 uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride);
314
315 sum_u16 = vpadal_u8(sum_u16, s0);
316
317 uint16x8_t sse_u16 = vmull_u8(s0, s0);
318
319 sse_u32 = vpadalq_u16(sse_u32, sse_u16);
320
321 src_ptr += 8;
322 w -= 8;
323 } while (w >= 8);
324
325 // Process remaining columns in the row using C.
326 while (w > 0) {
327 int idx = width - w;
328 const uint8_t v = src[idx];
329 sum += v;
330 sse += v * v;
331 w--;
332 }
333
334 src += 2 * src_stride;
335 h += 2;
336 } while (h < h_tmp && h < height);
337
338 sum_u32 = vpadal_u16(sum_u32, sum_u16);
339 h_tmp += h_limit;
340 } while (h < height);
341
342 sum += horizontal_long_add_u32x2(sum_u32);
343 sse += horizontal_long_add_u32x4(sse_u32);
344
345 return sse - sum * sum / (width * height);
346 }
347
aom_var_2d_u8_8xh_neon(uint8_t * src,int src_stride,int width,int height)348 static inline uint64_t aom_var_2d_u8_8xh_neon(uint8_t *src, int src_stride,
349 int width, int height) {
350 uint64_t sum = 0;
351 uint64_t sse = 0;
352 uint32x2_t sum_u32 = vdup_n_u32(0);
353 uint32x4_t sse_u32 = vdupq_n_u32(0);
354
355 // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
356 // element before we need to accumulate to 32-bit elements. Since we're
357 // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
358 // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
359 // * 256) / width.
360 int h_limit = (4 * 256) / width;
361 int h_tmp = height > h_limit ? h_limit : height;
362
363 int h = 0;
364 do {
365 uint16x4_t sum_u16 = vdup_n_u16(0);
366 do {
367 uint8_t *src_ptr = src;
368 int w = width;
369 do {
370 uint8x8_t s0 = vld1_u8(src_ptr);
371
372 sum_u16 = vpadal_u8(sum_u16, s0);
373
374 uint16x8_t sse_u16 = vmull_u8(s0, s0);
375
376 sse_u32 = vpadalq_u16(sse_u32, sse_u16);
377
378 src_ptr += 8;
379 w -= 8;
380 } while (w >= 8);
381
382 // Process remaining columns in the row using C.
383 while (w > 0) {
384 int idx = width - w;
385 const uint8_t v = src[idx];
386 sum += v;
387 sse += v * v;
388 w--;
389 }
390
391 src += src_stride;
392 ++h;
393 } while (h < h_tmp && h < height);
394
395 sum_u32 = vpadal_u16(sum_u32, sum_u16);
396 h_tmp += h_limit;
397 } while (h < height);
398
399 sum += horizontal_long_add_u32x2(sum_u32);
400 sse += horizontal_long_add_u32x4(sse_u32);
401
402 return sse - sum * sum / (width * height);
403 }
404
aom_var_2d_u8_16xh_neon(uint8_t * src,int src_stride,int width,int height)405 static inline uint64_t aom_var_2d_u8_16xh_neon(uint8_t *src, int src_stride,
406 int width, int height) {
407 uint64_t sum = 0;
408 uint64_t sse = 0;
409 uint32x4_t sum_u32 = vdupq_n_u32(0);
410 uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
411
412 // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
413 // element before we need to accumulate to 32-bit elements. Since we're
414 // accumulating in uint16x8_t vectors, this means we can accumulate up to 8
415 // rows of 256 elements. Therefore the limit can be computed as: h_limit = (8
416 // * 256) / width.
417 int h_limit = (8 * 256) / width;
418 int h_tmp = height > h_limit ? h_limit : height;
419
420 int h = 0;
421 do {
422 uint16x8_t sum_u16 = vdupq_n_u16(0);
423 do {
424 int w = width;
425 uint8_t *src_ptr = src;
426 do {
427 uint8x16_t s0 = vld1q_u8(src_ptr);
428
429 sum_u16 = vpadalq_u8(sum_u16, s0);
430
431 uint16x8_t sse_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(s0));
432 uint16x8_t sse_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(s0));
433
434 sse_u32[0] = vpadalq_u16(sse_u32[0], sse_u16_lo);
435 sse_u32[1] = vpadalq_u16(sse_u32[1], sse_u16_hi);
436
437 src_ptr += 16;
438 w -= 16;
439 } while (w >= 16);
440
441 // Process remaining columns in the row using C.
442 while (w > 0) {
443 int idx = width - w;
444 const uint8_t v = src[idx];
445 sum += v;
446 sse += v * v;
447 w--;
448 }
449
450 src += src_stride;
451 ++h;
452 } while (h < h_tmp && h < height);
453
454 sum_u32 = vpadalq_u16(sum_u32, sum_u16);
455 h_tmp += h_limit;
456 } while (h < height);
457
458 sum += horizontal_long_add_u32x4(sum_u32);
459 sse += horizontal_long_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
460
461 return sse - sum * sum / (width * height);
462 }
463
aom_var_2d_u8_neon(uint8_t * src,int src_stride,int width,int height)464 uint64_t aom_var_2d_u8_neon(uint8_t *src, int src_stride, int width,
465 int height) {
466 if (width >= 16) {
467 return aom_var_2d_u8_16xh_neon(src, src_stride, width, height);
468 }
469 if (width >= 8) {
470 return aom_var_2d_u8_8xh_neon(src, src_stride, width, height);
471 }
472 if (width >= 4 && height % 2 == 0) {
473 return aom_var_2d_u8_4xh_neon(src, src_stride, width, height);
474 }
475 return aom_var_2d_u8_c(src, src_stride, width, height);
476 }
477
478 #if CONFIG_AV1_HIGHBITDEPTH
aom_var_2d_u16_4xh_neon(uint8_t * src,int src_stride,int width,int height)479 static inline uint64_t aom_var_2d_u16_4xh_neon(uint8_t *src, int src_stride,
480 int width, int height) {
481 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
482 uint64_t sum = 0;
483 uint64_t sse = 0;
484 uint32x2_t sum_u32 = vdup_n_u32(0);
485 uint64x2_t sse_u64 = vdupq_n_u64(0);
486
487 int h = height;
488 do {
489 int w = width;
490 uint16_t *src_ptr = src_u16;
491 do {
492 uint16x4_t s0 = vld1_u16(src_ptr);
493
494 sum_u32 = vpadal_u16(sum_u32, s0);
495
496 uint32x4_t sse_u32 = vmull_u16(s0, s0);
497
498 sse_u64 = vpadalq_u32(sse_u64, sse_u32);
499
500 src_ptr += 4;
501 w -= 4;
502 } while (w >= 4);
503
504 // Process remaining columns in the row using C.
505 while (w > 0) {
506 int idx = width - w;
507 const uint16_t v = src_u16[idx];
508 sum += v;
509 sse += v * v;
510 w--;
511 }
512
513 src_u16 += src_stride;
514 } while (--h != 0);
515
516 sum += horizontal_long_add_u32x2(sum_u32);
517 sse += horizontal_add_u64x2(sse_u64);
518
519 return sse - sum * sum / (width * height);
520 }
521
aom_var_2d_u16_8xh_neon(uint8_t * src,int src_stride,int width,int height)522 static inline uint64_t aom_var_2d_u16_8xh_neon(uint8_t *src, int src_stride,
523 int width, int height) {
524 uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
525 uint64_t sum = 0;
526 uint64_t sse = 0;
527 uint32x4_t sum_u32 = vdupq_n_u32(0);
528 uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
529
530 int h = height;
531 do {
532 int w = width;
533 uint16_t *src_ptr = src_u16;
534 do {
535 uint16x8_t s0 = vld1q_u16(src_ptr);
536
537 sum_u32 = vpadalq_u16(sum_u32, s0);
538
539 uint32x4_t sse_u32_lo = vmull_u16(vget_low_u16(s0), vget_low_u16(s0));
540 uint32x4_t sse_u32_hi = vmull_u16(vget_high_u16(s0), vget_high_u16(s0));
541
542 sse_u64[0] = vpadalq_u32(sse_u64[0], sse_u32_lo);
543 sse_u64[1] = vpadalq_u32(sse_u64[1], sse_u32_hi);
544
545 src_ptr += 8;
546 w -= 8;
547 } while (w >= 8);
548
549 // Process remaining columns in the row using C.
550 while (w > 0) {
551 int idx = width - w;
552 const uint16_t v = src_u16[idx];
553 sum += v;
554 sse += v * v;
555 w--;
556 }
557
558 src_u16 += src_stride;
559 } while (--h != 0);
560
561 sum += horizontal_long_add_u32x4(sum_u32);
562 sse += horizontal_add_u64x2(vaddq_u64(sse_u64[0], sse_u64[1]));
563
564 return sse - sum * sum / (width * height);
565 }
566
aom_var_2d_u16_neon(uint8_t * src,int src_stride,int width,int height)567 uint64_t aom_var_2d_u16_neon(uint8_t *src, int src_stride, int width,
568 int height) {
569 if (width >= 8) {
570 return aom_var_2d_u16_8xh_neon(src, src_stride, width, height);
571 }
572 if (width >= 4) {
573 return aom_var_2d_u16_4xh_neon(src, src_stride, width, height);
574 }
575 return aom_var_2d_u16_c(src, src_stride, width, height);
576 }
577 #endif // CONFIG_AV1_HIGHBITDEPTH
578