xref: /aosp_15_r20/external/libvpx/vpx_dsp/arm/highbd_vpx_convolve8_sve.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2024 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <assert.h>
12 #include <arm_neon.h>
13 
14 #include "./vpx_config.h"
15 #include "./vpx_dsp_rtcd.h"
16 
17 #include "vpx/vpx_integer.h"
18 #include "vpx_dsp/arm/highbd_convolve8_sve.h"
19 #include "vpx_dsp/arm/mem_neon.h"
20 #include "vpx_dsp/arm/transpose_neon.h"
21 #include "vpx_dsp/arm/vpx_neon_sve_bridge.h"
22 
23 DECLARE_ALIGNED(16, static const uint16_t, kTblConv4_8[8]) = { 0, 2, 4, 6,
24                                                                1, 3, 5, 7 };
25 
highbd_convolve_4tap_horiz_sve(const uint16_t * src,ptrdiff_t src_stride,uint16_t * dst,ptrdiff_t dst_stride,int w,int h,const int16x4_t filters,int bd)26 static INLINE void highbd_convolve_4tap_horiz_sve(
27     const uint16_t *src, ptrdiff_t src_stride, uint16_t *dst,
28     ptrdiff_t dst_stride, int w, int h, const int16x4_t filters, int bd) {
29   const int16x8_t filter = vcombine_s16(filters, vdup_n_s16(0));
30 
31   if (w == 4) {
32     const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
33     const int16_t *s = (const int16_t *)src;
34     uint16_t *d = dst;
35 
36     do {
37       int16x4_t s0[4], s1[4], s2[4], s3[4];
38       load_s16_4x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
39       load_s16_4x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
40       load_s16_4x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
41       load_s16_4x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
42 
43       uint16x4_t d0 = highbd_convolve4_4_sve(s0, filter, max);
44       uint16x4_t d1 = highbd_convolve4_4_sve(s1, filter, max);
45       uint16x4_t d2 = highbd_convolve4_4_sve(s2, filter, max);
46       uint16x4_t d3 = highbd_convolve4_4_sve(s3, filter, max);
47 
48       store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
49 
50       s += 4 * src_stride;
51       d += 4 * dst_stride;
52       h -= 4;
53     } while (h != 0);
54   } else {
55     const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
56     const uint16x8_t idx = vld1q_u16(kTblConv4_8);
57 
58     do {
59       const int16_t *s = (const int16_t *)src;
60       uint16_t *d = dst;
61       int width = w;
62 
63       do {
64         int16x8_t s0[4], s1[4], s2[4], s3[4];
65         load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
66         load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
67         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
68         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
69 
70         uint16x8_t d0 = highbd_convolve4_8_sve(s0, filter, max, idx);
71         uint16x8_t d1 = highbd_convolve4_8_sve(s1, filter, max, idx);
72         uint16x8_t d2 = highbd_convolve4_8_sve(s2, filter, max, idx);
73         uint16x8_t d3 = highbd_convolve4_8_sve(s3, filter, max, idx);
74 
75         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
76 
77         s += 8;
78         d += 8;
79         width -= 8;
80       } while (width != 0);
81 
82       src += 4 * src_stride;
83       dst += 4 * dst_stride;
84       h -= 4;
85     } while (h != 0);
86   }
87 }
88 
highbd_convolve_8tap_horiz_sve(const uint16_t * src,ptrdiff_t src_stride,uint16_t * dst,ptrdiff_t dst_stride,int w,int h,const int16x8_t filters,int bd)89 static INLINE void highbd_convolve_8tap_horiz_sve(
90     const uint16_t *src, ptrdiff_t src_stride, uint16_t *dst,
91     ptrdiff_t dst_stride, int w, int h, const int16x8_t filters, int bd) {
92   if (w == 4) {
93     const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
94     const int16_t *s = (const int16_t *)src;
95     uint16_t *d = dst;
96 
97     do {
98       int16x8_t s0[4], s1[4], s2[4], s3[4];
99       load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
100       load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
101       load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
102       load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
103 
104       uint16x4_t d0 = highbd_convolve8_4(s0, filters, max);
105       uint16x4_t d1 = highbd_convolve8_4(s1, filters, max);
106       uint16x4_t d2 = highbd_convolve8_4(s2, filters, max);
107       uint16x4_t d3 = highbd_convolve8_4(s3, filters, max);
108 
109       store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
110 
111       s += 4 * src_stride;
112       d += 4 * dst_stride;
113       h -= 4;
114     } while (h != 0);
115   } else {
116     const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
117 
118     do {
119       const int16_t *s = (const int16_t *)src;
120       uint16_t *d = dst;
121       int width = w;
122 
123       do {
124         int16x8_t s0[8], s1[8], s2[8], s3[8];
125         load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
126                      &s0[4], &s0[5], &s0[6], &s0[7]);
127         load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
128                      &s1[4], &s1[5], &s1[6], &s1[7]);
129         load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
130                      &s2[4], &s2[5], &s2[6], &s2[7]);
131         load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
132                      &s3[4], &s3[5], &s3[6], &s3[7]);
133 
134         uint16x8_t d0 = highbd_convolve8_8(s0, filters, max);
135         uint16x8_t d1 = highbd_convolve8_8(s1, filters, max);
136         uint16x8_t d2 = highbd_convolve8_8(s2, filters, max);
137         uint16x8_t d3 = highbd_convolve8_8(s3, filters, max);
138 
139         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
140 
141         s += 8;
142         d += 8;
143         width -= 8;
144       } while (width != 0);
145 
146       src += 4 * src_stride;
147       dst += 4 * dst_stride;
148       h -= 4;
149     } while (h != 0);
150   }
151 }
152 
vpx_highbd_convolve8_horiz_sve(const uint16_t * src,ptrdiff_t src_stride,uint16_t * dst,ptrdiff_t dst_stride,const InterpKernel * filter,int x0_q4,int x_step_q4,int y0_q4,int y_step_q4,int w,int h,int bd)153 void vpx_highbd_convolve8_horiz_sve(const uint16_t *src, ptrdiff_t src_stride,
154                                     uint16_t *dst, ptrdiff_t dst_stride,
155                                     const InterpKernel *filter, int x0_q4,
156                                     int x_step_q4, int y0_q4, int y_step_q4,
157                                     int w, int h, int bd) {
158   if (x_step_q4 != 16) {
159     vpx_highbd_convolve8_horiz_c(src, src_stride, dst, dst_stride, filter,
160                                  x0_q4, x_step_q4, y0_q4, y_step_q4, w, h, bd);
161     return;
162   }
163 
164   assert((intptr_t)dst % 4 == 0);
165   assert(dst_stride % 4 == 0);
166   assert(x_step_q4 == 16);
167 
168   (void)x_step_q4;
169   (void)y0_q4;
170   (void)y_step_q4;
171 
172   if (vpx_get_filter_taps(filter[x0_q4]) <= 4) {
173     const int16x4_t x_filter_4tap = vld1_s16(filter[x0_q4] + 2);
174     highbd_convolve_4tap_horiz_sve(src - 1, src_stride, dst, dst_stride, w, h,
175                                    x_filter_4tap, bd);
176   } else {
177     const int16x8_t x_filter_8tap = vld1q_s16(filter[x0_q4]);
178     highbd_convolve_8tap_horiz_sve(src - 3, src_stride, dst, dst_stride, w, h,
179                                    x_filter_8tap, bd);
180   }
181 }
182 
vpx_highbd_convolve8_avg_horiz_sve(const uint16_t * src,ptrdiff_t src_stride,uint16_t * dst,ptrdiff_t dst_stride,const InterpKernel * filter,int x0_q4,int x_step_q4,int y0_q4,int y_step_q4,int w,int h,int bd)183 void vpx_highbd_convolve8_avg_horiz_sve(const uint16_t *src,
184                                         ptrdiff_t src_stride, uint16_t *dst,
185                                         ptrdiff_t dst_stride,
186                                         const InterpKernel *filter, int x0_q4,
187                                         int x_step_q4, int y0_q4, int y_step_q4,
188                                         int w, int h, int bd) {
189   if (x_step_q4 != 16) {
190     vpx_highbd_convolve8_avg_horiz_c(src, src_stride, dst, dst_stride, filter,
191                                      x0_q4, x_step_q4, y0_q4, y_step_q4, w, h,
192                                      bd);
193     return;
194   }
195   assert((intptr_t)dst % 4 == 0);
196   assert(dst_stride % 4 == 0);
197 
198   const int16x8_t filters = vld1q_s16(filter[x0_q4]);
199 
200   src -= 3;
201 
202   if (w == 4) {
203     const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
204     const int16_t *s = (const int16_t *)src;
205     uint16_t *d = dst;
206 
207     do {
208       int16x8_t s0[4], s1[4], s2[4], s3[4];
209       load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
210       load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
211       load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
212       load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
213 
214       uint16x4_t d0 = highbd_convolve8_4(s0, filters, max);
215       uint16x4_t d1 = highbd_convolve8_4(s1, filters, max);
216       uint16x4_t d2 = highbd_convolve8_4(s2, filters, max);
217       uint16x4_t d3 = highbd_convolve8_4(s3, filters, max);
218 
219       d0 = vrhadd_u16(d0, vld1_u16(d + 0 * dst_stride));
220       d1 = vrhadd_u16(d1, vld1_u16(d + 1 * dst_stride));
221       d2 = vrhadd_u16(d2, vld1_u16(d + 2 * dst_stride));
222       d3 = vrhadd_u16(d3, vld1_u16(d + 3 * dst_stride));
223 
224       store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
225 
226       s += 4 * src_stride;
227       d += 4 * dst_stride;
228       h -= 4;
229     } while (h != 0);
230   } else {
231     const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
232 
233     do {
234       const int16_t *s = (const int16_t *)src;
235       uint16_t *d = dst;
236       int width = w;
237 
238       do {
239         int16x8_t s0[8], s1[8], s2[8], s3[8];
240         load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
241                      &s0[4], &s0[5], &s0[6], &s0[7]);
242         load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
243                      &s1[4], &s1[5], &s1[6], &s1[7]);
244         load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
245                      &s2[4], &s2[5], &s2[6], &s2[7]);
246         load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
247                      &s3[4], &s3[5], &s3[6], &s3[7]);
248 
249         uint16x8_t d0 = highbd_convolve8_8(s0, filters, max);
250         uint16x8_t d1 = highbd_convolve8_8(s1, filters, max);
251         uint16x8_t d2 = highbd_convolve8_8(s2, filters, max);
252         uint16x8_t d3 = highbd_convolve8_8(s3, filters, max);
253 
254         d0 = vrhaddq_u16(d0, vld1q_u16(d + 0 * dst_stride));
255         d1 = vrhaddq_u16(d1, vld1q_u16(d + 1 * dst_stride));
256         d2 = vrhaddq_u16(d2, vld1q_u16(d + 2 * dst_stride));
257         d3 = vrhaddq_u16(d3, vld1q_u16(d + 3 * dst_stride));
258 
259         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
260 
261         s += 8;
262         d += 8;
263         width -= 8;
264       } while (width != 0);
265 
266       src += 4 * src_stride;
267       dst += 4 * dst_stride;
268       h -= 4;
269     } while (h != 0);
270   }
271 }
272