1 /*
2 * Copyright (c) 2018 The WebM project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <arm_neon.h>
12 #include <assert.h>
13
14 #include "./vpx_dsp_rtcd.h"
15 #include "vpx_dsp/arm/sum_neon.h"
16
vpx_sum_squares_2d_i16_neon(const int16_t * src,int stride,int size)17 uint64_t vpx_sum_squares_2d_i16_neon(const int16_t *src, int stride, int size) {
18 if (size == 4) {
19 int16x4_t s[4];
20 int32x4_t sum_s32;
21
22 s[0] = vld1_s16(src + 0 * stride);
23 s[1] = vld1_s16(src + 1 * stride);
24 s[2] = vld1_s16(src + 2 * stride);
25 s[3] = vld1_s16(src + 3 * stride);
26
27 sum_s32 = vmull_s16(s[0], s[0]);
28 sum_s32 = vmlal_s16(sum_s32, s[1], s[1]);
29 sum_s32 = vmlal_s16(sum_s32, s[2], s[2]);
30 sum_s32 = vmlal_s16(sum_s32, s[3], s[3]);
31
32 return horizontal_long_add_uint32x4(vreinterpretq_u32_s32(sum_s32));
33 } else {
34 uint64x2_t sum_u64 = vdupq_n_u64(0);
35 int rows = size;
36
37 do {
38 const int16_t *src_ptr = src;
39 int32x4_t sum_s32[2] = { vdupq_n_s32(0), vdupq_n_s32(0) };
40 int cols = size;
41
42 do {
43 int16x8_t s[8];
44
45 s[0] = vld1q_s16(src_ptr + 0 * stride);
46 s[1] = vld1q_s16(src_ptr + 1 * stride);
47 s[2] = vld1q_s16(src_ptr + 2 * stride);
48 s[3] = vld1q_s16(src_ptr + 3 * stride);
49 s[4] = vld1q_s16(src_ptr + 4 * stride);
50 s[5] = vld1q_s16(src_ptr + 5 * stride);
51 s[6] = vld1q_s16(src_ptr + 6 * stride);
52 s[7] = vld1q_s16(src_ptr + 7 * stride);
53
54 sum_s32[0] =
55 vmlal_s16(sum_s32[0], vget_low_s16(s[0]), vget_low_s16(s[0]));
56 sum_s32[0] =
57 vmlal_s16(sum_s32[0], vget_low_s16(s[1]), vget_low_s16(s[1]));
58 sum_s32[0] =
59 vmlal_s16(sum_s32[0], vget_low_s16(s[2]), vget_low_s16(s[2]));
60 sum_s32[0] =
61 vmlal_s16(sum_s32[0], vget_low_s16(s[3]), vget_low_s16(s[3]));
62 sum_s32[0] =
63 vmlal_s16(sum_s32[0], vget_low_s16(s[4]), vget_low_s16(s[4]));
64 sum_s32[0] =
65 vmlal_s16(sum_s32[0], vget_low_s16(s[5]), vget_low_s16(s[5]));
66 sum_s32[0] =
67 vmlal_s16(sum_s32[0], vget_low_s16(s[6]), vget_low_s16(s[6]));
68 sum_s32[0] =
69 vmlal_s16(sum_s32[0], vget_low_s16(s[7]), vget_low_s16(s[7]));
70
71 sum_s32[1] =
72 vmlal_s16(sum_s32[1], vget_high_s16(s[0]), vget_high_s16(s[0]));
73 sum_s32[1] =
74 vmlal_s16(sum_s32[1], vget_high_s16(s[1]), vget_high_s16(s[1]));
75 sum_s32[1] =
76 vmlal_s16(sum_s32[1], vget_high_s16(s[2]), vget_high_s16(s[2]));
77 sum_s32[1] =
78 vmlal_s16(sum_s32[1], vget_high_s16(s[3]), vget_high_s16(s[3]));
79 sum_s32[1] =
80 vmlal_s16(sum_s32[1], vget_high_s16(s[4]), vget_high_s16(s[4]));
81 sum_s32[1] =
82 vmlal_s16(sum_s32[1], vget_high_s16(s[5]), vget_high_s16(s[5]));
83 sum_s32[1] =
84 vmlal_s16(sum_s32[1], vget_high_s16(s[6]), vget_high_s16(s[6]));
85 sum_s32[1] =
86 vmlal_s16(sum_s32[1], vget_high_s16(s[7]), vget_high_s16(s[7]));
87
88 src_ptr += 8;
89 cols -= 8;
90 } while (cols);
91
92 sum_u64 = vpadalq_u32(sum_u64, vreinterpretq_u32_s32(sum_s32[0]));
93 sum_u64 = vpadalq_u32(sum_u64, vreinterpretq_u32_s32(sum_s32[1]));
94 src += 8 * stride;
95 rows -= 8;
96 } while (rows);
97
98 return horizontal_add_uint64x2(sum_u64);
99 }
100 }
101