xref: /aosp_15_r20/external/libaom/aom_dsp/arm/sum_squares_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "aom_dsp/arm/sum_neon.h"
17 #include "config/aom_config.h"
18 #include "config/aom_dsp_rtcd.h"
19 
aom_sum_squares_2d_i16_4x4_neon(const int16_t * src,int stride)20 static inline uint64_t aom_sum_squares_2d_i16_4x4_neon(const int16_t *src,
21                                                        int stride) {
22   int16x4_t s0 = vld1_s16(src + 0 * stride);
23   int16x4_t s1 = vld1_s16(src + 1 * stride);
24   int16x4_t s2 = vld1_s16(src + 2 * stride);
25   int16x4_t s3 = vld1_s16(src + 3 * stride);
26 
27   int32x4_t sum_squares = vmull_s16(s0, s0);
28   sum_squares = vmlal_s16(sum_squares, s1, s1);
29   sum_squares = vmlal_s16(sum_squares, s2, s2);
30   sum_squares = vmlal_s16(sum_squares, s3, s3);
31 
32   return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sum_squares));
33 }
34 
aom_sum_squares_2d_i16_4xn_neon(const int16_t * src,int stride,int height)35 static inline uint64_t aom_sum_squares_2d_i16_4xn_neon(const int16_t *src,
36                                                        int stride, int height) {
37   int32x4_t sum_squares[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
38 
39   int h = height;
40   do {
41     int16x4_t s0 = vld1_s16(src + 0 * stride);
42     int16x4_t s1 = vld1_s16(src + 1 * stride);
43     int16x4_t s2 = vld1_s16(src + 2 * stride);
44     int16x4_t s3 = vld1_s16(src + 3 * stride);
45 
46     sum_squares[0] = vmlal_s16(sum_squares[0], s0, s0);
47     sum_squares[0] = vmlal_s16(sum_squares[0], s1, s1);
48     sum_squares[1] = vmlal_s16(sum_squares[1], s2, s2);
49     sum_squares[1] = vmlal_s16(sum_squares[1], s3, s3);
50 
51     src += 4 * stride;
52     h -= 4;
53   } while (h != 0);
54 
55   return horizontal_long_add_u32x4(
56       vreinterpretq_u32_s32(vaddq_s32(sum_squares[0], sum_squares[1])));
57 }
58 
aom_sum_squares_2d_i16_nxn_neon(const int16_t * src,int stride,int width,int height)59 static inline uint64_t aom_sum_squares_2d_i16_nxn_neon(const int16_t *src,
60                                                        int stride, int width,
61                                                        int height) {
62   uint64x2_t sum_squares = vdupq_n_u64(0);
63 
64   int h = height;
65   do {
66     int32x4_t ss_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
67     int w = 0;
68     do {
69       const int16_t *s = src + w;
70       int16x8_t s0 = vld1q_s16(s + 0 * stride);
71       int16x8_t s1 = vld1q_s16(s + 1 * stride);
72       int16x8_t s2 = vld1q_s16(s + 2 * stride);
73       int16x8_t s3 = vld1q_s16(s + 3 * stride);
74 
75       ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s0), vget_low_s16(s0));
76       ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s1), vget_low_s16(s1));
77       ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s2), vget_low_s16(s2));
78       ss_row[0] = vmlal_s16(ss_row[0], vget_low_s16(s3), vget_low_s16(s3));
79       ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s0), vget_high_s16(s0));
80       ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s1), vget_high_s16(s1));
81       ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s2), vget_high_s16(s2));
82       ss_row[1] = vmlal_s16(ss_row[1], vget_high_s16(s3), vget_high_s16(s3));
83       w += 8;
84     } while (w < width);
85 
86     sum_squares = vpadalq_u32(
87         sum_squares, vreinterpretq_u32_s32(vaddq_s32(ss_row[0], ss_row[1])));
88 
89     src += 4 * stride;
90     h -= 4;
91   } while (h != 0);
92 
93   return horizontal_add_u64x2(sum_squares);
94 }
95 
aom_sum_squares_2d_i16_neon(const int16_t * src,int stride,int width,int height)96 uint64_t aom_sum_squares_2d_i16_neon(const int16_t *src, int stride, int width,
97                                      int height) {
98   // 4 elements per row only requires half an SIMD register, so this
99   // must be a special case, but also note that over 75% of all calls
100   // are with size == 4, so it is also the common case.
101   if (LIKELY(width == 4 && height == 4)) {
102     return aom_sum_squares_2d_i16_4x4_neon(src, stride);
103   } else if (LIKELY(width == 4 && (height & 3) == 0)) {
104     return aom_sum_squares_2d_i16_4xn_neon(src, stride, height);
105   } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) {
106     // Generic case
107     return aom_sum_squares_2d_i16_nxn_neon(src, stride, width, height);
108   } else {
109     return aom_sum_squares_2d_i16_c(src, stride, width, height);
110   }
111 }
112 
aom_sum_sse_2d_i16_4x4_neon(const int16_t * src,int stride,int * sum)113 static inline uint64_t aom_sum_sse_2d_i16_4x4_neon(const int16_t *src,
114                                                    int stride, int *sum) {
115   int16x4_t s0 = vld1_s16(src + 0 * stride);
116   int16x4_t s1 = vld1_s16(src + 1 * stride);
117   int16x4_t s2 = vld1_s16(src + 2 * stride);
118   int16x4_t s3 = vld1_s16(src + 3 * stride);
119 
120   int32x4_t sse = vmull_s16(s0, s0);
121   sse = vmlal_s16(sse, s1, s1);
122   sse = vmlal_s16(sse, s2, s2);
123   sse = vmlal_s16(sse, s3, s3);
124 
125   int32x4_t sum_01 = vaddl_s16(s0, s1);
126   int32x4_t sum_23 = vaddl_s16(s2, s3);
127   *sum += horizontal_add_s32x4(vaddq_s32(sum_01, sum_23));
128 
129   return horizontal_long_add_u32x4(vreinterpretq_u32_s32(sse));
130 }
131 
aom_sum_sse_2d_i16_4xn_neon(const int16_t * src,int stride,int height,int * sum)132 static inline uint64_t aom_sum_sse_2d_i16_4xn_neon(const int16_t *src,
133                                                    int stride, int height,
134                                                    int *sum) {
135   int32x4_t sse[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
136   int32x2_t sum_acc[2] = { vdup_n_s32(0), vdup_n_s32(0) };
137 
138   int h = height;
139   do {
140     int16x4_t s0 = vld1_s16(src + 0 * stride);
141     int16x4_t s1 = vld1_s16(src + 1 * stride);
142     int16x4_t s2 = vld1_s16(src + 2 * stride);
143     int16x4_t s3 = vld1_s16(src + 3 * stride);
144 
145     sse[0] = vmlal_s16(sse[0], s0, s0);
146     sse[0] = vmlal_s16(sse[0], s1, s1);
147     sse[1] = vmlal_s16(sse[1], s2, s2);
148     sse[1] = vmlal_s16(sse[1], s3, s3);
149 
150     sum_acc[0] = vpadal_s16(sum_acc[0], s0);
151     sum_acc[0] = vpadal_s16(sum_acc[0], s1);
152     sum_acc[1] = vpadal_s16(sum_acc[1], s2);
153     sum_acc[1] = vpadal_s16(sum_acc[1], s3);
154 
155     src += 4 * stride;
156     h -= 4;
157   } while (h != 0);
158 
159   *sum += horizontal_add_s32x4(vcombine_s32(sum_acc[0], sum_acc[1]));
160   return horizontal_long_add_u32x4(
161       vreinterpretq_u32_s32(vaddq_s32(sse[0], sse[1])));
162 }
163 
aom_sum_sse_2d_i16_nxn_neon(const int16_t * src,int stride,int width,int height,int * sum)164 static inline uint64_t aom_sum_sse_2d_i16_nxn_neon(const int16_t *src,
165                                                    int stride, int width,
166                                                    int height, int *sum) {
167   uint64x2_t sse = vdupq_n_u64(0);
168   int32x4_t sum_acc = vdupq_n_s32(0);
169 
170   int h = height;
171   do {
172     int32x4_t sse_row[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
173     int w = 0;
174     do {
175       const int16_t *s = src + w;
176       int16x8_t s0 = vld1q_s16(s + 0 * stride);
177       int16x8_t s1 = vld1q_s16(s + 1 * stride);
178       int16x8_t s2 = vld1q_s16(s + 2 * stride);
179       int16x8_t s3 = vld1q_s16(s + 3 * stride);
180 
181       sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s0), vget_low_s16(s0));
182       sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s1), vget_low_s16(s1));
183       sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s2), vget_low_s16(s2));
184       sse_row[0] = vmlal_s16(sse_row[0], vget_low_s16(s3), vget_low_s16(s3));
185       sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s0), vget_high_s16(s0));
186       sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s1), vget_high_s16(s1));
187       sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s2), vget_high_s16(s2));
188       sse_row[1] = vmlal_s16(sse_row[1], vget_high_s16(s3), vget_high_s16(s3));
189 
190       sum_acc = vpadalq_s16(sum_acc, s0);
191       sum_acc = vpadalq_s16(sum_acc, s1);
192       sum_acc = vpadalq_s16(sum_acc, s2);
193       sum_acc = vpadalq_s16(sum_acc, s3);
194 
195       w += 8;
196     } while (w < width);
197 
198     sse = vpadalq_u32(sse,
199                       vreinterpretq_u32_s32(vaddq_s32(sse_row[0], sse_row[1])));
200 
201     src += 4 * stride;
202     h -= 4;
203   } while (h != 0);
204 
205   *sum += horizontal_add_s32x4(sum_acc);
206   return horizontal_add_u64x2(sse);
207 }
208 
aom_sum_sse_2d_i16_neon(const int16_t * src,int stride,int width,int height,int * sum)209 uint64_t aom_sum_sse_2d_i16_neon(const int16_t *src, int stride, int width,
210                                  int height, int *sum) {
211   uint64_t sse;
212 
213   if (LIKELY(width == 4 && height == 4)) {
214     sse = aom_sum_sse_2d_i16_4x4_neon(src, stride, sum);
215   } else if (LIKELY(width == 4 && (height & 3) == 0)) {
216     // width = 4, height is a multiple of 4.
217     sse = aom_sum_sse_2d_i16_4xn_neon(src, stride, height, sum);
218   } else if (LIKELY((width & 7) == 0 && (height & 3) == 0)) {
219     // Generic case - width is multiple of 8, height is multiple of 4.
220     sse = aom_sum_sse_2d_i16_nxn_neon(src, stride, width, height, sum);
221   } else {
222     sse = aom_sum_sse_2d_i16_c(src, stride, width, height, sum);
223   }
224 
225   return sse;
226 }
227 
aom_sum_squares_i16_4xn_neon(const int16_t * src,uint32_t n)228 static inline uint64_t aom_sum_squares_i16_4xn_neon(const int16_t *src,
229                                                     uint32_t n) {
230   uint64x2_t sum_u64 = vdupq_n_u64(0);
231 
232   int i = n;
233   do {
234     uint32x4_t sum;
235     int16x4_t s0 = vld1_s16(src);
236 
237     sum = vreinterpretq_u32_s32(vmull_s16(s0, s0));
238 
239     sum_u64 = vpadalq_u32(sum_u64, sum);
240 
241     src += 4;
242     i -= 4;
243   } while (i >= 4);
244 
245   if (i > 0) {
246     return horizontal_add_u64x2(sum_u64) + aom_sum_squares_i16_c(src, i);
247   }
248   return horizontal_add_u64x2(sum_u64);
249 }
250 
aom_sum_squares_i16_8xn_neon(const int16_t * src,uint32_t n)251 static inline uint64_t aom_sum_squares_i16_8xn_neon(const int16_t *src,
252                                                     uint32_t n) {
253   uint64x2_t sum_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
254 
255   int i = n;
256   do {
257     uint32x4_t sum[2];
258     int16x8_t s0 = vld1q_s16(src);
259 
260     sum[0] =
261         vreinterpretq_u32_s32(vmull_s16(vget_low_s16(s0), vget_low_s16(s0)));
262     sum[1] =
263         vreinterpretq_u32_s32(vmull_s16(vget_high_s16(s0), vget_high_s16(s0)));
264 
265     sum_u64[0] = vpadalq_u32(sum_u64[0], sum[0]);
266     sum_u64[1] = vpadalq_u32(sum_u64[1], sum[1]);
267 
268     src += 8;
269     i -= 8;
270   } while (i >= 8);
271 
272   if (i > 0) {
273     return horizontal_add_u64x2(vaddq_u64(sum_u64[0], sum_u64[1])) +
274            aom_sum_squares_i16_c(src, i);
275   }
276   return horizontal_add_u64x2(vaddq_u64(sum_u64[0], sum_u64[1]));
277 }
278 
aom_sum_squares_i16_neon(const int16_t * src,uint32_t n)279 uint64_t aom_sum_squares_i16_neon(const int16_t *src, uint32_t n) {
280   // This function seems to be called only for values of N >= 64. See
281   // av1/encoder/compound_type.c.
282   if (LIKELY(n >= 8)) {
283     return aom_sum_squares_i16_8xn_neon(src, n);
284   }
285   if (n >= 4) {
286     return aom_sum_squares_i16_4xn_neon(src, n);
287   }
288   return aom_sum_squares_i16_c(src, n);
289 }
290 
aom_var_2d_u8_4xh_neon(uint8_t * src,int src_stride,int width,int height)291 static inline uint64_t aom_var_2d_u8_4xh_neon(uint8_t *src, int src_stride,
292                                               int width, int height) {
293   uint64_t sum = 0;
294   uint64_t sse = 0;
295   uint32x2_t sum_u32 = vdup_n_u32(0);
296   uint32x4_t sse_u32 = vdupq_n_u32(0);
297 
298   // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
299   // element before we need to accumulate to 32-bit elements. Since we're
300   // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
301   // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
302   // * 256) / width.
303   int h_limit = (4 * 256) / width;
304   int h_tmp = height > h_limit ? h_limit : height;
305 
306   int h = 0;
307   do {
308     uint16x4_t sum_u16 = vdup_n_u16(0);
309     do {
310       uint8_t *src_ptr = src;
311       int w = width;
312       do {
313         uint8x8_t s0 = load_unaligned_u8(src_ptr, src_stride);
314 
315         sum_u16 = vpadal_u8(sum_u16, s0);
316 
317         uint16x8_t sse_u16 = vmull_u8(s0, s0);
318 
319         sse_u32 = vpadalq_u16(sse_u32, sse_u16);
320 
321         src_ptr += 8;
322         w -= 8;
323       } while (w >= 8);
324 
325       // Process remaining columns in the row using C.
326       while (w > 0) {
327         int idx = width - w;
328         const uint8_t v = src[idx];
329         sum += v;
330         sse += v * v;
331         w--;
332       }
333 
334       src += 2 * src_stride;
335       h += 2;
336     } while (h < h_tmp && h < height);
337 
338     sum_u32 = vpadal_u16(sum_u32, sum_u16);
339     h_tmp += h_limit;
340   } while (h < height);
341 
342   sum += horizontal_long_add_u32x2(sum_u32);
343   sse += horizontal_long_add_u32x4(sse_u32);
344 
345   return sse - sum * sum / (width * height);
346 }
347 
aom_var_2d_u8_8xh_neon(uint8_t * src,int src_stride,int width,int height)348 static inline uint64_t aom_var_2d_u8_8xh_neon(uint8_t *src, int src_stride,
349                                               int width, int height) {
350   uint64_t sum = 0;
351   uint64_t sse = 0;
352   uint32x2_t sum_u32 = vdup_n_u32(0);
353   uint32x4_t sse_u32 = vdupq_n_u32(0);
354 
355   // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
356   // element before we need to accumulate to 32-bit elements. Since we're
357   // accumulating in uint16x4_t vectors, this means we can accumulate up to 4
358   // rows of 256 elements. Therefore the limit can be computed as: h_limit = (4
359   // * 256) / width.
360   int h_limit = (4 * 256) / width;
361   int h_tmp = height > h_limit ? h_limit : height;
362 
363   int h = 0;
364   do {
365     uint16x4_t sum_u16 = vdup_n_u16(0);
366     do {
367       uint8_t *src_ptr = src;
368       int w = width;
369       do {
370         uint8x8_t s0 = vld1_u8(src_ptr);
371 
372         sum_u16 = vpadal_u8(sum_u16, s0);
373 
374         uint16x8_t sse_u16 = vmull_u8(s0, s0);
375 
376         sse_u32 = vpadalq_u16(sse_u32, sse_u16);
377 
378         src_ptr += 8;
379         w -= 8;
380       } while (w >= 8);
381 
382       // Process remaining columns in the row using C.
383       while (w > 0) {
384         int idx = width - w;
385         const uint8_t v = src[idx];
386         sum += v;
387         sse += v * v;
388         w--;
389       }
390 
391       src += src_stride;
392       ++h;
393     } while (h < h_tmp && h < height);
394 
395     sum_u32 = vpadal_u16(sum_u32, sum_u16);
396     h_tmp += h_limit;
397   } while (h < height);
398 
399   sum += horizontal_long_add_u32x2(sum_u32);
400   sse += horizontal_long_add_u32x4(sse_u32);
401 
402   return sse - sum * sum / (width * height);
403 }
404 
aom_var_2d_u8_16xh_neon(uint8_t * src,int src_stride,int width,int height)405 static inline uint64_t aom_var_2d_u8_16xh_neon(uint8_t *src, int src_stride,
406                                                int width, int height) {
407   uint64_t sum = 0;
408   uint64_t sse = 0;
409   uint32x4_t sum_u32 = vdupq_n_u32(0);
410   uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
411 
412   // 255*256 = 65280, so we can accumulate up to 256 8-bit elements in a 16-bit
413   // element before we need to accumulate to 32-bit elements. Since we're
414   // accumulating in uint16x8_t vectors, this means we can accumulate up to 8
415   // rows of 256 elements. Therefore the limit can be computed as: h_limit = (8
416   // * 256) / width.
417   int h_limit = (8 * 256) / width;
418   int h_tmp = height > h_limit ? h_limit : height;
419 
420   int h = 0;
421   do {
422     uint16x8_t sum_u16 = vdupq_n_u16(0);
423     do {
424       int w = width;
425       uint8_t *src_ptr = src;
426       do {
427         uint8x16_t s0 = vld1q_u8(src_ptr);
428 
429         sum_u16 = vpadalq_u8(sum_u16, s0);
430 
431         uint16x8_t sse_u16_lo = vmull_u8(vget_low_u8(s0), vget_low_u8(s0));
432         uint16x8_t sse_u16_hi = vmull_u8(vget_high_u8(s0), vget_high_u8(s0));
433 
434         sse_u32[0] = vpadalq_u16(sse_u32[0], sse_u16_lo);
435         sse_u32[1] = vpadalq_u16(sse_u32[1], sse_u16_hi);
436 
437         src_ptr += 16;
438         w -= 16;
439       } while (w >= 16);
440 
441       // Process remaining columns in the row using C.
442       while (w > 0) {
443         int idx = width - w;
444         const uint8_t v = src[idx];
445         sum += v;
446         sse += v * v;
447         w--;
448       }
449 
450       src += src_stride;
451       ++h;
452     } while (h < h_tmp && h < height);
453 
454     sum_u32 = vpadalq_u16(sum_u32, sum_u16);
455     h_tmp += h_limit;
456   } while (h < height);
457 
458   sum += horizontal_long_add_u32x4(sum_u32);
459   sse += horizontal_long_add_u32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
460 
461   return sse - sum * sum / (width * height);
462 }
463 
aom_var_2d_u8_neon(uint8_t * src,int src_stride,int width,int height)464 uint64_t aom_var_2d_u8_neon(uint8_t *src, int src_stride, int width,
465                             int height) {
466   if (width >= 16) {
467     return aom_var_2d_u8_16xh_neon(src, src_stride, width, height);
468   }
469   if (width >= 8) {
470     return aom_var_2d_u8_8xh_neon(src, src_stride, width, height);
471   }
472   if (width >= 4 && height % 2 == 0) {
473     return aom_var_2d_u8_4xh_neon(src, src_stride, width, height);
474   }
475   return aom_var_2d_u8_c(src, src_stride, width, height);
476 }
477 
478 #if CONFIG_AV1_HIGHBITDEPTH
aom_var_2d_u16_4xh_neon(uint8_t * src,int src_stride,int width,int height)479 static inline uint64_t aom_var_2d_u16_4xh_neon(uint8_t *src, int src_stride,
480                                                int width, int height) {
481   uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
482   uint64_t sum = 0;
483   uint64_t sse = 0;
484   uint32x2_t sum_u32 = vdup_n_u32(0);
485   uint64x2_t sse_u64 = vdupq_n_u64(0);
486 
487   int h = height;
488   do {
489     int w = width;
490     uint16_t *src_ptr = src_u16;
491     do {
492       uint16x4_t s0 = vld1_u16(src_ptr);
493 
494       sum_u32 = vpadal_u16(sum_u32, s0);
495 
496       uint32x4_t sse_u32 = vmull_u16(s0, s0);
497 
498       sse_u64 = vpadalq_u32(sse_u64, sse_u32);
499 
500       src_ptr += 4;
501       w -= 4;
502     } while (w >= 4);
503 
504     // Process remaining columns in the row using C.
505     while (w > 0) {
506       int idx = width - w;
507       const uint16_t v = src_u16[idx];
508       sum += v;
509       sse += v * v;
510       w--;
511     }
512 
513     src_u16 += src_stride;
514   } while (--h != 0);
515 
516   sum += horizontal_long_add_u32x2(sum_u32);
517   sse += horizontal_add_u64x2(sse_u64);
518 
519   return sse - sum * sum / (width * height);
520 }
521 
aom_var_2d_u16_8xh_neon(uint8_t * src,int src_stride,int width,int height)522 static inline uint64_t aom_var_2d_u16_8xh_neon(uint8_t *src, int src_stride,
523                                                int width, int height) {
524   uint16_t *src_u16 = CONVERT_TO_SHORTPTR(src);
525   uint64_t sum = 0;
526   uint64_t sse = 0;
527   uint32x4_t sum_u32 = vdupq_n_u32(0);
528   uint64x2_t sse_u64[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
529 
530   int h = height;
531   do {
532     int w = width;
533     uint16_t *src_ptr = src_u16;
534     do {
535       uint16x8_t s0 = vld1q_u16(src_ptr);
536 
537       sum_u32 = vpadalq_u16(sum_u32, s0);
538 
539       uint32x4_t sse_u32_lo = vmull_u16(vget_low_u16(s0), vget_low_u16(s0));
540       uint32x4_t sse_u32_hi = vmull_u16(vget_high_u16(s0), vget_high_u16(s0));
541 
542       sse_u64[0] = vpadalq_u32(sse_u64[0], sse_u32_lo);
543       sse_u64[1] = vpadalq_u32(sse_u64[1], sse_u32_hi);
544 
545       src_ptr += 8;
546       w -= 8;
547     } while (w >= 8);
548 
549     // Process remaining columns in the row using C.
550     while (w > 0) {
551       int idx = width - w;
552       const uint16_t v = src_u16[idx];
553       sum += v;
554       sse += v * v;
555       w--;
556     }
557 
558     src_u16 += src_stride;
559   } while (--h != 0);
560 
561   sum += horizontal_long_add_u32x4(sum_u32);
562   sse += horizontal_add_u64x2(vaddq_u64(sse_u64[0], sse_u64[1]));
563 
564   return sse - sum * sum / (width * height);
565 }
566 
aom_var_2d_u16_neon(uint8_t * src,int src_stride,int width,int height)567 uint64_t aom_var_2d_u16_neon(uint8_t *src, int src_stride, int width,
568                              int height) {
569   if (width >= 8) {
570     return aom_var_2d_u16_8xh_neon(src, src_stride, width, height);
571   }
572   if (width >= 4) {
573     return aom_var_2d_u16_4xh_neon(src, src_stride, width, height);
574   }
575   return aom_var_2d_u16_c(src, src_stride, width, height);
576 }
577 #endif  // CONFIG_AV1_HIGHBITDEPTH
578