xref: /aosp_15_r20/external/libaom/aom_dsp/arm/highbd_sse_sve.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "config/aom_dsp_rtcd.h"
17 
highbd_sse_8x1_neon(const uint16_t * src,const uint16_t * ref,uint64x2_t * sse)18 static inline void highbd_sse_8x1_neon(const uint16_t *src, const uint16_t *ref,
19                                        uint64x2_t *sse) {
20   uint16x8_t s = vld1q_u16(src);
21   uint16x8_t r = vld1q_u16(ref);
22 
23   uint16x8_t abs_diff = vabdq_u16(s, r);
24 
25   *sse = aom_udotq_u16(*sse, abs_diff, abs_diff);
26 }
27 
highbd_sse_128xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)28 static inline int64_t highbd_sse_128xh_sve(const uint16_t *src, int src_stride,
29                                            const uint16_t *ref, int ref_stride,
30                                            int height) {
31   uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
32                         vdupq_n_u64(0) };
33 
34   do {
35     highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
36     highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
37     highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
38     highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
39     highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]);
40     highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]);
41     highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]);
42     highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]);
43     highbd_sse_8x1_neon(src + 8 * 8, ref + 8 * 8, &sse[0]);
44     highbd_sse_8x1_neon(src + 9 * 8, ref + 9 * 8, &sse[1]);
45     highbd_sse_8x1_neon(src + 10 * 8, ref + 10 * 8, &sse[2]);
46     highbd_sse_8x1_neon(src + 11 * 8, ref + 11 * 8, &sse[3]);
47     highbd_sse_8x1_neon(src + 12 * 8, ref + 12 * 8, &sse[0]);
48     highbd_sse_8x1_neon(src + 13 * 8, ref + 13 * 8, &sse[1]);
49     highbd_sse_8x1_neon(src + 14 * 8, ref + 14 * 8, &sse[2]);
50     highbd_sse_8x1_neon(src + 15 * 8, ref + 15 * 8, &sse[3]);
51 
52     src += src_stride;
53     ref += ref_stride;
54   } while (--height != 0);
55 
56   sse[0] = vaddq_u64(sse[0], sse[1]);
57   sse[2] = vaddq_u64(sse[2], sse[3]);
58   sse[0] = vaddq_u64(sse[0], sse[2]);
59   return vaddvq_u64(sse[0]);
60 }
61 
highbd_sse_64xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)62 static inline int64_t highbd_sse_64xh_sve(const uint16_t *src, int src_stride,
63                                           const uint16_t *ref, int ref_stride,
64                                           int height) {
65   uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
66                         vdupq_n_u64(0) };
67 
68   do {
69     highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
70     highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
71     highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
72     highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
73     highbd_sse_8x1_neon(src + 4 * 8, ref + 4 * 8, &sse[0]);
74     highbd_sse_8x1_neon(src + 5 * 8, ref + 5 * 8, &sse[1]);
75     highbd_sse_8x1_neon(src + 6 * 8, ref + 6 * 8, &sse[2]);
76     highbd_sse_8x1_neon(src + 7 * 8, ref + 7 * 8, &sse[3]);
77 
78     src += src_stride;
79     ref += ref_stride;
80   } while (--height != 0);
81 
82   sse[0] = vaddq_u64(sse[0], sse[1]);
83   sse[2] = vaddq_u64(sse[2], sse[3]);
84   sse[0] = vaddq_u64(sse[0], sse[2]);
85   return vaddvq_u64(sse[0]);
86 }
87 
highbd_sse_32xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)88 static inline int64_t highbd_sse_32xh_sve(const uint16_t *src, int src_stride,
89                                           const uint16_t *ref, int ref_stride,
90                                           int height) {
91   uint64x2_t sse[4] = { vdupq_n_u64(0), vdupq_n_u64(0), vdupq_n_u64(0),
92                         vdupq_n_u64(0) };
93 
94   do {
95     highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
96     highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
97     highbd_sse_8x1_neon(src + 2 * 8, ref + 2 * 8, &sse[2]);
98     highbd_sse_8x1_neon(src + 3 * 8, ref + 3 * 8, &sse[3]);
99 
100     src += src_stride;
101     ref += ref_stride;
102   } while (--height != 0);
103 
104   sse[0] = vaddq_u64(sse[0], sse[1]);
105   sse[2] = vaddq_u64(sse[2], sse[3]);
106   sse[0] = vaddq_u64(sse[0], sse[2]);
107   return vaddvq_u64(sse[0]);
108 }
109 
highbd_sse_16xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)110 static inline int64_t highbd_sse_16xh_sve(const uint16_t *src, int src_stride,
111                                           const uint16_t *ref, int ref_stride,
112                                           int height) {
113   uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
114 
115   do {
116     highbd_sse_8x1_neon(src + 0 * 8, ref + 0 * 8, &sse[0]);
117     highbd_sse_8x1_neon(src + 1 * 8, ref + 1 * 8, &sse[1]);
118 
119     src += src_stride;
120     ref += ref_stride;
121   } while (--height != 0);
122 
123   return vaddvq_u64(vaddq_u64(sse[0], sse[1]));
124 }
125 
highbd_sse_8xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)126 static inline int64_t highbd_sse_8xh_sve(const uint16_t *src, int src_stride,
127                                          const uint16_t *ref, int ref_stride,
128                                          int height) {
129   uint64x2_t sse[2] = { vdupq_n_u64(0), vdupq_n_u64(0) };
130 
131   do {
132     highbd_sse_8x1_neon(src + 0 * src_stride, ref + 0 * ref_stride, &sse[0]);
133     highbd_sse_8x1_neon(src + 1 * src_stride, ref + 1 * ref_stride, &sse[1]);
134 
135     src += 2 * src_stride;
136     ref += 2 * ref_stride;
137     height -= 2;
138   } while (height != 0);
139 
140   return vaddvq_u64(vaddq_u64(sse[0], sse[1]));
141 }
142 
highbd_sse_4xh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int height)143 static inline int64_t highbd_sse_4xh_sve(const uint16_t *src, int src_stride,
144                                          const uint16_t *ref, int ref_stride,
145                                          int height) {
146   uint64x2_t sse = vdupq_n_u64(0);
147 
148   do {
149     uint16x8_t s = load_unaligned_u16_4x2(src, src_stride);
150     uint16x8_t r = load_unaligned_u16_4x2(ref, ref_stride);
151 
152     uint16x8_t abs_diff = vabdq_u16(s, r);
153     sse = aom_udotq_u16(sse, abs_diff, abs_diff);
154 
155     src += 2 * src_stride;
156     ref += 2 * ref_stride;
157     height -= 2;
158   } while (height != 0);
159 
160   return vaddvq_u64(sse);
161 }
162 
highbd_sse_wxh_sve(const uint16_t * src,int src_stride,const uint16_t * ref,int ref_stride,int width,int height)163 static inline int64_t highbd_sse_wxh_sve(const uint16_t *src, int src_stride,
164                                          const uint16_t *ref, int ref_stride,
165                                          int width, int height) {
166   svuint64_t sse = svdup_n_u64(0);
167   uint64_t step = svcnth();
168 
169   do {
170     int w = 0;
171     const uint16_t *src_ptr = src;
172     const uint16_t *ref_ptr = ref;
173 
174     do {
175       svbool_t pred = svwhilelt_b16_u32(w, width);
176       svuint16_t s = svld1_u16(pred, src_ptr);
177       svuint16_t r = svld1_u16(pred, ref_ptr);
178 
179       svuint16_t abs_diff = svabd_u16_z(pred, s, r);
180 
181       sse = svdot_u64(sse, abs_diff, abs_diff);
182 
183       src_ptr += step;
184       ref_ptr += step;
185       w += step;
186     } while (w < width);
187 
188     src += src_stride;
189     ref += ref_stride;
190   } while (--height != 0);
191 
192   return svaddv_u64(svptrue_b64(), sse);
193 }
194 
aom_highbd_sse_sve(const uint8_t * src8,int src_stride,const uint8_t * ref8,int ref_stride,int width,int height)195 int64_t aom_highbd_sse_sve(const uint8_t *src8, int src_stride,
196                            const uint8_t *ref8, int ref_stride, int width,
197                            int height) {
198   uint16_t *src = CONVERT_TO_SHORTPTR(src8);
199   uint16_t *ref = CONVERT_TO_SHORTPTR(ref8);
200 
201   switch (width) {
202     case 4: return highbd_sse_4xh_sve(src, src_stride, ref, ref_stride, height);
203     case 8: return highbd_sse_8xh_sve(src, src_stride, ref, ref_stride, height);
204     case 16:
205       return highbd_sse_16xh_sve(src, src_stride, ref, ref_stride, height);
206     case 32:
207       return highbd_sse_32xh_sve(src, src_stride, ref, ref_stride, height);
208     case 64:
209       return highbd_sse_64xh_sve(src, src_stride, ref, ref_stride, height);
210     case 128:
211       return highbd_sse_128xh_sve(src, src_stride, ref, ref_stride, height);
212     default:
213       return highbd_sse_wxh_sve(src, src_stride, ref, ref_stride, width,
214                                 height);
215   }
216 }
217