xref: /aosp_15_r20/external/libaom/av1/common/arm/highbd_compound_convolve_sve2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <assert.h>
13 #include <arm_neon.h>
14 
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/aom_dsp_common.h"
19 #include "aom_dsp/arm/aom_neon_sve_bridge.h"
20 #include "aom_dsp/arm/aom_neon_sve2_bridge.h"
21 #include "aom_dsp/arm/mem_neon.h"
22 #include "aom_ports/mem.h"
23 #include "av1/common/convolve.h"
24 #include "av1/common/filter.h"
25 #include "av1/common/filter.h"
26 #include "av1/common/arm/highbd_compound_convolve_neon.h"
27 #include "av1/common/arm/highbd_convolve_neon.h"
28 #include "av1/common/arm/highbd_convolve_sve2.h"
29 
30 DECLARE_ALIGNED(16, static const uint16_t, kDotProdTbl[32]) = {
31   0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
32   4, 5, 6, 7, 5, 6, 7, 0, 6, 7, 0, 1, 7, 0, 1, 2,
33 };
34 
highbd_12_convolve8_8_x(int16x8_t s0[8],int16x8_t filter,int64x2_t offset)35 static inline uint16x8_t highbd_12_convolve8_8_x(int16x8_t s0[8],
36                                                  int16x8_t filter,
37                                                  int64x2_t offset) {
38   int64x2_t sum[8];
39   sum[0] = aom_sdotq_s16(offset, s0[0], filter);
40   sum[1] = aom_sdotq_s16(offset, s0[1], filter);
41   sum[2] = aom_sdotq_s16(offset, s0[2], filter);
42   sum[3] = aom_sdotq_s16(offset, s0[3], filter);
43   sum[4] = aom_sdotq_s16(offset, s0[4], filter);
44   sum[5] = aom_sdotq_s16(offset, s0[5], filter);
45   sum[6] = aom_sdotq_s16(offset, s0[6], filter);
46   sum[7] = aom_sdotq_s16(offset, s0[7], filter);
47 
48   sum[0] = vpaddq_s64(sum[0], sum[1]);
49   sum[2] = vpaddq_s64(sum[2], sum[3]);
50   sum[4] = vpaddq_s64(sum[4], sum[5]);
51   sum[6] = vpaddq_s64(sum[6], sum[7]);
52 
53   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
54   int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum[4]), vmovn_s64(sum[6]));
55 
56   return vcombine_u16(vqrshrun_n_s32(sum0123, ROUND0_BITS + 2),
57                       vqrshrun_n_s32(sum4567, ROUND0_BITS + 2));
58 }
59 
highbd_12_dist_wtd_convolve_x_8tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr)60 static inline void highbd_12_dist_wtd_convolve_x_8tap_sve2(
61     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
62     int width, int height, const int16_t *x_filter_ptr) {
63   const int64x1_t offset_vec =
64       vcreate_s64((1 << (12 + FILTER_BITS)) + (1 << (12 + FILTER_BITS - 1)));
65   const int64x2_t offset_lo = vcombine_s64(offset_vec, vdup_n_s64(0));
66 
67   const int16x8_t filter = vld1q_s16(x_filter_ptr);
68 
69   do {
70     const int16_t *s = (const int16_t *)src;
71     uint16_t *d = dst;
72     int w = width;
73 
74     do {
75       int16x8_t s0[8], s1[8], s2[8], s3[8];
76       load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
77                    &s0[4], &s0[5], &s0[6], &s0[7]);
78       load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
79                    &s1[4], &s1[5], &s1[6], &s1[7]);
80       load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
81                    &s2[4], &s2[5], &s2[6], &s2[7]);
82       load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
83                    &s3[4], &s3[5], &s3[6], &s3[7]);
84 
85       uint16x8_t d0 = highbd_12_convolve8_8_x(s0, filter, offset_lo);
86       uint16x8_t d1 = highbd_12_convolve8_8_x(s1, filter, offset_lo);
87       uint16x8_t d2 = highbd_12_convolve8_8_x(s2, filter, offset_lo);
88       uint16x8_t d3 = highbd_12_convolve8_8_x(s3, filter, offset_lo);
89 
90       store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
91 
92       s += 8;
93       d += 8;
94       w -= 8;
95     } while (w != 0);
96     src += 4 * src_stride;
97     dst += 4 * dst_stride;
98     height -= 4;
99   } while (height != 0);
100 }
101 
highbd_convolve8_8_x(int16x8_t s0[8],int16x8_t filter,int64x2_t offset)102 static inline uint16x8_t highbd_convolve8_8_x(int16x8_t s0[8], int16x8_t filter,
103                                               int64x2_t offset) {
104   int64x2_t sum[8];
105   sum[0] = aom_sdotq_s16(offset, s0[0], filter);
106   sum[1] = aom_sdotq_s16(offset, s0[1], filter);
107   sum[2] = aom_sdotq_s16(offset, s0[2], filter);
108   sum[3] = aom_sdotq_s16(offset, s0[3], filter);
109   sum[4] = aom_sdotq_s16(offset, s0[4], filter);
110   sum[5] = aom_sdotq_s16(offset, s0[5], filter);
111   sum[6] = aom_sdotq_s16(offset, s0[6], filter);
112   sum[7] = aom_sdotq_s16(offset, s0[7], filter);
113 
114   sum[0] = vpaddq_s64(sum[0], sum[1]);
115   sum[2] = vpaddq_s64(sum[2], sum[3]);
116   sum[4] = vpaddq_s64(sum[4], sum[5]);
117   sum[6] = vpaddq_s64(sum[6], sum[7]);
118 
119   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum[0]), vmovn_s64(sum[2]));
120   int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum[4]), vmovn_s64(sum[6]));
121 
122   return vcombine_u16(vqrshrun_n_s32(sum0123, ROUND0_BITS),
123                       vqrshrun_n_s32(sum4567, ROUND0_BITS));
124 }
125 
highbd_dist_wtd_convolve_x_8tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr,const int bd)126 static inline void highbd_dist_wtd_convolve_x_8tap_sve2(
127     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
128     int width, int height, const int16_t *x_filter_ptr, const int bd) {
129   const int64x1_t offset_vec =
130       vcreate_s64((1 << (bd + FILTER_BITS)) + (1 << (bd + FILTER_BITS - 1)));
131   const int64x2_t offset_lo = vcombine_s64(offset_vec, vdup_n_s64(0));
132 
133   const int16x8_t filter = vld1q_s16(x_filter_ptr);
134 
135   do {
136     const int16_t *s = (const int16_t *)src;
137     uint16_t *d = dst;
138     int w = width;
139 
140     do {
141       int16x8_t s0[8], s1[8], s2[8], s3[8];
142       load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
143                    &s0[4], &s0[5], &s0[6], &s0[7]);
144       load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
145                    &s1[4], &s1[5], &s1[6], &s1[7]);
146       load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
147                    &s2[4], &s2[5], &s2[6], &s2[7]);
148       load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
149                    &s3[4], &s3[5], &s3[6], &s3[7]);
150 
151       uint16x8_t d0 = highbd_convolve8_8_x(s0, filter, offset_lo);
152       uint16x8_t d1 = highbd_convolve8_8_x(s1, filter, offset_lo);
153       uint16x8_t d2 = highbd_convolve8_8_x(s2, filter, offset_lo);
154       uint16x8_t d3 = highbd_convolve8_8_x(s3, filter, offset_lo);
155 
156       store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
157 
158       s += 8;
159       d += 8;
160       w -= 8;
161     } while (w != 0);
162     src += 4 * src_stride;
163     dst += 4 * dst_stride;
164     height -= 4;
165   } while (height != 0);
166 }
167 
168 // clang-format off
169 DECLARE_ALIGNED(16, static const uint16_t, kDeinterleaveTbl[8]) = {
170   0, 2, 4, 6, 1, 3, 5, 7,
171 };
172 // clang-format on
173 
highbd_12_convolve4_4_x(int16x8_t s0,int16x8_t filter,int64x2_t offset,uint16x8x2_t permute_tbl)174 static inline uint16x4_t highbd_12_convolve4_4_x(int16x8_t s0, int16x8_t filter,
175                                                  int64x2_t offset,
176                                                  uint16x8x2_t permute_tbl) {
177   int16x8_t permuted_samples0 = aom_tbl_s16(s0, permute_tbl.val[0]);
178   int16x8_t permuted_samples1 = aom_tbl_s16(s0, permute_tbl.val[1]);
179 
180   int64x2_t sum01 = aom_svdot_lane_s16(offset, permuted_samples0, filter, 0);
181   int64x2_t sum23 = aom_svdot_lane_s16(offset, permuted_samples1, filter, 0);
182 
183   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
184 
185   return vqrshrun_n_s32(sum0123, ROUND0_BITS + 2);
186 }
187 
highbd_12_convolve4_8_x(int16x8_t s0[4],int16x8_t filter,int64x2_t offset,uint16x8_t tbl)188 static inline uint16x8_t highbd_12_convolve4_8_x(int16x8_t s0[4],
189                                                  int16x8_t filter,
190                                                  int64x2_t offset,
191                                                  uint16x8_t tbl) {
192   int64x2_t sum04 = aom_svdot_lane_s16(offset, s0[0], filter, 0);
193   int64x2_t sum15 = aom_svdot_lane_s16(offset, s0[1], filter, 0);
194   int64x2_t sum26 = aom_svdot_lane_s16(offset, s0[2], filter, 0);
195   int64x2_t sum37 = aom_svdot_lane_s16(offset, s0[3], filter, 0);
196 
197   int32x4_t sum0415 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
198   int32x4_t sum2637 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
199 
200   uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0415, ROUND0_BITS + 2),
201                                 vqrshrun_n_s32(sum2637, ROUND0_BITS + 2));
202   return aom_tbl_u16(res, tbl);
203 }
204 
highbd_12_dist_wtd_convolve_x_4tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr)205 static inline void highbd_12_dist_wtd_convolve_x_4tap_sve2(
206     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
207     int width, int height, const int16_t *x_filter_ptr) {
208   const int64x2_t offset =
209       vdupq_n_s64((1 << (12 + FILTER_BITS)) + (1 << (12 + FILTER_BITS - 1)));
210 
211   const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
212   const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
213 
214   if (width == 4) {
215     uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
216 
217     const int16_t *s = (const int16_t *)(src);
218 
219     do {
220       int16x8_t s0, s1, s2, s3;
221       load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
222 
223       uint16x4_t d0 = highbd_12_convolve4_4_x(s0, filter, offset, permute_tbl);
224       uint16x4_t d1 = highbd_12_convolve4_4_x(s1, filter, offset, permute_tbl);
225       uint16x4_t d2 = highbd_12_convolve4_4_x(s2, filter, offset, permute_tbl);
226       uint16x4_t d3 = highbd_12_convolve4_4_x(s3, filter, offset, permute_tbl);
227 
228       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
229 
230       s += 4 * src_stride;
231       dst += 4 * dst_stride;
232       height -= 4;
233     } while (height != 0);
234   } else {
235     uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
236 
237     do {
238       const int16_t *s = (const int16_t *)(src);
239       uint16_t *d = dst;
240       int w = width;
241 
242       do {
243         int16x8_t s0[4], s1[4], s2[4], s3[4];
244         load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
245         load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
246         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
247         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
248 
249         uint16x8_t d0 = highbd_12_convolve4_8_x(s0, filter, offset, idx);
250         uint16x8_t d1 = highbd_12_convolve4_8_x(s1, filter, offset, idx);
251         uint16x8_t d2 = highbd_12_convolve4_8_x(s2, filter, offset, idx);
252         uint16x8_t d3 = highbd_12_convolve4_8_x(s3, filter, offset, idx);
253 
254         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
255 
256         s += 8;
257         d += 8;
258         w -= 8;
259       } while (w != 0);
260       src += 4 * src_stride;
261       dst += 4 * dst_stride;
262       height -= 4;
263     } while (height != 0);
264   }
265 }
266 
highbd_convolve4_4_x(int16x8_t s0,int16x8_t filter,int64x2_t offset,uint16x8x2_t permute_tbl)267 static inline uint16x4_t highbd_convolve4_4_x(int16x8_t s0, int16x8_t filter,
268                                               int64x2_t offset,
269                                               uint16x8x2_t permute_tbl) {
270   int16x8_t permuted_samples0 = aom_tbl_s16(s0, permute_tbl.val[0]);
271   int16x8_t permuted_samples1 = aom_tbl_s16(s0, permute_tbl.val[1]);
272 
273   int64x2_t sum01 = aom_svdot_lane_s16(offset, permuted_samples0, filter, 0);
274   int64x2_t sum23 = aom_svdot_lane_s16(offset, permuted_samples1, filter, 0);
275 
276   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
277 
278   return vqrshrun_n_s32(sum0123, ROUND0_BITS);
279 }
280 
highbd_convolve4_8_x(int16x8_t s0[4],int16x8_t filter,int64x2_t offset,uint16x8_t tbl)281 static inline uint16x8_t highbd_convolve4_8_x(int16x8_t s0[4], int16x8_t filter,
282                                               int64x2_t offset,
283                                               uint16x8_t tbl) {
284   int64x2_t sum04 = aom_svdot_lane_s16(offset, s0[0], filter, 0);
285   int64x2_t sum15 = aom_svdot_lane_s16(offset, s0[1], filter, 0);
286   int64x2_t sum26 = aom_svdot_lane_s16(offset, s0[2], filter, 0);
287   int64x2_t sum37 = aom_svdot_lane_s16(offset, s0[3], filter, 0);
288 
289   int32x4_t sum0415 = vcombine_s32(vmovn_s64(sum04), vmovn_s64(sum15));
290   int32x4_t sum2637 = vcombine_s32(vmovn_s64(sum26), vmovn_s64(sum37));
291 
292   uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0415, ROUND0_BITS),
293                                 vqrshrun_n_s32(sum2637, ROUND0_BITS));
294   return aom_tbl_u16(res, tbl);
295 }
296 
highbd_dist_wtd_convolve_x_4tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr,const int bd)297 static inline void highbd_dist_wtd_convolve_x_4tap_sve2(
298     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
299     int width, int height, const int16_t *x_filter_ptr, const int bd) {
300   const int64x2_t offset =
301       vdupq_n_s64((1 << (bd + FILTER_BITS)) + (1 << (bd + FILTER_BITS - 1)));
302 
303   const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
304   const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
305 
306   if (width == 4) {
307     uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
308 
309     const int16_t *s = (const int16_t *)(src);
310 
311     do {
312       int16x8_t s0, s1, s2, s3;
313       load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
314 
315       uint16x4_t d0 = highbd_convolve4_4_x(s0, filter, offset, permute_tbl);
316       uint16x4_t d1 = highbd_convolve4_4_x(s1, filter, offset, permute_tbl);
317       uint16x4_t d2 = highbd_convolve4_4_x(s2, filter, offset, permute_tbl);
318       uint16x4_t d3 = highbd_convolve4_4_x(s3, filter, offset, permute_tbl);
319 
320       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
321 
322       s += 4 * src_stride;
323       dst += 4 * dst_stride;
324       height -= 4;
325     } while (height != 0);
326   } else {
327     uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
328 
329     do {
330       const int16_t *s = (const int16_t *)(src);
331       uint16_t *d = dst;
332       int w = width;
333 
334       do {
335         int16x8_t s0[4], s1[4], s2[4], s3[4];
336         load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
337         load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
338         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
339         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
340 
341         uint16x8_t d0 = highbd_convolve4_8_x(s0, filter, offset, idx);
342         uint16x8_t d1 = highbd_convolve4_8_x(s1, filter, offset, idx);
343         uint16x8_t d2 = highbd_convolve4_8_x(s2, filter, offset, idx);
344         uint16x8_t d3 = highbd_convolve4_8_x(s3, filter, offset, idx);
345 
346         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
347 
348         s += 8;
349         d += 8;
350         w -= 8;
351       } while (w != 0);
352       src += 4 * src_stride;
353       dst += 4 * dst_stride;
354       height -= 4;
355     } while (height != 0);
356   }
357 }
358 
av1_highbd_dist_wtd_convolve_x_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params,int bd)359 void av1_highbd_dist_wtd_convolve_x_sve2(
360     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
361     int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
362     ConvolveParams *conv_params, int bd) {
363   DECLARE_ALIGNED(16, uint16_t,
364                   im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
365   CONV_BUF_TYPE *dst16 = conv_params->dst;
366   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
367 
368   if (x_filter_taps == 6) {
369     av1_highbd_dist_wtd_convolve_x_neon(src, src_stride, dst, dst_stride, w, h,
370                                         filter_params_x, subpel_x_qn,
371                                         conv_params, bd);
372     return;
373   }
374 
375   int dst16_stride = conv_params->dst_stride;
376   const int im_stride = MAX_SB_SIZE;
377   const int horiz_offset = filter_params_x->taps / 2 - 1;
378   assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
379 
380   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
381       filter_params_x, subpel_x_qn & SUBPEL_MASK);
382 
383   src -= horiz_offset;
384 
385   if (bd == 12) {
386     if (conv_params->do_average) {
387       if (x_filter_taps <= 4) {
388         highbd_12_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, im_block,
389                                                 im_stride, w, h, x_filter_ptr);
390       } else {
391         highbd_12_dist_wtd_convolve_x_8tap_sve2(src, src_stride, im_block,
392                                                 im_stride, w, h, x_filter_ptr);
393       }
394 
395       if (conv_params->use_dist_wtd_comp_avg) {
396         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
397                                          w, h, conv_params);
398 
399       } else {
400         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
401                                 conv_params);
402       }
403     } else {
404       if (x_filter_taps <= 4) {
405         highbd_12_dist_wtd_convolve_x_4tap_sve2(
406             src + 2, src_stride, dst16, dst16_stride, w, h, x_filter_ptr);
407       } else {
408         highbd_12_dist_wtd_convolve_x_8tap_sve2(
409             src, src_stride, dst16, dst16_stride, w, h, x_filter_ptr);
410       }
411     }
412   } else {
413     if (conv_params->do_average) {
414       if (x_filter_taps <= 4) {
415         highbd_dist_wtd_convolve_x_4tap_sve2(src + 2, src_stride, im_block,
416                                              im_stride, w, h, x_filter_ptr, bd);
417       } else {
418         highbd_dist_wtd_convolve_x_8tap_sve2(src, src_stride, im_block,
419                                              im_stride, w, h, x_filter_ptr, bd);
420       }
421 
422       if (conv_params->use_dist_wtd_comp_avg) {
423         highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
424                                       h, conv_params, bd);
425       } else {
426         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
427                              conv_params, bd);
428       }
429     } else {
430       if (x_filter_taps <= 4) {
431         highbd_dist_wtd_convolve_x_4tap_sve2(
432             src + 2, src_stride, dst16, dst16_stride, w, h, x_filter_ptr, bd);
433       } else {
434         highbd_dist_wtd_convolve_x_8tap_sve2(
435             src, src_stride, dst16, dst16_stride, w, h, x_filter_ptr, bd);
436       }
437     }
438   }
439 }
440 
highbd_12_convolve8_4_y(int16x8_t samples_lo[2],int16x8_t samples_hi[2],int16x8_t filter,int64x2_t offset)441 static inline uint16x4_t highbd_12_convolve8_4_y(int16x8_t samples_lo[2],
442                                                  int16x8_t samples_hi[2],
443                                                  int16x8_t filter,
444                                                  int64x2_t offset) {
445   int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
446   sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
447 
448   int64x2_t sum23 = aom_svdot_lane_s16(offset, samples_lo[1], filter, 0);
449   sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
450 
451   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
452 
453   return vqrshrun_n_s32(sum0123, ROUND0_BITS + 2);
454 }
455 
highbd_12_convolve8_8_y(int16x8_t samples_lo[4],int16x8_t samples_hi[4],int16x8_t filter,int64x2_t offset)456 static inline uint16x8_t highbd_12_convolve8_8_y(int16x8_t samples_lo[4],
457                                                  int16x8_t samples_hi[4],
458                                                  int16x8_t filter,
459                                                  int64x2_t offset) {
460   int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
461   sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
462 
463   int64x2_t sum23 = aom_svdot_lane_s16(offset, samples_lo[1], filter, 0);
464   sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
465 
466   int64x2_t sum45 = aom_svdot_lane_s16(offset, samples_lo[2], filter, 0);
467   sum45 = aom_svdot_lane_s16(sum45, samples_hi[2], filter, 1);
468 
469   int64x2_t sum67 = aom_svdot_lane_s16(offset, samples_lo[3], filter, 0);
470   sum67 = aom_svdot_lane_s16(sum67, samples_hi[3], filter, 1);
471 
472   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
473   int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum45), vmovn_s64(sum67));
474 
475   return vcombine_u16(vqrshrun_n_s32(sum0123, ROUND0_BITS + 2),
476                       vqrshrun_n_s32(sum4567, ROUND0_BITS + 2));
477 }
478 
highbd_12_dist_wtd_convolve_y_8tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * y_filter_ptr)479 static inline void highbd_12_dist_wtd_convolve_y_8tap_sve2(
480     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
481     int width, int height, const int16_t *y_filter_ptr) {
482   const int64x2_t offset =
483       vdupq_n_s64((1 << (12 + FILTER_BITS)) + (1 << (12 + FILTER_BITS - 1)));
484   const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
485 
486   uint16x8x3_t merge_block_tbl = vld1q_u16_x3(kDotProdMergeBlockTbl);
487   // Scale indices by size of the true vector length to avoid reading from an
488   // 'undefined' portion of a vector on a system with SVE vectors > 128-bit.
489   uint16x8_t correction0 =
490       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000000000000ULL));
491   merge_block_tbl.val[0] = vaddq_u16(merge_block_tbl.val[0], correction0);
492   uint16x8_t correction1 =
493       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000100000000ULL));
494   merge_block_tbl.val[1] = vaddq_u16(merge_block_tbl.val[1], correction1);
495 
496   uint16x8_t correction2 =
497       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000100010000ULL));
498   merge_block_tbl.val[2] = vaddq_u16(merge_block_tbl.val[2], correction2);
499 
500   if (width == 4) {
501     int16_t *s = (int16_t *)src;
502     int16x4_t s0, s1, s2, s3, s4, s5, s6;
503     load_s16_4x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
504     s += 7 * src_stride;
505 
506     // This operation combines a conventional transpose and the sample permute
507     // required before computing the dot product.
508     int16x8_t s0123[2], s1234[2], s2345[2], s3456[2];
509     transpose_concat_4x4(s0, s1, s2, s3, s0123);
510     transpose_concat_4x4(s1, s2, s3, s4, s1234);
511     transpose_concat_4x4(s2, s3, s4, s5, s2345);
512     transpose_concat_4x4(s3, s4, s5, s6, s3456);
513 
514     do {
515       int16x4_t s7, s8, s9, s10;
516       load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10);
517 
518       int16x8_t s4567[2], s5678[2], s6789[2], s789A[2];
519       // Transpose and shuffle the 4 lines that were loaded.
520       transpose_concat_4x4(s7, s8, s9, s10, s789A);
521 
522       // Merge new data into block from previous iteration.
523       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[0], s4567);
524       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
525       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
526 
527       uint16x4_t d0 = highbd_12_convolve8_4_y(s0123, s4567, y_filter, offset);
528       uint16x4_t d1 = highbd_12_convolve8_4_y(s1234, s5678, y_filter, offset);
529       uint16x4_t d2 = highbd_12_convolve8_4_y(s2345, s6789, y_filter, offset);
530       uint16x4_t d3 = highbd_12_convolve8_4_y(s3456, s789A, y_filter, offset);
531 
532       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
533 
534       // Prepare block for next iteration - re-using as much as possible.
535       // Shuffle everything up four rows.
536       s0123[0] = s4567[0];
537       s0123[1] = s4567[1];
538       s1234[0] = s5678[0];
539       s1234[1] = s5678[1];
540       s2345[0] = s6789[0];
541       s2345[1] = s6789[1];
542       s3456[0] = s789A[0];
543       s3456[1] = s789A[1];
544 
545       s += 4 * src_stride;
546       dst += 4 * dst_stride;
547       height -= 4;
548     } while (height != 0);
549   } else {
550     do {
551       int h = height;
552       int16_t *s = (int16_t *)src;
553       uint16_t *d = dst;
554 
555       int16x8_t s0, s1, s2, s3, s4, s5, s6;
556       load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
557       s += 7 * src_stride;
558 
559       // This operation combines a conventional transpose and the sample permute
560       // required before computing the dot product.
561       int16x8_t s0123[4], s1234[4], s2345[4], s3456[4];
562       transpose_concat_8x4(s0, s1, s2, s3, s0123);
563       transpose_concat_8x4(s1, s2, s3, s4, s1234);
564       transpose_concat_8x4(s2, s3, s4, s5, s2345);
565       transpose_concat_8x4(s3, s4, s5, s6, s3456);
566 
567       do {
568         int16x8_t s7, s8, s9, s10;
569         load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
570         int16x8_t s4567[4], s5678[4], s6789[4], s789A[4];
571 
572         // Transpose and shuffle the 4 lines that were loaded.
573         transpose_concat_8x4(s7, s8, s9, s10, s789A);
574 
575         // Merge new data into block from previous iteration.
576         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[0], s4567);
577         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
578         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
579 
580         uint16x8_t d0 = highbd_12_convolve8_8_y(s0123, s4567, y_filter, offset);
581         uint16x8_t d1 = highbd_12_convolve8_8_y(s1234, s5678, y_filter, offset);
582         uint16x8_t d2 = highbd_12_convolve8_8_y(s2345, s6789, y_filter, offset);
583         uint16x8_t d3 = highbd_12_convolve8_8_y(s3456, s789A, y_filter, offset);
584 
585         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
586 
587         // Prepare block for next iteration - re-using as much as possible.
588         // Shuffle everything up four rows.
589         s0123[0] = s4567[0];
590         s0123[1] = s4567[1];
591         s0123[2] = s4567[2];
592         s0123[3] = s4567[3];
593         s1234[0] = s5678[0];
594         s1234[1] = s5678[1];
595         s1234[2] = s5678[2];
596         s1234[3] = s5678[3];
597         s2345[0] = s6789[0];
598         s2345[1] = s6789[1];
599         s2345[2] = s6789[2];
600         s2345[3] = s6789[3];
601         s3456[0] = s789A[0];
602         s3456[1] = s789A[1];
603         s3456[2] = s789A[2];
604         s3456[3] = s789A[3];
605 
606         s += 4 * src_stride;
607         d += 4 * dst_stride;
608         h -= 4;
609       } while (h != 0);
610       src += 8;
611       dst += 8;
612       width -= 8;
613     } while (width != 0);
614   }
615 }
616 
highbd_convolve8_4_y(int16x8_t samples_lo[2],int16x8_t samples_hi[2],int16x8_t filter,int64x2_t offset)617 static inline uint16x4_t highbd_convolve8_4_y(int16x8_t samples_lo[2],
618                                               int16x8_t samples_hi[2],
619                                               int16x8_t filter,
620                                               int64x2_t offset) {
621   int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
622   sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
623 
624   int64x2_t sum23 = aom_svdot_lane_s16(offset, samples_lo[1], filter, 0);
625   sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
626 
627   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
628 
629   return vqrshrun_n_s32(sum0123, ROUND0_BITS);
630 }
631 
highbd_convolve8_8_y(int16x8_t samples_lo[4],int16x8_t samples_hi[4],int16x8_t filter,int64x2_t offset)632 static inline uint16x8_t highbd_convolve8_8_y(int16x8_t samples_lo[4],
633                                               int16x8_t samples_hi[4],
634                                               int16x8_t filter,
635                                               int64x2_t offset) {
636   int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
637   sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
638 
639   int64x2_t sum23 = aom_svdot_lane_s16(offset, samples_lo[1], filter, 0);
640   sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
641 
642   int64x2_t sum45 = aom_svdot_lane_s16(offset, samples_lo[2], filter, 0);
643   sum45 = aom_svdot_lane_s16(sum45, samples_hi[2], filter, 1);
644 
645   int64x2_t sum67 = aom_svdot_lane_s16(offset, samples_lo[3], filter, 0);
646   sum67 = aom_svdot_lane_s16(sum67, samples_hi[3], filter, 1);
647 
648   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
649   int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum45), vmovn_s64(sum67));
650 
651   return vcombine_u16(vqrshrun_n_s32(sum0123, ROUND0_BITS),
652                       vqrshrun_n_s32(sum4567, ROUND0_BITS));
653 }
654 
highbd_dist_wtd_convolve_y_8tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * y_filter_ptr,const int bd)655 static inline void highbd_dist_wtd_convolve_y_8tap_sve2(
656     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
657     int width, int height, const int16_t *y_filter_ptr, const int bd) {
658   const int64x2_t offset =
659       vdupq_n_s64((1 << (bd + FILTER_BITS)) + (1 << (bd + FILTER_BITS - 1)));
660   const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
661 
662   uint16x8x3_t merge_block_tbl = vld1q_u16_x3(kDotProdMergeBlockTbl);
663   // Scale indices by size of the true vector length to avoid reading from an
664   // 'undefined' portion of a vector on a system with SVE vectors > 128-bit.
665   uint16x8_t correction0 =
666       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000000000000ULL));
667   merge_block_tbl.val[0] = vaddq_u16(merge_block_tbl.val[0], correction0);
668   uint16x8_t correction1 =
669       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000100000000ULL));
670   merge_block_tbl.val[1] = vaddq_u16(merge_block_tbl.val[1], correction1);
671 
672   uint16x8_t correction2 =
673       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000100010000ULL));
674   merge_block_tbl.val[2] = vaddq_u16(merge_block_tbl.val[2], correction2);
675 
676   if (width == 4) {
677     int16_t *s = (int16_t *)src;
678     int16x4_t s0, s1, s2, s3, s4, s5, s6;
679     load_s16_4x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
680     s += 7 * src_stride;
681 
682     // This operation combines a conventional transpose and the sample permute
683     // required before computing the dot product.
684     int16x8_t s0123[2], s1234[2], s2345[2], s3456[2];
685     transpose_concat_4x4(s0, s1, s2, s3, s0123);
686     transpose_concat_4x4(s1, s2, s3, s4, s1234);
687     transpose_concat_4x4(s2, s3, s4, s5, s2345);
688     transpose_concat_4x4(s3, s4, s5, s6, s3456);
689 
690     do {
691       int16x4_t s7, s8, s9, s10;
692       load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10);
693 
694       int16x8_t s4567[2], s5678[2], s6789[2], s789A[2];
695       // Transpose and shuffle the 4 lines that were loaded.
696       transpose_concat_4x4(s7, s8, s9, s10, s789A);
697 
698       // Merge new data into block from previous iteration.
699       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[0], s4567);
700       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
701       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
702 
703       uint16x4_t d0 = highbd_convolve8_4_y(s0123, s4567, y_filter, offset);
704       uint16x4_t d1 = highbd_convolve8_4_y(s1234, s5678, y_filter, offset);
705       uint16x4_t d2 = highbd_convolve8_4_y(s2345, s6789, y_filter, offset);
706       uint16x4_t d3 = highbd_convolve8_4_y(s3456, s789A, y_filter, offset);
707 
708       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
709 
710       // Prepare block for next iteration - re-using as much as possible.
711       // Shuffle everything up four rows.
712       s0123[0] = s4567[0];
713       s0123[1] = s4567[1];
714       s1234[0] = s5678[0];
715       s1234[1] = s5678[1];
716       s2345[0] = s6789[0];
717       s2345[1] = s6789[1];
718       s3456[0] = s789A[0];
719       s3456[1] = s789A[1];
720 
721       s += 4 * src_stride;
722       dst += 4 * dst_stride;
723       height -= 4;
724     } while (height != 0);
725   } else {
726     do {
727       int h = height;
728       int16_t *s = (int16_t *)src;
729       uint16_t *d = dst;
730 
731       int16x8_t s0, s1, s2, s3, s4, s5, s6;
732       load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
733       s += 7 * src_stride;
734 
735       // This operation combines a conventional transpose and the sample permute
736       // required before computing the dot product.
737       int16x8_t s0123[4], s1234[4], s2345[4], s3456[4];
738       transpose_concat_8x4(s0, s1, s2, s3, s0123);
739       transpose_concat_8x4(s1, s2, s3, s4, s1234);
740       transpose_concat_8x4(s2, s3, s4, s5, s2345);
741       transpose_concat_8x4(s3, s4, s5, s6, s3456);
742 
743       do {
744         int16x8_t s7, s8, s9, s10;
745         load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
746         int16x8_t s4567[4], s5678[4], s6789[4], s789A[4];
747 
748         // Transpose and shuffle the 4 lines that were loaded.
749         transpose_concat_8x4(s7, s8, s9, s10, s789A);
750 
751         // Merge new data into block from previous iteration.
752         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[0], s4567);
753         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
754         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
755 
756         uint16x8_t d0 = highbd_convolve8_8_y(s0123, s4567, y_filter, offset);
757         uint16x8_t d1 = highbd_convolve8_8_y(s1234, s5678, y_filter, offset);
758         uint16x8_t d2 = highbd_convolve8_8_y(s2345, s6789, y_filter, offset);
759         uint16x8_t d3 = highbd_convolve8_8_y(s3456, s789A, y_filter, offset);
760 
761         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
762 
763         // Prepare block for next iteration - re-using as much as possible.
764         // Shuffle everything up four rows.
765         s0123[0] = s4567[0];
766         s0123[1] = s4567[1];
767         s0123[2] = s4567[2];
768         s0123[3] = s4567[3];
769         s1234[0] = s5678[0];
770         s1234[1] = s5678[1];
771         s1234[2] = s5678[2];
772         s1234[3] = s5678[3];
773         s2345[0] = s6789[0];
774         s2345[1] = s6789[1];
775         s2345[2] = s6789[2];
776         s2345[3] = s6789[3];
777         s3456[0] = s789A[0];
778         s3456[1] = s789A[1];
779         s3456[2] = s789A[2];
780         s3456[3] = s789A[3];
781 
782         s += 4 * src_stride;
783         d += 4 * dst_stride;
784         h -= 4;
785       } while (h != 0);
786       src += 8;
787       dst += 8;
788       width -= 8;
789     } while (width != 0);
790   }
791 }
792 
av1_highbd_dist_wtd_convolve_y_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_y,const int subpel_y_qn,ConvolveParams * conv_params,int bd)793 void av1_highbd_dist_wtd_convolve_y_sve2(
794     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
795     int h, const InterpFilterParams *filter_params_y, const int subpel_y_qn,
796     ConvolveParams *conv_params, int bd) {
797   DECLARE_ALIGNED(16, uint16_t,
798                   im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
799   CONV_BUF_TYPE *dst16 = conv_params->dst;
800   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
801 
802   if (y_filter_taps != 8) {
803     av1_highbd_dist_wtd_convolve_y_neon(src, src_stride, dst, dst_stride, w, h,
804                                         filter_params_y, subpel_y_qn,
805                                         conv_params, bd);
806     return;
807   }
808 
809   int dst16_stride = conv_params->dst_stride;
810   const int im_stride = MAX_SB_SIZE;
811   const int vert_offset = filter_params_y->taps / 2 - 1;
812   assert(FILTER_BITS == COMPOUND_ROUND1_BITS);
813 
814   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
815       filter_params_y, subpel_y_qn & SUBPEL_MASK);
816 
817   src -= vert_offset * src_stride;
818 
819   if (bd == 12) {
820     if (conv_params->do_average) {
821       highbd_12_dist_wtd_convolve_y_8tap_sve2(src, src_stride, im_block,
822                                               im_stride, w, h, y_filter_ptr);
823       if (conv_params->use_dist_wtd_comp_avg) {
824         highbd_12_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride,
825                                          w, h, conv_params);
826       } else {
827         highbd_12_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
828                                 conv_params);
829       }
830     } else {
831       highbd_12_dist_wtd_convolve_y_8tap_sve2(src, src_stride, dst16,
832                                               dst16_stride, w, h, y_filter_ptr);
833     }
834   } else {
835     if (conv_params->do_average) {
836       highbd_dist_wtd_convolve_y_8tap_sve2(src, src_stride, im_block, im_stride,
837                                            w, h, y_filter_ptr, bd);
838       if (conv_params->use_dist_wtd_comp_avg) {
839         highbd_dist_wtd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w,
840                                       h, conv_params, bd);
841       } else {
842         highbd_comp_avg_neon(im_block, im_stride, dst, dst_stride, w, h,
843                              conv_params, bd);
844       }
845     } else {
846       highbd_dist_wtd_convolve_y_8tap_sve2(src, src_stride, dst16, dst16_stride,
847                                            w, h, y_filter_ptr, bd);
848     }
849   }
850 }
851 
highbd_12_dist_wtd_convolve_2d_horiz_8tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr)852 static inline void highbd_12_dist_wtd_convolve_2d_horiz_8tap_sve2(
853     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
854     int width, int height, const int16_t *x_filter_ptr) {
855   const int64x2_t offset = vdupq_n_s64(1 << (12 + FILTER_BITS - 2));
856   const int16x8_t filter = vld1q_s16(x_filter_ptr);
857 
858   // We are only doing 8-tap and 4-tap vertical convolutions, therefore we know
859   // that im_h % 4 = 3, so we can do the loop across the whole block 4 rows at
860   // a time and then process the last 3 rows separately.
861 
862   do {
863     const int16_t *s = (const int16_t *)src;
864     uint16_t *d = dst;
865     int w = width;
866 
867     do {
868       int16x8_t s0[8], s1[8], s2[8], s3[8];
869       load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
870                    &s0[4], &s0[5], &s0[6], &s0[7]);
871       load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
872                    &s1[4], &s1[5], &s1[6], &s1[7]);
873       load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
874                    &s2[4], &s2[5], &s2[6], &s2[7]);
875       load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
876                    &s3[4], &s3[5], &s3[6], &s3[7]);
877 
878       uint16x8_t d0 = highbd_12_convolve8_8_x(s0, filter, offset);
879       uint16x8_t d1 = highbd_12_convolve8_8_x(s1, filter, offset);
880       uint16x8_t d2 = highbd_12_convolve8_8_x(s2, filter, offset);
881       uint16x8_t d3 = highbd_12_convolve8_8_x(s3, filter, offset);
882 
883       store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
884 
885       s += 8;
886       d += 8;
887       w -= 8;
888     } while (w != 0);
889     src += 4 * src_stride;
890     dst += 4 * dst_stride;
891     height -= 4;
892   } while (height > 4);
893 
894   // Process final 3 rows.
895   const int16_t *s = (const int16_t *)src;
896   do {
897     int16x8_t s0[8], s1[8], s2[8];
898     load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3], &s0[4],
899                  &s0[5], &s0[6], &s0[7]);
900     load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3], &s1[4],
901                  &s1[5], &s1[6], &s1[7]);
902     load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3], &s2[4],
903                  &s2[5], &s2[6], &s2[7]);
904 
905     uint16x8_t d0 = highbd_12_convolve8_8_x(s0, filter, offset);
906     uint16x8_t d1 = highbd_12_convolve8_8_x(s1, filter, offset);
907     uint16x8_t d2 = highbd_12_convolve8_8_x(s2, filter, offset);
908 
909     store_u16_8x3(dst, dst_stride, d0, d1, d2);
910     s += 8;
911     dst += 8;
912     width -= 8;
913   } while (width != 0);
914 }
915 
highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr,const int bd)916 static inline void highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
917     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
918     int width, int height, const int16_t *x_filter_ptr, const int bd) {
919   const int64x2_t offset = vdupq_n_s64(1 << (bd + FILTER_BITS - 2));
920   const int16x8_t filter = vld1q_s16(x_filter_ptr);
921 
922   // We are only doing 8-tap and 4-tap vertical convolutions, therefore we know
923   // that im_h % 4 = 3, so we can do the loop across the whole block 4 rows at
924   // a time and then process the last 3 rows separately.
925 
926   do {
927     const int16_t *s = (const int16_t *)src;
928     uint16_t *d = dst;
929     int w = width;
930 
931     do {
932       int16x8_t s0[8], s1[8], s2[8], s3[8];
933       load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3],
934                    &s0[4], &s0[5], &s0[6], &s0[7]);
935       load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3],
936                    &s1[4], &s1[5], &s1[6], &s1[7]);
937       load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3],
938                    &s2[4], &s2[5], &s2[6], &s2[7]);
939       load_s16_8x8(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3],
940                    &s3[4], &s3[5], &s3[6], &s3[7]);
941 
942       uint16x8_t d0 = highbd_convolve8_8_x(s0, filter, offset);
943       uint16x8_t d1 = highbd_convolve8_8_x(s1, filter, offset);
944       uint16x8_t d2 = highbd_convolve8_8_x(s2, filter, offset);
945       uint16x8_t d3 = highbd_convolve8_8_x(s3, filter, offset);
946 
947       store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
948 
949       s += 8;
950       d += 8;
951       w -= 8;
952     } while (w != 0);
953     src += 4 * src_stride;
954     dst += 4 * dst_stride;
955     height -= 4;
956   } while (height > 4);
957 
958   // Process final 3 rows.
959   const int16_t *s = (const int16_t *)src;
960   do {
961     int16x8_t s0[8], s1[8], s2[8];
962     load_s16_8x8(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3], &s0[4],
963                  &s0[5], &s0[6], &s0[7]);
964     load_s16_8x8(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3], &s1[4],
965                  &s1[5], &s1[6], &s1[7]);
966     load_s16_8x8(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3], &s2[4],
967                  &s2[5], &s2[6], &s2[7]);
968 
969     uint16x8_t d0 = highbd_convolve8_8_x(s0, filter, offset);
970     uint16x8_t d1 = highbd_convolve8_8_x(s1, filter, offset);
971     uint16x8_t d2 = highbd_convolve8_8_x(s2, filter, offset);
972 
973     store_u16_8x3(dst, dst_stride, d0, d1, d2);
974     s += 8;
975     dst += 8;
976     width -= 8;
977   } while (width != 0);
978 }
979 
highbd_12_dist_wtd_convolve_2d_horiz_4tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr)980 static inline void highbd_12_dist_wtd_convolve_2d_horiz_4tap_sve2(
981     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
982     int width, int height, const int16_t *x_filter_ptr) {
983   const int64x2_t offset = vdupq_n_s64(1 << (12 + FILTER_BITS - 1));
984   const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
985   const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
986 
987   // We are only doing 8-tap and 4-tap vertical convolutions, therefore we know
988   // that im_h % 4 = 3, so we can do the loop across the whole block 4 rows at
989   // a time and then process the last 3 rows separately.
990 
991   if (width == 4) {
992     uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
993 
994     const int16_t *s = (const int16_t *)(src);
995 
996     do {
997       int16x8_t s0, s1, s2, s3;
998       load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
999 
1000       uint16x4_t d0 = highbd_12_convolve4_4_x(s0, filter, offset, permute_tbl);
1001       uint16x4_t d1 = highbd_12_convolve4_4_x(s1, filter, offset, permute_tbl);
1002       uint16x4_t d2 = highbd_12_convolve4_4_x(s2, filter, offset, permute_tbl);
1003       uint16x4_t d3 = highbd_12_convolve4_4_x(s3, filter, offset, permute_tbl);
1004 
1005       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
1006 
1007       s += 4 * src_stride;
1008       dst += 4 * dst_stride;
1009       height -= 4;
1010     } while (height > 4);
1011 
1012     // Process final 3 rows.
1013     int16x8_t s0, s1, s2;
1014     load_s16_8x3(s, src_stride, &s0, &s1, &s2);
1015 
1016     uint16x4_t d0 = highbd_12_convolve4_4_x(s0, filter, offset, permute_tbl);
1017     uint16x4_t d1 = highbd_12_convolve4_4_x(s1, filter, offset, permute_tbl);
1018     uint16x4_t d2 = highbd_12_convolve4_4_x(s2, filter, offset, permute_tbl);
1019 
1020     store_u16_4x3(dst, dst_stride, d0, d1, d2);
1021 
1022   } else {
1023     uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
1024 
1025     do {
1026       const int16_t *s = (const int16_t *)(src);
1027       uint16_t *d = dst;
1028       int w = width;
1029 
1030       do {
1031         int16x8_t s0[4], s1[4], s2[4], s3[4];
1032         load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
1033         load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
1034         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
1035         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
1036 
1037         uint16x8_t d0 = highbd_12_convolve4_8_x(s0, filter, offset, idx);
1038         uint16x8_t d1 = highbd_12_convolve4_8_x(s1, filter, offset, idx);
1039         uint16x8_t d2 = highbd_12_convolve4_8_x(s2, filter, offset, idx);
1040         uint16x8_t d3 = highbd_12_convolve4_8_x(s3, filter, offset, idx);
1041 
1042         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
1043 
1044         s += 8;
1045         d += 8;
1046         w -= 8;
1047       } while (w != 0);
1048       src += 4 * src_stride;
1049       dst += 4 * dst_stride;
1050       height -= 4;
1051     } while (height > 4);
1052 
1053     // Process final 3 rows.
1054     const int16_t *s = (const int16_t *)(src);
1055 
1056     do {
1057       int16x8_t s0[4], s1[4], s2[4];
1058       load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
1059       load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
1060       load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
1061 
1062       uint16x8_t d0 = highbd_12_convolve4_8_x(s0, filter, offset, idx);
1063       uint16x8_t d1 = highbd_12_convolve4_8_x(s1, filter, offset, idx);
1064       uint16x8_t d2 = highbd_12_convolve4_8_x(s2, filter, offset, idx);
1065 
1066       store_u16_8x3(dst, dst_stride, d0, d1, d2);
1067 
1068       s += 8;
1069       dst += 8;
1070       width -= 8;
1071     } while (width != 0);
1072   }
1073 }
1074 
highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * x_filter_ptr,const int bd)1075 static inline void highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
1076     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
1077     int width, int height, const int16_t *x_filter_ptr, const int bd) {
1078   const int64x2_t offset = vdupq_n_s64(1 << (bd + FILTER_BITS - 1));
1079   const int16x4_t x_filter = vld1_s16(x_filter_ptr + 2);
1080   const int16x8_t filter = vcombine_s16(x_filter, vdup_n_s16(0));
1081 
1082   // We are only doing 8-tap and 4-tap vertical convolutions, therefore we know
1083   // that im_h % 4 = 3, so we can do the loop across the whole block 4 rows at
1084   // a time and then process the last 3 rows separately.
1085 
1086   if (width == 4) {
1087     uint16x8x2_t permute_tbl = vld1q_u16_x2(kDotProdTbl);
1088 
1089     const int16_t *s = (const int16_t *)(src);
1090 
1091     do {
1092       int16x8_t s0, s1, s2, s3;
1093       load_s16_8x4(s, src_stride, &s0, &s1, &s2, &s3);
1094 
1095       uint16x4_t d0 = highbd_convolve4_4_x(s0, filter, offset, permute_tbl);
1096       uint16x4_t d1 = highbd_convolve4_4_x(s1, filter, offset, permute_tbl);
1097       uint16x4_t d2 = highbd_convolve4_4_x(s2, filter, offset, permute_tbl);
1098       uint16x4_t d3 = highbd_convolve4_4_x(s3, filter, offset, permute_tbl);
1099 
1100       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
1101 
1102       s += 4 * src_stride;
1103       dst += 4 * dst_stride;
1104       height -= 4;
1105     } while (height > 4);
1106 
1107     // Process final 3 rows.
1108     int16x8_t s0, s1, s2;
1109     load_s16_8x3(s, src_stride, &s0, &s1, &s2);
1110 
1111     uint16x4_t d0 = highbd_convolve4_4_x(s0, filter, offset, permute_tbl);
1112     uint16x4_t d1 = highbd_convolve4_4_x(s1, filter, offset, permute_tbl);
1113     uint16x4_t d2 = highbd_convolve4_4_x(s2, filter, offset, permute_tbl);
1114 
1115     store_u16_4x3(dst, dst_stride, d0, d1, d2);
1116   } else {
1117     uint16x8_t idx = vld1q_u16(kDeinterleaveTbl);
1118 
1119     do {
1120       const int16_t *s = (const int16_t *)(src);
1121       uint16_t *d = dst;
1122       int w = width;
1123 
1124       do {
1125         int16x8_t s0[4], s1[4], s2[4], s3[4];
1126         load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
1127         load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
1128         load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
1129         load_s16_8x4(s + 3 * src_stride, 1, &s3[0], &s3[1], &s3[2], &s3[3]);
1130 
1131         uint16x8_t d0 = highbd_convolve4_8_x(s0, filter, offset, idx);
1132         uint16x8_t d1 = highbd_convolve4_8_x(s1, filter, offset, idx);
1133         uint16x8_t d2 = highbd_convolve4_8_x(s2, filter, offset, idx);
1134         uint16x8_t d3 = highbd_convolve4_8_x(s3, filter, offset, idx);
1135 
1136         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
1137 
1138         s += 8;
1139         d += 8;
1140         w -= 8;
1141       } while (w != 0);
1142       src += 4 * src_stride;
1143       dst += 4 * dst_stride;
1144       height -= 4;
1145     } while (height > 4);
1146 
1147     // Process final 3 rows.
1148     const int16_t *s = (const int16_t *)(src);
1149 
1150     do {
1151       int16x8_t s0[4], s1[4], s2[4];
1152       load_s16_8x4(s + 0 * src_stride, 1, &s0[0], &s0[1], &s0[2], &s0[3]);
1153       load_s16_8x4(s + 1 * src_stride, 1, &s1[0], &s1[1], &s1[2], &s1[3]);
1154       load_s16_8x4(s + 2 * src_stride, 1, &s2[0], &s2[1], &s2[2], &s2[3]);
1155 
1156       uint16x8_t d0 = highbd_convolve4_8_x(s0, filter, offset, idx);
1157       uint16x8_t d1 = highbd_convolve4_8_x(s1, filter, offset, idx);
1158       uint16x8_t d2 = highbd_convolve4_8_x(s2, filter, offset, idx);
1159 
1160       store_u16_8x3(dst, dst_stride, d0, d1, d2);
1161 
1162       s += 8;
1163       dst += 8;
1164       width -= 8;
1165     } while (width != 0);
1166   }
1167 }
1168 
highbd_convolve8_4_2d_v(int16x8_t samples_lo[2],int16x8_t samples_hi[2],int16x8_t filter,int64x2_t offset)1169 static inline uint16x4_t highbd_convolve8_4_2d_v(int16x8_t samples_lo[2],
1170                                                  int16x8_t samples_hi[2],
1171                                                  int16x8_t filter,
1172                                                  int64x2_t offset) {
1173   int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
1174   sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
1175 
1176   int64x2_t sum23 = aom_svdot_lane_s16(offset, samples_lo[1], filter, 0);
1177   sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
1178 
1179   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
1180 
1181   return vqrshrun_n_s32(sum0123, COMPOUND_ROUND1_BITS);
1182 }
1183 
highbd_convolve8_8_2d_v(int16x8_t samples_lo[4],int16x8_t samples_hi[4],int16x8_t filter,int64x2_t offset)1184 static inline uint16x8_t highbd_convolve8_8_2d_v(int16x8_t samples_lo[4],
1185                                                  int16x8_t samples_hi[4],
1186                                                  int16x8_t filter,
1187                                                  int64x2_t offset) {
1188   int64x2_t sum01 = aom_svdot_lane_s16(offset, samples_lo[0], filter, 0);
1189   sum01 = aom_svdot_lane_s16(sum01, samples_hi[0], filter, 1);
1190 
1191   int64x2_t sum23 = aom_svdot_lane_s16(offset, samples_lo[1], filter, 0);
1192   sum23 = aom_svdot_lane_s16(sum23, samples_hi[1], filter, 1);
1193 
1194   int64x2_t sum45 = aom_svdot_lane_s16(offset, samples_lo[2], filter, 0);
1195   sum45 = aom_svdot_lane_s16(sum45, samples_hi[2], filter, 1);
1196 
1197   int64x2_t sum67 = aom_svdot_lane_s16(offset, samples_lo[3], filter, 0);
1198   sum67 = aom_svdot_lane_s16(sum67, samples_hi[3], filter, 1);
1199 
1200   int32x4_t sum0123 = vcombine_s32(vmovn_s64(sum01), vmovn_s64(sum23));
1201   int32x4_t sum4567 = vcombine_s32(vmovn_s64(sum45), vmovn_s64(sum67));
1202 
1203   return vcombine_u16(vqrshrun_n_s32(sum0123, COMPOUND_ROUND1_BITS),
1204                       vqrshrun_n_s32(sum4567, COMPOUND_ROUND1_BITS));
1205 }
1206 
highbd_dist_wtd_convolve_2d_vert_8tap_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int width,int height,const int16_t * y_filter_ptr,int offset)1207 static inline void highbd_dist_wtd_convolve_2d_vert_8tap_sve2(
1208     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride,
1209     int width, int height, const int16_t *y_filter_ptr, int offset) {
1210   const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
1211   const int64x2_t offset_s64 = vdupq_n_s64(offset);
1212 
1213   uint16x8x3_t merge_block_tbl = vld1q_u16_x3(kDotProdMergeBlockTbl);
1214   // Scale indices by size of the true vector length to avoid reading from an
1215   // 'undefined' portion of a vector on a system with SVE vectors > 128-bit.
1216   uint16x8_t correction0 =
1217       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000000000000ULL));
1218   merge_block_tbl.val[0] = vaddq_u16(merge_block_tbl.val[0], correction0);
1219 
1220   uint16x8_t correction1 =
1221       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000100000000ULL));
1222   merge_block_tbl.val[1] = vaddq_u16(merge_block_tbl.val[1], correction1);
1223 
1224   uint16x8_t correction2 =
1225       vreinterpretq_u16_u64(vdupq_n_u64(svcnth() * 0x0001000100010000ULL));
1226   merge_block_tbl.val[2] = vaddq_u16(merge_block_tbl.val[2], correction2);
1227 
1228   if (width == 4) {
1229     int16_t *s = (int16_t *)src;
1230     int16x4_t s0, s1, s2, s3, s4, s5, s6;
1231     load_s16_4x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
1232     s += 7 * src_stride;
1233 
1234     // This operation combines a conventional transpose and the sample permute
1235     // required before computing the dot product.
1236     int16x8_t s0123[2], s1234[2], s2345[2], s3456[2];
1237     transpose_concat_4x4(s0, s1, s2, s3, s0123);
1238     transpose_concat_4x4(s1, s2, s3, s4, s1234);
1239     transpose_concat_4x4(s2, s3, s4, s5, s2345);
1240     transpose_concat_4x4(s3, s4, s5, s6, s3456);
1241 
1242     do {
1243       int16x4_t s7, s8, s9, s10;
1244       load_s16_4x4(s, src_stride, &s7, &s8, &s9, &s10);
1245 
1246       int16x8_t s4567[2], s5678[2], s6789[2], s789A[2];
1247       // Transpose and shuffle the 4 lines that were loaded.
1248       transpose_concat_4x4(s7, s8, s9, s10, s789A);
1249 
1250       // Merge new data into block from previous iteration.
1251       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[0], s4567);
1252       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
1253       aom_tbl2x2_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
1254 
1255       uint16x4_t d0 =
1256           highbd_convolve8_4_2d_v(s0123, s4567, y_filter, offset_s64);
1257       uint16x4_t d1 =
1258           highbd_convolve8_4_2d_v(s1234, s5678, y_filter, offset_s64);
1259       uint16x4_t d2 =
1260           highbd_convolve8_4_2d_v(s2345, s6789, y_filter, offset_s64);
1261       uint16x4_t d3 =
1262           highbd_convolve8_4_2d_v(s3456, s789A, y_filter, offset_s64);
1263 
1264       store_u16_4x4(dst, dst_stride, d0, d1, d2, d3);
1265 
1266       // Prepare block for next iteration - re-using as much as possible.
1267       // Shuffle everything up four rows.
1268       s0123[0] = s4567[0];
1269       s0123[1] = s4567[1];
1270       s1234[0] = s5678[0];
1271       s1234[1] = s5678[1];
1272       s2345[0] = s6789[0];
1273       s2345[1] = s6789[1];
1274       s3456[0] = s789A[0];
1275       s3456[1] = s789A[1];
1276 
1277       s += 4 * src_stride;
1278       dst += 4 * dst_stride;
1279       height -= 4;
1280     } while (height != 0);
1281   } else {
1282     do {
1283       int h = height;
1284       int16_t *s = (int16_t *)src;
1285       uint16_t *d = dst;
1286 
1287       int16x8_t s0, s1, s2, s3, s4, s5, s6;
1288       load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
1289       s += 7 * src_stride;
1290 
1291       // This operation combines a conventional transpose and the sample permute
1292       // required before computing the dot product.
1293       int16x8_t s0123[4], s1234[4], s2345[4], s3456[4];
1294       transpose_concat_8x4(s0, s1, s2, s3, s0123);
1295       transpose_concat_8x4(s1, s2, s3, s4, s1234);
1296       transpose_concat_8x4(s2, s3, s4, s5, s2345);
1297       transpose_concat_8x4(s3, s4, s5, s6, s3456);
1298 
1299       do {
1300         int16x8_t s7, s8, s9, s10;
1301         load_s16_8x4(s, src_stride, &s7, &s8, &s9, &s10);
1302         int16x8_t s4567[4], s5678[4], s6789[4], s789A[4];
1303 
1304         // Transpose and shuffle the 4 lines that were loaded.
1305         transpose_concat_8x4(s7, s8, s9, s10, s789A);
1306 
1307         // Merge new data into block from previous iteration.
1308         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[0], s4567);
1309         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[1], s5678);
1310         aom_tbl2x4_s16(s3456, s789A, merge_block_tbl.val[2], s6789);
1311 
1312         uint16x8_t d0 =
1313             highbd_convolve8_8_2d_v(s0123, s4567, y_filter, offset_s64);
1314         uint16x8_t d1 =
1315             highbd_convolve8_8_2d_v(s1234, s5678, y_filter, offset_s64);
1316         uint16x8_t d2 =
1317             highbd_convolve8_8_2d_v(s2345, s6789, y_filter, offset_s64);
1318         uint16x8_t d3 =
1319             highbd_convolve8_8_2d_v(s3456, s789A, y_filter, offset_s64);
1320 
1321         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
1322 
1323         // Prepare block for next iteration - re-using as much as possible.
1324         // Shuffle everything up four rows.
1325         s0123[0] = s4567[0];
1326         s0123[1] = s4567[1];
1327         s0123[2] = s4567[2];
1328         s0123[3] = s4567[3];
1329         s1234[0] = s5678[0];
1330         s1234[1] = s5678[1];
1331         s1234[2] = s5678[2];
1332         s1234[3] = s5678[3];
1333         s2345[0] = s6789[0];
1334         s2345[1] = s6789[1];
1335         s2345[2] = s6789[2];
1336         s2345[3] = s6789[3];
1337         s3456[0] = s789A[0];
1338         s3456[1] = s789A[1];
1339         s3456[2] = s789A[2];
1340         s3456[3] = s789A[3];
1341 
1342         s += 4 * src_stride;
1343         d += 4 * dst_stride;
1344         h -= 4;
1345       } while (h != 0);
1346       src += 8;
1347       dst += 8;
1348       width -= 8;
1349     } while (width != 0);
1350   }
1351 }
1352 
highbd_convolve4_4_2d_v(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t filter,const int32x4_t offset)1353 static inline uint16x4_t highbd_convolve4_4_2d_v(
1354     const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
1355     const int16x4_t s3, const int16x4_t filter, const int32x4_t offset) {
1356   int32x4_t sum = vmlal_lane_s16(offset, s0, filter, 0);
1357   sum = vmlal_lane_s16(sum, s1, filter, 1);
1358   sum = vmlal_lane_s16(sum, s2, filter, 2);
1359   sum = vmlal_lane_s16(sum, s3, filter, 3);
1360 
1361   return vqrshrun_n_s32(sum, COMPOUND_ROUND1_BITS);
1362 }
1363 
highbd_convolve4_8_2d_v(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x4_t filter,const int32x4_t offset)1364 static inline uint16x8_t highbd_convolve4_8_2d_v(
1365     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
1366     const int16x8_t s3, const int16x4_t filter, const int32x4_t offset) {
1367   int32x4_t sum0 = vmlal_lane_s16(offset, vget_low_s16(s0), filter, 0);
1368   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter, 1);
1369   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter, 2);
1370   sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter, 3);
1371 
1372   int32x4_t sum1 = vmlal_lane_s16(offset, vget_high_s16(s0), filter, 0);
1373   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter, 1);
1374   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter, 2);
1375   sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter, 3);
1376 
1377   return vcombine_u16(vqrshrun_n_s32(sum0, COMPOUND_ROUND1_BITS),
1378                       vqrshrun_n_s32(sum1, COMPOUND_ROUND1_BITS));
1379 }
1380 
highbd_dist_wtd_convolve_2d_vert_4tap_neon(const uint16_t * src_ptr,int src_stride,uint16_t * dst_ptr,int dst_stride,int w,int h,const int16_t * y_filter_ptr,const int offset)1381 static inline void highbd_dist_wtd_convolve_2d_vert_4tap_neon(
1382     const uint16_t *src_ptr, int src_stride, uint16_t *dst_ptr, int dst_stride,
1383     int w, int h, const int16_t *y_filter_ptr, const int offset) {
1384   const int16x4_t y_filter = vld1_s16(y_filter_ptr + 2);
1385   const int32x4_t offset_vec = vdupq_n_s32(offset);
1386 
1387   if (w == 4) {
1388     const int16_t *s = (const int16_t *)src_ptr;
1389     uint16_t *d = dst_ptr;
1390 
1391     int16x4_t s0, s1, s2;
1392     load_s16_4x3(s, src_stride, &s0, &s1, &s2);
1393     s += 3 * src_stride;
1394 
1395     do {
1396       int16x4_t s3, s4, s5, s6;
1397       load_s16_4x4(s, src_stride, &s3, &s4, &s5, &s6);
1398 
1399       uint16x4_t d0 =
1400           highbd_convolve4_4_2d_v(s0, s1, s2, s3, y_filter, offset_vec);
1401       uint16x4_t d1 =
1402           highbd_convolve4_4_2d_v(s1, s2, s3, s4, y_filter, offset_vec);
1403       uint16x4_t d2 =
1404           highbd_convolve4_4_2d_v(s2, s3, s4, s5, y_filter, offset_vec);
1405       uint16x4_t d3 =
1406           highbd_convolve4_4_2d_v(s3, s4, s5, s6, y_filter, offset_vec);
1407 
1408       store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
1409 
1410       s0 = s4;
1411       s1 = s5;
1412       s2 = s6;
1413 
1414       s += 4 * src_stride;
1415       d += 4 * dst_stride;
1416       h -= 4;
1417     } while (h != 0);
1418   } else {
1419     do {
1420       int height = h;
1421       const int16_t *s = (const int16_t *)src_ptr;
1422       uint16_t *d = dst_ptr;
1423 
1424       int16x8_t s0, s1, s2;
1425       load_s16_8x3(s, src_stride, &s0, &s1, &s2);
1426       s += 3 * src_stride;
1427 
1428       do {
1429         int16x8_t s3, s4, s5, s6;
1430         load_s16_8x4(s, src_stride, &s3, &s4, &s5, &s6);
1431 
1432         uint16x8_t d0 =
1433             highbd_convolve4_8_2d_v(s0, s1, s2, s3, y_filter, offset_vec);
1434         uint16x8_t d1 =
1435             highbd_convolve4_8_2d_v(s1, s2, s3, s4, y_filter, offset_vec);
1436         uint16x8_t d2 =
1437             highbd_convolve4_8_2d_v(s2, s3, s4, s5, y_filter, offset_vec);
1438         uint16x8_t d3 =
1439             highbd_convolve4_8_2d_v(s3, s4, s5, s6, y_filter, offset_vec);
1440 
1441         store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
1442 
1443         s0 = s4;
1444         s1 = s5;
1445         s2 = s6;
1446 
1447         s += 4 * src_stride;
1448         d += 4 * dst_stride;
1449         height -= 4;
1450       } while (height != 0);
1451       src_ptr += 8;
1452       dst_ptr += 8;
1453       w -= 8;
1454     } while (w != 0);
1455   }
1456 }
1457 
av1_highbd_dist_wtd_convolve_2d_sve2(const uint16_t * src,int src_stride,uint16_t * dst,int dst_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params,int bd)1458 void av1_highbd_dist_wtd_convolve_2d_sve2(
1459     const uint16_t *src, int src_stride, uint16_t *dst, int dst_stride, int w,
1460     int h, const InterpFilterParams *filter_params_x,
1461     const InterpFilterParams *filter_params_y, const int subpel_x_qn,
1462     const int subpel_y_qn, ConvolveParams *conv_params, int bd) {
1463   DECLARE_ALIGNED(16, uint16_t,
1464                   im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
1465   DECLARE_ALIGNED(16, uint16_t,
1466                   im_block2[(MAX_SB_SIZE + MAX_FILTER_TAP) * MAX_SB_SIZE]);
1467 
1468   CONV_BUF_TYPE *dst16 = conv_params->dst;
1469   int dst16_stride = conv_params->dst_stride;
1470   const int x_filter_taps = get_filter_tap(filter_params_x, subpel_x_qn);
1471   const int clamped_x_taps = x_filter_taps < 4 ? 4 : x_filter_taps;
1472 
1473   const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
1474   const int clamped_y_taps = y_filter_taps < 4 ? 4 : y_filter_taps;
1475 
1476   if (x_filter_taps == 6 || y_filter_taps == 6) {
1477     av1_highbd_dist_wtd_convolve_2d_neon(
1478         src, src_stride, dst, dst_stride, w, h, filter_params_x,
1479         filter_params_y, subpel_x_qn, subpel_y_qn, conv_params, bd);
1480     return;
1481   }
1482 
1483   const int im_h = h + clamped_y_taps - 1;
1484   const int im_stride = MAX_SB_SIZE;
1485   const int vert_offset = clamped_y_taps / 2 - 1;
1486   const int horiz_offset = clamped_x_taps / 2 - 1;
1487   const int y_offset_bits = bd + 2 * FILTER_BITS - conv_params->round_0;
1488   const int round_offset_conv_y = (1 << y_offset_bits);
1489 
1490   const uint16_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
1491 
1492   const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
1493       filter_params_x, subpel_x_qn & SUBPEL_MASK);
1494   const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
1495       filter_params_y, subpel_y_qn & SUBPEL_MASK);
1496 
1497   if (bd == 12) {
1498     if (x_filter_taps <= 4) {
1499       highbd_12_dist_wtd_convolve_2d_horiz_4tap_sve2(
1500           src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr);
1501     } else {
1502       highbd_12_dist_wtd_convolve_2d_horiz_8tap_sve2(
1503           src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr);
1504     }
1505   } else {
1506     if (x_filter_taps <= 4) {
1507       highbd_dist_wtd_convolve_2d_horiz_4tap_sve2(
1508           src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr, bd);
1509     } else {
1510       highbd_dist_wtd_convolve_2d_horiz_8tap_sve2(
1511           src_ptr, src_stride, im_block, im_stride, w, im_h, x_filter_ptr, bd);
1512     }
1513   }
1514 
1515   if (conv_params->do_average) {
1516     if (y_filter_taps <= 4) {
1517       highbd_dist_wtd_convolve_2d_vert_4tap_neon(im_block, im_stride, im_block2,
1518                                                  im_stride, w, h, y_filter_ptr,
1519                                                  round_offset_conv_y);
1520     } else {
1521       highbd_dist_wtd_convolve_2d_vert_8tap_sve2(im_block, im_stride, im_block2,
1522                                                  im_stride, w, h, y_filter_ptr,
1523                                                  round_offset_conv_y);
1524     }
1525     if (conv_params->use_dist_wtd_comp_avg) {
1526       if (bd == 12) {
1527         highbd_12_dist_wtd_comp_avg_neon(im_block2, im_stride, dst, dst_stride,
1528                                          w, h, conv_params);
1529 
1530       } else {
1531         highbd_dist_wtd_comp_avg_neon(im_block2, im_stride, dst, dst_stride, w,
1532                                       h, conv_params, bd);
1533       }
1534     } else {
1535       if (bd == 12) {
1536         highbd_12_comp_avg_neon(im_block2, im_stride, dst, dst_stride, w, h,
1537                                 conv_params);
1538 
1539       } else {
1540         highbd_comp_avg_neon(im_block2, im_stride, dst, dst_stride, w, h,
1541                              conv_params, bd);
1542       }
1543     }
1544   } else {
1545     if (y_filter_taps <= 4) {
1546       highbd_dist_wtd_convolve_2d_vert_4tap_neon(
1547           im_block, im_stride, dst16, dst16_stride, w, h, y_filter_ptr,
1548           round_offset_conv_y);
1549     } else {
1550       highbd_dist_wtd_convolve_2d_vert_8tap_sve2(
1551           im_block, im_stride, dst16, dst16_stride, w, h, y_filter_ptr,
1552           round_offset_conv_y);
1553     }
1554   }
1555 }
1556