xref: /aosp_15_r20/external/libaom/aom_dsp/arm/aom_convolve8_neon_dotprod.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2014 The WebM project authors. All rights reserved.
3  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
4  *
5  * This source code is subject to the terms of the BSD 2 Clause License and
6  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
7  * was not distributed with this source code in the LICENSE file, you can
8  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
9  * Media Patent License 1.0 was not distributed with this source code in the
10  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
11  */
12 
13 #include <arm_neon.h>
14 #include <assert.h>
15 #include <string.h>
16 
17 #include "config/aom_config.h"
18 #include "config/aom_dsp_rtcd.h"
19 
20 #include "aom/aom_integer.h"
21 #include "aom_dsp/aom_dsp_common.h"
22 #include "aom_dsp/aom_filter.h"
23 #include "aom_dsp/arm/aom_convolve8_neon.h"
24 #include "aom_dsp/arm/aom_filter.h"
25 #include "aom_dsp/arm/mem_neon.h"
26 #include "aom_dsp/arm/transpose_neon.h"
27 #include "aom_ports/mem.h"
28 
29 // Filter values always sum to 128.
30 #define FILTER_WEIGHT 128
31 
32 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
33   0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
34   4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
35   8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
36 };
37 
38 DECLARE_ALIGNED(16, static const uint8_t, kDotProdMergeBlockTbl[48]) = {
39   // Shift left and insert new last column in transposed 4x4 block.
40   1, 2, 3, 16, 5, 6, 7, 20, 9, 10, 11, 24, 13, 14, 15, 28,
41   // Shift left and insert two new columns in transposed 4x4 block.
42   2, 3, 16, 17, 6, 7, 20, 21, 10, 11, 24, 25, 14, 15, 28, 29,
43   // Shift left and insert three new columns in transposed 4x4 block.
44   3, 16, 17, 18, 7, 20, 21, 22, 11, 24, 25, 26, 15, 28, 29, 30
45 };
46 
convolve8_4_h(const uint8x16_t samples,const int8x8_t filters,const uint8x16x2_t permute_tbl)47 static inline int16x4_t convolve8_4_h(const uint8x16_t samples,
48                                       const int8x8_t filters,
49                                       const uint8x16x2_t permute_tbl) {
50   // Transform sample range to [-128, 127] for 8-bit signed dot product.
51   int8x16_t samples_128 =
52       vreinterpretq_s8_u8(vsubq_u8(samples, vdupq_n_u8(128)));
53 
54   // Permute samples ready for dot product.
55   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
56   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
57   int8x16_t perm_samples[2] = { vqtbl1q_s8(samples_128, permute_tbl.val[0]),
58                                 vqtbl1q_s8(samples_128, permute_tbl.val[1]) };
59 
60   // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
61   int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT);
62   int32x4_t sum = vdotq_lane_s32(acc, perm_samples[0], filters, 0);
63   sum = vdotq_lane_s32(sum, perm_samples[1], filters, 1);
64 
65   // Further narrowing and packing is performed by the caller.
66   return vqmovn_s32(sum);
67 }
68 
convolve8_8_h(const uint8x16_t samples,const int8x8_t filters,const uint8x16x3_t permute_tbl)69 static inline uint8x8_t convolve8_8_h(const uint8x16_t samples,
70                                       const int8x8_t filters,
71                                       const uint8x16x3_t permute_tbl) {
72   // Transform sample range to [-128, 127] for 8-bit signed dot product.
73   int8x16_t samples_128 =
74       vreinterpretq_s8_u8(vsubq_u8(samples, vdupq_n_u8(128)));
75 
76   // Permute samples ready for dot product.
77   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
78   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
79   // { 8,  9, 10, 11,  9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
80   int8x16_t perm_samples[3] = { vqtbl1q_s8(samples_128, permute_tbl.val[0]),
81                                 vqtbl1q_s8(samples_128, permute_tbl.val[1]),
82                                 vqtbl1q_s8(samples_128, permute_tbl.val[2]) };
83 
84   // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
85   int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT);
86   // First 4 output values.
87   int32x4_t sum0 = vdotq_lane_s32(acc, perm_samples[0], filters, 0);
88   sum0 = vdotq_lane_s32(sum0, perm_samples[1], filters, 1);
89   // Second 4 output values.
90   int32x4_t sum1 = vdotq_lane_s32(acc, perm_samples[1], filters, 0);
91   sum1 = vdotq_lane_s32(sum1, perm_samples[2], filters, 1);
92 
93   // Narrow and re-pack.
94   int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
95   return vqrshrun_n_s16(sum, FILTER_BITS);
96 }
97 
convolve8_horiz_8tap_neon_dotprod(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_x,int w,int h)98 static inline void convolve8_horiz_8tap_neon_dotprod(
99     const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
100     ptrdiff_t dst_stride, const int16_t *filter_x, int w, int h) {
101   const int8x8_t filter = vmovn_s16(vld1q_s16(filter_x));
102 
103   if (w == 4) {
104     const uint8x16x2_t perm_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
105     do {
106       uint8x16_t s0, s1, s2, s3;
107       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
108 
109       int16x4_t d0 = convolve8_4_h(s0, filter, perm_tbl);
110       int16x4_t d1 = convolve8_4_h(s1, filter, perm_tbl);
111       int16x4_t d2 = convolve8_4_h(s2, filter, perm_tbl);
112       int16x4_t d3 = convolve8_4_h(s3, filter, perm_tbl);
113       uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
114       uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
115 
116       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
117       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
118 
119       src += 4 * src_stride;
120       dst += 4 * dst_stride;
121       h -= 4;
122     } while (h > 0);
123   } else {
124     const uint8x16x3_t perm_tbl = vld1q_u8_x3(kDotProdPermuteTbl);
125 
126     do {
127       int width = w;
128       const uint8_t *s = src;
129       uint8_t *d = dst;
130       do {
131         uint8x16_t s0, s1, s2, s3;
132         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
133 
134         uint8x8_t d0 = convolve8_8_h(s0, filter, perm_tbl);
135         uint8x8_t d1 = convolve8_8_h(s1, filter, perm_tbl);
136         uint8x8_t d2 = convolve8_8_h(s2, filter, perm_tbl);
137         uint8x8_t d3 = convolve8_8_h(s3, filter, perm_tbl);
138 
139         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
140 
141         s += 8;
142         d += 8;
143         width -= 8;
144       } while (width != 0);
145       src += 4 * src_stride;
146       dst += 4 * dst_stride;
147       h -= 4;
148     } while (h > 0);
149   }
150 }
151 
convolve4_4_h(const uint8x16_t samples,const int8x8_t filters,const uint8x16_t permute_tbl)152 static inline int16x4_t convolve4_4_h(const uint8x16_t samples,
153                                       const int8x8_t filters,
154                                       const uint8x16_t permute_tbl) {
155   // Transform sample range to [-128, 127] for 8-bit signed dot product.
156   int8x16_t samples_128 =
157       vreinterpretq_s8_u8(vsubq_u8(samples, vdupq_n_u8(128)));
158 
159   // Permute samples ready for dot product.
160   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
161   int8x16_t perm_samples = vqtbl1q_s8(samples_128, permute_tbl);
162 
163   // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
164   // (Divide by 2 since we halved the filter values.)
165   int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT / 2);
166   int32x4_t sum = vdotq_lane_s32(acc, perm_samples, filters, 0);
167 
168   // Further narrowing and packing is performed by the caller.
169   return vmovn_s32(sum);
170 }
171 
convolve4_8_h(const uint8x16_t samples,const int8x8_t filters,const uint8x16x2_t permute_tbl)172 static inline uint8x8_t convolve4_8_h(const uint8x16_t samples,
173                                       const int8x8_t filters,
174                                       const uint8x16x2_t permute_tbl) {
175   // Transform sample range to [-128, 127] for 8-bit signed dot product.
176   int8x16_t samples_128 =
177       vreinterpretq_s8_u8(vsubq_u8(samples, vdupq_n_u8(128)));
178 
179   // Permute samples ready for dot product.
180   // { 0,  1,  2,  3,  1,  2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6 }
181   // { 4,  5,  6,  7,  5,  6,  7,  8,  6,  7,  8,  9,  7,  8,  9, 10 }
182   int8x16_t perm_samples[2] = { vqtbl1q_s8(samples_128, permute_tbl.val[0]),
183                                 vqtbl1q_s8(samples_128, permute_tbl.val[1]) };
184 
185   // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
186   // (Divide by 2 since we halved the filter values.)
187   int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT / 2);
188   // First 4 output values.
189   int32x4_t sum0 = vdotq_lane_s32(acc, perm_samples[0], filters, 0);
190   // Second 4 output values.
191   int32x4_t sum1 = vdotq_lane_s32(acc, perm_samples[1], filters, 0);
192 
193   // Narrow and re-pack.
194   int16x8_t sum = vcombine_s16(vmovn_s32(sum0), vmovn_s32(sum1));
195   // We halved the filter values so -1 from right shift.
196   return vqrshrun_n_s16(sum, FILTER_BITS - 1);
197 }
198 
convolve8_horiz_4tap_neon_dotprod(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_x,int width,int height)199 static inline void convolve8_horiz_4tap_neon_dotprod(
200     const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
201     ptrdiff_t dst_stride, const int16_t *filter_x, int width, int height) {
202   const int16x4_t x_filter = vld1_s16(filter_x + 2);
203   // All 4-tap and bilinear filter values are even, so halve them to reduce
204   // intermediate precision requirements.
205   const int8x8_t filter = vshrn_n_s16(vcombine_s16(x_filter, vdup_n_s16(0)), 1);
206 
207   if (width == 4) {
208     const uint8x16_t permute_tbl = vld1q_u8(kDotProdPermuteTbl);
209 
210     do {
211       uint8x16_t s0, s1, s2, s3;
212       load_u8_16x4(src, src_stride, &s0, &s1, &s2, &s3);
213 
214       int16x4_t t0 = convolve4_4_h(s0, filter, permute_tbl);
215       int16x4_t t1 = convolve4_4_h(s1, filter, permute_tbl);
216       int16x4_t t2 = convolve4_4_h(s2, filter, permute_tbl);
217       int16x4_t t3 = convolve4_4_h(s3, filter, permute_tbl);
218       // We halved the filter values so -1 from right shift.
219       uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(t0, t1), FILTER_BITS - 1);
220       uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(t2, t3), FILTER_BITS - 1);
221 
222       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
223       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
224 
225       src += 4 * src_stride;
226       dst += 4 * dst_stride;
227       height -= 4;
228     } while (height > 0);
229   } else {
230     const uint8x16x2_t permute_tbl = vld1q_u8_x2(kDotProdPermuteTbl);
231 
232     do {
233       const uint8_t *s = src;
234       uint8_t *d = dst;
235       int w = width;
236 
237       do {
238         uint8x16_t s0, s1, s2, s3;
239         load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
240 
241         uint8x8_t d0 = convolve4_8_h(s0, filter, permute_tbl);
242         uint8x8_t d1 = convolve4_8_h(s1, filter, permute_tbl);
243         uint8x8_t d2 = convolve4_8_h(s2, filter, permute_tbl);
244         uint8x8_t d3 = convolve4_8_h(s3, filter, permute_tbl);
245 
246         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
247 
248         s += 8;
249         d += 8;
250         w -= 8;
251       } while (w != 0);
252       src += 4 * src_stride;
253       dst += 4 * dst_stride;
254       height -= 4;
255     } while (height > 0);
256   }
257 }
258 
aom_convolve8_horiz_neon_dotprod(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_x,int x_step_q4,const int16_t * filter_y,int y_step_q4,int w,int h)259 void aom_convolve8_horiz_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
260                                       uint8_t *dst, ptrdiff_t dst_stride,
261                                       const int16_t *filter_x, int x_step_q4,
262                                       const int16_t *filter_y, int y_step_q4,
263                                       int w, int h) {
264   assert((intptr_t)dst % 4 == 0);
265   assert(dst_stride % 4 == 0);
266 
267   (void)x_step_q4;
268   (void)filter_y;
269   (void)y_step_q4;
270 
271   src -= ((SUBPEL_TAPS / 2) - 1);
272 
273   int filter_taps = get_filter_taps_convolve8(filter_x);
274 
275   if (filter_taps == 2) {
276     convolve8_horiz_2tap_neon(src + 3, src_stride, dst, dst_stride, filter_x, w,
277                               h);
278   } else if (filter_taps == 4) {
279     convolve8_horiz_4tap_neon_dotprod(src + 2, src_stride, dst, dst_stride,
280                                       filter_x, w, h);
281   } else {
282     convolve8_horiz_8tap_neon_dotprod(src, src_stride, dst, dst_stride,
283                                       filter_x, w, h);
284   }
285 }
286 
transpose_concat_4x4(int8x8_t a0,int8x8_t a1,int8x8_t a2,int8x8_t a3,int8x16_t * b)287 static inline void transpose_concat_4x4(int8x8_t a0, int8x8_t a1, int8x8_t a2,
288                                         int8x8_t a3, int8x16_t *b) {
289   // Transpose 8-bit elements and concatenate result rows as follows:
290   // a0: 00, 01, 02, 03, XX, XX, XX, XX
291   // a1: 10, 11, 12, 13, XX, XX, XX, XX
292   // a2: 20, 21, 22, 23, XX, XX, XX, XX
293   // a3: 30, 31, 32, 33, XX, XX, XX, XX
294   //
295   // b: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
296 
297   int8x16_t a0q = vcombine_s8(a0, vdup_n_s8(0));
298   int8x16_t a1q = vcombine_s8(a1, vdup_n_s8(0));
299   int8x16_t a2q = vcombine_s8(a2, vdup_n_s8(0));
300   int8x16_t a3q = vcombine_s8(a3, vdup_n_s8(0));
301 
302   int8x16_t a01 = vzipq_s8(a0q, a1q).val[0];
303   int8x16_t a23 = vzipq_s8(a2q, a3q).val[0];
304 
305   int16x8_t a0123 =
306       vzipq_s16(vreinterpretq_s16_s8(a01), vreinterpretq_s16_s8(a23)).val[0];
307 
308   *b = vreinterpretq_s8_s16(a0123);
309 }
310 
transpose_concat_8x4(int8x8_t a0,int8x8_t a1,int8x8_t a2,int8x8_t a3,int8x16_t * b0,int8x16_t * b1)311 static inline void transpose_concat_8x4(int8x8_t a0, int8x8_t a1, int8x8_t a2,
312                                         int8x8_t a3, int8x16_t *b0,
313                                         int8x16_t *b1) {
314   // Transpose 8-bit elements and concatenate result rows as follows:
315   // a0: 00, 01, 02, 03, 04, 05, 06, 07
316   // a1: 10, 11, 12, 13, 14, 15, 16, 17
317   // a2: 20, 21, 22, 23, 24, 25, 26, 27
318   // a3: 30, 31, 32, 33, 34, 35, 36, 37
319   //
320   // b0: 00, 10, 20, 30, 01, 11, 21, 31, 02, 12, 22, 32, 03, 13, 23, 33
321   // b1: 04, 14, 24, 34, 05, 15, 25, 35, 06, 16, 26, 36, 07, 17, 27, 37
322 
323   int8x16_t a0q = vcombine_s8(a0, vdup_n_s8(0));
324   int8x16_t a1q = vcombine_s8(a1, vdup_n_s8(0));
325   int8x16_t a2q = vcombine_s8(a2, vdup_n_s8(0));
326   int8x16_t a3q = vcombine_s8(a3, vdup_n_s8(0));
327 
328   int8x16_t a01 = vzipq_s8(a0q, a1q).val[0];
329   int8x16_t a23 = vzipq_s8(a2q, a3q).val[0];
330 
331   int16x8x2_t a0123 =
332       vzipq_s16(vreinterpretq_s16_s8(a01), vreinterpretq_s16_s8(a23));
333 
334   *b0 = vreinterpretq_s8_s16(a0123.val[0]);
335   *b1 = vreinterpretq_s8_s16(a0123.val[1]);
336 }
337 
convolve8_4_v(const int8x16_t samples_lo,const int8x16_t samples_hi,const int8x8_t filters)338 static inline int16x4_t convolve8_4_v(const int8x16_t samples_lo,
339                                       const int8x16_t samples_hi,
340                                       const int8x8_t filters) {
341   // The sample range transform and permutation are performed by the caller.
342 
343   // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
344   int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT);
345   int32x4_t sum = vdotq_lane_s32(acc, samples_lo, filters, 0);
346   sum = vdotq_lane_s32(sum, samples_hi, filters, 1);
347 
348   // Further narrowing and packing is performed by the caller.
349   return vqmovn_s32(sum);
350 }
351 
convolve8_8_v(const int8x16_t samples0_lo,const int8x16_t samples0_hi,const int8x16_t samples1_lo,const int8x16_t samples1_hi,const int8x8_t filters)352 static inline uint8x8_t convolve8_8_v(const int8x16_t samples0_lo,
353                                       const int8x16_t samples0_hi,
354                                       const int8x16_t samples1_lo,
355                                       const int8x16_t samples1_hi,
356                                       const int8x8_t filters) {
357   // The sample range transform and permutation are performed by the caller.
358 
359   // Accumulate into 128 * FILTER_WEIGHT to account for range transform.
360   int32x4_t acc = vdupq_n_s32(128 * FILTER_WEIGHT);
361   // First 4 output values.
362   int32x4_t sum0 = vdotq_lane_s32(acc, samples0_lo, filters, 0);
363   sum0 = vdotq_lane_s32(sum0, samples0_hi, filters, 1);
364   // Second 4 output values.
365   int32x4_t sum1 = vdotq_lane_s32(acc, samples1_lo, filters, 0);
366   sum1 = vdotq_lane_s32(sum1, samples1_hi, filters, 1);
367 
368   // Narrow and re-pack.
369   int16x8_t sum = vcombine_s16(vqmovn_s32(sum0), vqmovn_s32(sum1));
370   return vqrshrun_n_s16(sum, FILTER_BITS);
371 }
372 
convolve8_vert_8tap_neon_dotprod(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_y,int w,int h)373 static inline void convolve8_vert_8tap_neon_dotprod(
374     const uint8_t *src, ptrdiff_t src_stride, uint8_t *dst,
375     ptrdiff_t dst_stride, const int16_t *filter_y, int w, int h) {
376   const int8x8_t filter = vmovn_s16(vld1q_s16(filter_y));
377   const uint8x16x3_t merge_block_tbl = vld1q_u8_x3(kDotProdMergeBlockTbl);
378   int8x16x2_t samples_LUT;
379 
380   if (w == 4) {
381     uint8x8_t t0, t1, t2, t3, t4, t5, t6;
382     load_u8_8x7(src, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
383     src += 7 * src_stride;
384 
385     // Clamp sample range to [-128, 127] for 8-bit signed dot product.
386     int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128)));
387     int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128)));
388     int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128)));
389     int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128)));
390     int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128)));
391     int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128)));
392     int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128)));
393 
394     // This operation combines a conventional transpose and the sample permute
395     // (see horizontal case) required before computing the dot product.
396     int8x16_t s0123, s1234, s2345, s3456;
397     transpose_concat_4x4(s0, s1, s2, s3, &s0123);
398     transpose_concat_4x4(s1, s2, s3, s4, &s1234);
399     transpose_concat_4x4(s2, s3, s4, s5, &s2345);
400     transpose_concat_4x4(s3, s4, s5, s6, &s3456);
401 
402     do {
403       uint8x8_t t7, t8, t9, t10;
404       load_u8_8x4(src, src_stride, &t7, &t8, &t9, &t10);
405 
406       int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128)));
407       int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128)));
408       int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128)));
409       int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, vdup_n_u8(128)));
410 
411       int8x16_t s4567, s5678, s6789, s78910;
412       transpose_concat_4x4(s7, s8, s9, s10, &s78910);
413 
414       // Merge new data into block from previous iteration.
415       samples_LUT.val[0] = s3456;
416       samples_LUT.val[1] = s78910;
417       s4567 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]);
418       s5678 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
419       s6789 = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
420 
421       int16x4_t d0 = convolve8_4_v(s0123, s4567, filter);
422       int16x4_t d1 = convolve8_4_v(s1234, s5678, filter);
423       int16x4_t d2 = convolve8_4_v(s2345, s6789, filter);
424       int16x4_t d3 = convolve8_4_v(s3456, s78910, filter);
425       uint8x8_t d01 = vqrshrun_n_s16(vcombine_s16(d0, d1), FILTER_BITS);
426       uint8x8_t d23 = vqrshrun_n_s16(vcombine_s16(d2, d3), FILTER_BITS);
427 
428       store_u8x4_strided_x2(dst + 0 * dst_stride, dst_stride, d01);
429       store_u8x4_strided_x2(dst + 2 * dst_stride, dst_stride, d23);
430 
431       // Prepare block for next iteration - re-using as much as possible.
432       // Shuffle everything up four rows.
433       s0123 = s4567;
434       s1234 = s5678;
435       s2345 = s6789;
436       s3456 = s78910;
437 
438       src += 4 * src_stride;
439       dst += 4 * dst_stride;
440       h -= 4;
441     } while (h != 0);
442   } else {
443     do {
444       int height = h;
445       const uint8_t *s = src;
446       uint8_t *d = dst;
447 
448       uint8x8_t t0, t1, t2, t3, t4, t5, t6;
449       load_u8_8x7(s, src_stride, &t0, &t1, &t2, &t3, &t4, &t5, &t6);
450       s += 7 * src_stride;
451 
452       // Clamp sample range to [-128, 127] for 8-bit signed dot product.
453       int8x8_t s0 = vreinterpret_s8_u8(vsub_u8(t0, vdup_n_u8(128)));
454       int8x8_t s1 = vreinterpret_s8_u8(vsub_u8(t1, vdup_n_u8(128)));
455       int8x8_t s2 = vreinterpret_s8_u8(vsub_u8(t2, vdup_n_u8(128)));
456       int8x8_t s3 = vreinterpret_s8_u8(vsub_u8(t3, vdup_n_u8(128)));
457       int8x8_t s4 = vreinterpret_s8_u8(vsub_u8(t4, vdup_n_u8(128)));
458       int8x8_t s5 = vreinterpret_s8_u8(vsub_u8(t5, vdup_n_u8(128)));
459       int8x8_t s6 = vreinterpret_s8_u8(vsub_u8(t6, vdup_n_u8(128)));
460 
461       // This operation combines a conventional transpose and the sample permute
462       // (see horizontal case) required before computing the dot product.
463       int8x16_t s0123_lo, s0123_hi, s1234_lo, s1234_hi, s2345_lo, s2345_hi,
464           s3456_lo, s3456_hi;
465       transpose_concat_8x4(s0, s1, s2, s3, &s0123_lo, &s0123_hi);
466       transpose_concat_8x4(s1, s2, s3, s4, &s1234_lo, &s1234_hi);
467       transpose_concat_8x4(s2, s3, s4, s5, &s2345_lo, &s2345_hi);
468       transpose_concat_8x4(s3, s4, s5, s6, &s3456_lo, &s3456_hi);
469 
470       do {
471         uint8x8_t t7, t8, t9, t10;
472         load_u8_8x4(s, src_stride, &t7, &t8, &t9, &t10);
473 
474         int8x8_t s7 = vreinterpret_s8_u8(vsub_u8(t7, vdup_n_u8(128)));
475         int8x8_t s8 = vreinterpret_s8_u8(vsub_u8(t8, vdup_n_u8(128)));
476         int8x8_t s9 = vreinterpret_s8_u8(vsub_u8(t9, vdup_n_u8(128)));
477         int8x8_t s10 = vreinterpret_s8_u8(vsub_u8(t10, vdup_n_u8(128)));
478 
479         int8x16_t s4567_lo, s4567_hi, s5678_lo, s5678_hi, s6789_lo, s6789_hi,
480             s78910_lo, s78910_hi;
481         transpose_concat_8x4(s7, s8, s9, s10, &s78910_lo, &s78910_hi);
482 
483         // Merge new data into block from previous iteration.
484         samples_LUT.val[0] = s3456_lo;
485         samples_LUT.val[1] = s78910_lo;
486         s4567_lo = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]);
487         s5678_lo = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
488         s6789_lo = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
489 
490         samples_LUT.val[0] = s3456_hi;
491         samples_LUT.val[1] = s78910_hi;
492         s4567_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[0]);
493         s5678_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[1]);
494         s6789_hi = vqtbl2q_s8(samples_LUT, merge_block_tbl.val[2]);
495 
496         uint8x8_t d0 =
497             convolve8_8_v(s0123_lo, s4567_lo, s0123_hi, s4567_hi, filter);
498         uint8x8_t d1 =
499             convolve8_8_v(s1234_lo, s5678_lo, s1234_hi, s5678_hi, filter);
500         uint8x8_t d2 =
501             convolve8_8_v(s2345_lo, s6789_lo, s2345_hi, s6789_hi, filter);
502         uint8x8_t d3 =
503             convolve8_8_v(s3456_lo, s78910_lo, s3456_hi, s78910_hi, filter);
504 
505         store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
506 
507         // Prepare block for next iteration - re-using as much as possible.
508         // Shuffle everything up four rows.
509         s0123_lo = s4567_lo;
510         s0123_hi = s4567_hi;
511         s1234_lo = s5678_lo;
512         s1234_hi = s5678_hi;
513         s2345_lo = s6789_lo;
514         s2345_hi = s6789_hi;
515         s3456_lo = s78910_lo;
516         s3456_hi = s78910_hi;
517 
518         s += 4 * src_stride;
519         d += 4 * dst_stride;
520         height -= 4;
521       } while (height != 0);
522       src += 8;
523       dst += 8;
524       w -= 8;
525     } while (w != 0);
526   }
527 }
528 
aom_convolve8_vert_neon_dotprod(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * filter_x,int x_step_q4,const int16_t * filter_y,int y_step_q4,int w,int h)529 void aom_convolve8_vert_neon_dotprod(const uint8_t *src, ptrdiff_t src_stride,
530                                      uint8_t *dst, ptrdiff_t dst_stride,
531                                      const int16_t *filter_x, int x_step_q4,
532                                      const int16_t *filter_y, int y_step_q4,
533                                      int w, int h) {
534   assert((intptr_t)dst % 4 == 0);
535   assert(dst_stride % 4 == 0);
536 
537   (void)filter_x;
538   (void)x_step_q4;
539   (void)y_step_q4;
540 
541   src -= ((SUBPEL_TAPS / 2) - 1) * src_stride;
542 
543   int filter_taps = get_filter_taps_convolve8(filter_y);
544 
545   if (filter_taps == 2) {
546     convolve8_vert_2tap_neon(src + 3 * src_stride, src_stride, dst, dst_stride,
547                              filter_y, w, h);
548   } else if (filter_taps == 4) {
549     convolve8_vert_4tap_neon(src + 2 * src_stride, src_stride, dst, dst_stride,
550                              filter_y, w, h);
551   } else {
552     convolve8_vert_8tap_neon_dotprod(src, src_stride, dst, dst_stride, filter_y,
553                                      w, h);
554   }
555 }
556