1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "config/aom_dsp_rtcd.h"
17
highbd_sse_8x1_neon(const uint16_t * src,const uint16_t * ref,uint64x2_t * sse)18 static inline void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref,
19 uint64x2_t *sse) {
20 uint16x8_t s = vld1q_u16(src);
21 uint16x8_t r = vld1q_u16(ref);
22
23 uint16x8_t abs_diff = vabdq_u16(s, r);
24
25 *sse = aom_udotq_u16(*sse, abs_diff, abs_diff);
26 }
27
highbd_sse_128xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)28 static inline int64_t highbd_sse_128xh_sve(const uint16_t *src, int src_stride,
29 const uint16_t *ref, int ref_stride,
30 int height) {
31 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
32 vdupq_n_u64(0) };
33
34 do {
35 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
36 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
37 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
38 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
39 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]);
40 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]);
41 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]);
42 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]);
43 highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0]);
44 highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[1]);
45 highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[2]);
46 highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[3]);
47 highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[0]);
48 highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[1]);
49 highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[2]);
50 highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[3]);
51
52 src += src_stride;
53 ref += ref_stride;
54 } while (--height != 0);
55
56 sse[0] = vaddq_u64(sse[0], sse[1]);
57 sse[2] = vaddq_u64(sse[2], sse[3]);
58 sse[0] = vaddq_u64(sse[0], sse[2]);
59 return vaddvq_u64(sse[0]);
60 }
61
highbd_sse_64xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)62 static inline int64_t highbd_sse_64xh_sve(const uint16_t *src, int src_stride,
63 const uint16_t *ref, int ref_stride,
64 int height) {
65 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
66 vdupq_n_u64(0) };
67
68 do {
69 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
70 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
71 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
72 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
73 highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]);
74 highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]);
75 highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]);
76 highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]);
77
78 src += src_stride;
79 ref += ref_stride;
80 } while (--height != 0);
81
82 sse[0] = vaddq_u64(sse[0], sse[1]);
83 sse[2] = vaddq_u64(sse[2], sse[3]);
84 sse[0] = vaddq_u64(sse[0], sse[2]);
85 return vaddvq_u64(sse[0]);
86 }
87
highbd_sse_32xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)88 static inline int64_t highbd_sse_32xh_sve(const uint16_t *src, int src_stride,
89 const uint16_t *ref, int ref_stride,
90 int height) {
91 uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
92 vdupq_n_u64(0) };
93
94 do {
95 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
96 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
97 highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
98 highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
99
100 src += src_stride;
101 ref += ref_stride;
102 } while (--height != 0);
103
104 sse[0] = vaddq_u64(sse[0], sse[1]);
105 sse[2] = vaddq_u64(sse[2], sse[3]);
106 sse[0] = vaddq_u64(sse[0], sse[2]);
107 return vaddvq_u64(sse[0]);
108 }
109
highbd_sse_16xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)110 static inline int64_t highbd_sse_16xh_sve(const uint16_t *src, int src_stride,
111 const uint16_t *ref, int ref_stride,
112 int height) {
113 uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
114
115 do {
116 highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
117 highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
118
119 src += src_stride;
120 ref += ref_stride;
121 } while (--height != 0);
122
123 return vaddvq_u64(vaddq_u64(sse[0], sse[1]));
124 }
125
highbd_sse_8xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)126 static inline int64_t highbd_sse_8xh_sve(const uint16_t *src, int src_stride,
127 const uint16_t *ref, int ref_stride,
128 int height) {
129 uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
130
131 do {
132 highbd_sse_8x1_neon(src + 0 * src_stride, ref + 0 * ref_stride, &sse[0]);
133 highbd_sse_8x1_neon(src + 1 * src_stride, ref + 1 * ref_stride, &sse[1]);
134
135 src += 2 * src_stride;
136 ref += 2 * ref_stride;
137 height -= 2;
138 } while (height != 0);
139
140 return vaddvq_u64(vaddq_u64(sse[0], sse[1]));
141 }
142
highbd_sse_4xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)143 static inline int64_t highbd_sse_4xh_sve(const uint16_t *src, int src_stride,
144 const uint16_t *ref, int ref_stride,
145 int height) {
146 uint64x2_t sse = vdupq_n_u64(0);
147
148 do {
149 uint16x8_t s = load_unaligned_u16_4x2(src, src_stride);
150 uint16x8_t r = load_unaligned_u16_4x2(ref, ref_stride);
151
152 uint16x8_t abs_diff = vabdq_u16(s, r);
153 sse = aom_udotq_u16(sse, abs_diff, abs_diff);
154
155 src += 2 * src_stride;
156 ref += 2 * ref_stride;
157 height -= 2;
158 } while (height != 0);
159
160 return vaddvq_u64(sse);
161 }
162
highbd_sse_wxh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int width,int height)163 static inline int64_t highbd_sse_wxh_sve(const uint16_t *src, int src_stride,
164 const uint16_t *ref, int ref_stride,
165 int width, int height) {
166 svuint64_t sse = svdup_n_u64(0);
167 uint64_t step = svcnth();
168
169 do {
170 int w = 0;
171 const uint16_t *src_ptr = src;
172 const uint16_t *ref_ptr = ref;
173
174 do {
175 svbool_t pred = svwhilelt_b16_u32(w, width);
176 svuint16_t s = svld1_u16(pred, src_ptr);
177 svuint16_t r = svld1_u16(pred, ref_ptr);
178
179 svuint16_t abs_diff = svabd_u16_z(pred, s, r);
180
181 sse = svdot_u64(sse, abs_diff, abs_diff);
182
183 src_ptr += step;
184 ref_ptr += step;
185 w += step;
186 } while (w < width);
187
188 src += src_stride;
189 ref += ref_stride;
190 } while (--height != 0);
191
192 return svaddv_u64(svptrue_b64(), sse);
193 }
194
aom_highbd_sse_sve(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,int width,int height)195 int64_t aom_highbd_sse_sve(const uint8_t *src8, int src_stride,
196 const uint8_t *ref8, int ref_stride, int width,
197 int height) {
198 uint16_t *src = CONVERT_TO_SHORTPTR(src8);
199 uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
200
201 switch (width) {
202 case 4: return highbd_sse_4xh_sve(src, src_stride, ref, ref_stride, height);
203 case 8: return highbd_sse_8xh_sve(src, src_stride, ref, ref_stride, height);
204 case 16:
205 return highbd_sse_16xh_sve(src, src_stride, ref, ref_stride, height);
206 case 32:
207 return highbd_sse_32xh_sve(src, src_stride, ref, ref_stride, height);
208 case 64:
209 return highbd_sse_64xh_sve(src, src_stride, ref, ref_stride, height);
210 case 128:
211 return highbd_sse_128xh_sve(src, src_stride, ref, ref_stride, height);
212 default:
213 return highbd_sse_wxh_sve(src, src_stride, ref, ref_stride, width,
214 height);
215 }
216 }
217