xref: /aosp_15_r20/external/libgav1/src/dsp/arm/convolve_10bit_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/convolve.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
19 #include <arm_neon.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstdint>
24 
25 #include "src/dsp/arm/common_neon.h"
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/utils/common.h"
29 #include "src/utils/compiler_attributes.h"
30 #include "src/utils/constants.h"
31 
32 namespace libgav1 {
33 namespace dsp {
34 namespace {
35 
36 // Include the constants and utility functions inside the anonymous namespace.
37 #include "src/dsp/convolve.inc"
38 
39 // Output of ConvolveTest.ShowRange below.
40 // Bitdepth: 10 Input range:            [       0,     1023]
41 //   Horizontal base upscaled range:    [  -28644,    94116]
42 //   Horizontal halved upscaled range:  [  -14322,    47085]
43 //   Horizontal downscaled range:       [   -7161,    23529]
44 //   Vertical upscaled range:           [-1317624,  2365176]
45 //   Pixel output range:                [       0,     1023]
46 //   Compound output range:             [    3988,    61532]
47 
48 template <int num_taps>
SumOnePassTaps(const uint16x8_t * const src,const int16x4_t * const taps)49 int32x4x2_t SumOnePassTaps(const uint16x8_t* const src,
50                            const int16x4_t* const taps) {
51   const auto* ssrc = reinterpret_cast<const int16x8_t*>(src);
52   int32x4x2_t sum;
53   if (num_taps == 6) {
54     // 6 taps.
55     sum.val[0] = vmull_s16(vget_low_s16(ssrc[0]), taps[0]);
56     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[1]), taps[1]);
57     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[2]), taps[2]);
58     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[3]), taps[3]);
59     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[4]), taps[4]);
60     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[5]), taps[5]);
61 
62     sum.val[1] = vmull_s16(vget_high_s16(ssrc[0]), taps[0]);
63     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[1]), taps[1]);
64     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[2]), taps[2]);
65     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[3]), taps[3]);
66     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[4]), taps[4]);
67     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[5]), taps[5]);
68   } else if (num_taps == 8) {
69     // 8 taps.
70     sum.val[0] = vmull_s16(vget_low_s16(ssrc[0]), taps[0]);
71     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[1]), taps[1]);
72     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[2]), taps[2]);
73     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[3]), taps[3]);
74     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[4]), taps[4]);
75     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[5]), taps[5]);
76     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[6]), taps[6]);
77     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[7]), taps[7]);
78 
79     sum.val[1] = vmull_s16(vget_high_s16(ssrc[0]), taps[0]);
80     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[1]), taps[1]);
81     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[2]), taps[2]);
82     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[3]), taps[3]);
83     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[4]), taps[4]);
84     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[5]), taps[5]);
85     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[6]), taps[6]);
86     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[7]), taps[7]);
87   } else if (num_taps == 2) {
88     // 2 taps.
89     sum.val[0] = vmull_s16(vget_low_s16(ssrc[0]), taps[0]);
90     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[1]), taps[1]);
91 
92     sum.val[1] = vmull_s16(vget_high_s16(ssrc[0]), taps[0]);
93     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[1]), taps[1]);
94   } else {
95     // 4 taps.
96     sum.val[0] = vmull_s16(vget_low_s16(ssrc[0]), taps[0]);
97     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[1]), taps[1]);
98     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[2]), taps[2]);
99     sum.val[0] = vmlal_s16(sum.val[0], vget_low_s16(ssrc[3]), taps[3]);
100 
101     sum.val[1] = vmull_s16(vget_high_s16(ssrc[0]), taps[0]);
102     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[1]), taps[1]);
103     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[2]), taps[2]);
104     sum.val[1] = vmlal_s16(sum.val[1], vget_high_s16(ssrc[3]), taps[3]);
105   }
106   return sum;
107 }
108 
109 template <int num_taps>
SumOnePassTaps(const uint16x4_t * const src,const int16x4_t * const taps)110 int32x4_t SumOnePassTaps(const uint16x4_t* const src,
111                          const int16x4_t* const taps) {
112   const auto* ssrc = reinterpret_cast<const int16x4_t*>(src);
113   int32x4_t sum;
114   if (num_taps == 6) {
115     // 6 taps.
116     sum = vmull_s16(ssrc[0], taps[0]);
117     sum = vmlal_s16(sum, ssrc[1], taps[1]);
118     sum = vmlal_s16(sum, ssrc[2], taps[2]);
119     sum = vmlal_s16(sum, ssrc[3], taps[3]);
120     sum = vmlal_s16(sum, ssrc[4], taps[4]);
121     sum = vmlal_s16(sum, ssrc[5], taps[5]);
122   } else if (num_taps == 8) {
123     // 8 taps.
124     sum = vmull_s16(ssrc[0], taps[0]);
125     sum = vmlal_s16(sum, ssrc[1], taps[1]);
126     sum = vmlal_s16(sum, ssrc[2], taps[2]);
127     sum = vmlal_s16(sum, ssrc[3], taps[3]);
128     sum = vmlal_s16(sum, ssrc[4], taps[4]);
129     sum = vmlal_s16(sum, ssrc[5], taps[5]);
130     sum = vmlal_s16(sum, ssrc[6], taps[6]);
131     sum = vmlal_s16(sum, ssrc[7], taps[7]);
132   } else if (num_taps == 2) {
133     // 2 taps.
134     sum = vmull_s16(ssrc[0], taps[0]);
135     sum = vmlal_s16(sum, ssrc[1], taps[1]);
136   } else {
137     // 4 taps.
138     sum = vmull_s16(ssrc[0], taps[0]);
139     sum = vmlal_s16(sum, ssrc[1], taps[1]);
140     sum = vmlal_s16(sum, ssrc[2], taps[2]);
141     sum = vmlal_s16(sum, ssrc[3], taps[3]);
142   }
143   return sum;
144 }
145 
146 template <int num_taps, bool is_compound, bool is_2d>
FilterHorizontalWidth8AndUp(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int width,const int height,const int16x4_t * const v_tap)147 void FilterHorizontalWidth8AndUp(const uint16_t* LIBGAV1_RESTRICT src,
148                                  const ptrdiff_t src_stride,
149                                  void* LIBGAV1_RESTRICT const dest,
150                                  const ptrdiff_t pred_stride, const int width,
151                                  const int height,
152                                  const int16x4_t* const v_tap) {
153   auto* dest16 = static_cast<uint16_t*>(dest);
154   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
155   if (is_2d) {
156     int x = 0;
157     do {
158       const uint16_t* s = src + x;
159       int y = height;
160       do {  // Increasing loop counter x is better.
161         const uint16x8_t src_long = vld1q_u16(s);
162         const uint16x8_t src_long_hi = vld1q_u16(s + 8);
163         uint16x8_t v_src[8];
164         int32x4x2_t v_sum;
165         if (num_taps == 6) {
166           v_src[0] = src_long;
167           v_src[1] = vextq_u16(src_long, src_long_hi, 1);
168           v_src[2] = vextq_u16(src_long, src_long_hi, 2);
169           v_src[3] = vextq_u16(src_long, src_long_hi, 3);
170           v_src[4] = vextq_u16(src_long, src_long_hi, 4);
171           v_src[5] = vextq_u16(src_long, src_long_hi, 5);
172           v_sum = SumOnePassTaps<num_taps>(v_src, v_tap + 1);
173         } else if (num_taps == 8) {
174           v_src[0] = src_long;
175           v_src[1] = vextq_u16(src_long, src_long_hi, 1);
176           v_src[2] = vextq_u16(src_long, src_long_hi, 2);
177           v_src[3] = vextq_u16(src_long, src_long_hi, 3);
178           v_src[4] = vextq_u16(src_long, src_long_hi, 4);
179           v_src[5] = vextq_u16(src_long, src_long_hi, 5);
180           v_src[6] = vextq_u16(src_long, src_long_hi, 6);
181           v_src[7] = vextq_u16(src_long, src_long_hi, 7);
182           v_sum = SumOnePassTaps<num_taps>(v_src, v_tap);
183         } else if (num_taps == 2) {
184           v_src[0] = src_long;
185           v_src[1] = vextq_u16(src_long, src_long_hi, 1);
186           v_sum = SumOnePassTaps<num_taps>(v_src, v_tap + 3);
187         } else {  // 4 taps
188           v_src[0] = src_long;
189           v_src[1] = vextq_u16(src_long, src_long_hi, 1);
190           v_src[2] = vextq_u16(src_long, src_long_hi, 2);
191           v_src[3] = vextq_u16(src_long, src_long_hi, 3);
192           v_sum = SumOnePassTaps<num_taps>(v_src, v_tap + 2);
193         }
194 
195         const int16x4_t d0 =
196             vqrshrn_n_s32(v_sum.val[0], kInterRoundBitsHorizontal - 1);
197         const int16x4_t d1 =
198             vqrshrn_n_s32(v_sum.val[1], kInterRoundBitsHorizontal - 1);
199         vst1_u16(&dest16[0], vreinterpret_u16_s16(d0));
200         vst1_u16(&dest16[4], vreinterpret_u16_s16(d1));
201         s += src_stride;
202         dest16 += 8;
203       } while (--y != 0);
204       x += 8;
205     } while (x < width);
206     return;
207   }
208   int y = height;
209   do {
210     int x = 0;
211     do {
212       const uint16x8_t src_long = vld1q_u16(src + x);
213       const uint16x8_t src_long_hi = vld1q_u16(src + x + 8);
214       uint16x8_t v_src[8];
215       int32x4x2_t v_sum;
216       if (num_taps == 6) {
217         v_src[0] = src_long;
218         v_src[1] = vextq_u16(src_long, src_long_hi, 1);
219         v_src[2] = vextq_u16(src_long, src_long_hi, 2);
220         v_src[3] = vextq_u16(src_long, src_long_hi, 3);
221         v_src[4] = vextq_u16(src_long, src_long_hi, 4);
222         v_src[5] = vextq_u16(src_long, src_long_hi, 5);
223         v_sum = SumOnePassTaps<num_taps>(v_src, v_tap + 1);
224       } else if (num_taps == 8) {
225         v_src[0] = src_long;
226         v_src[1] = vextq_u16(src_long, src_long_hi, 1);
227         v_src[2] = vextq_u16(src_long, src_long_hi, 2);
228         v_src[3] = vextq_u16(src_long, src_long_hi, 3);
229         v_src[4] = vextq_u16(src_long, src_long_hi, 4);
230         v_src[5] = vextq_u16(src_long, src_long_hi, 5);
231         v_src[6] = vextq_u16(src_long, src_long_hi, 6);
232         v_src[7] = vextq_u16(src_long, src_long_hi, 7);
233         v_sum = SumOnePassTaps<num_taps>(v_src, v_tap);
234       } else if (num_taps == 2) {
235         v_src[0] = src_long;
236         v_src[1] = vextq_u16(src_long, src_long_hi, 1);
237         v_sum = SumOnePassTaps<num_taps>(v_src, v_tap + 3);
238       } else {  // 4 taps
239         v_src[0] = src_long;
240         v_src[1] = vextq_u16(src_long, src_long_hi, 1);
241         v_src[2] = vextq_u16(src_long, src_long_hi, 2);
242         v_src[3] = vextq_u16(src_long, src_long_hi, 3);
243         v_sum = SumOnePassTaps<num_taps>(v_src, v_tap + 2);
244       }
245       if (is_compound) {
246         const int16x4_t v_compound_offset = vdup_n_s16(kCompoundOffset);
247         const int16x4_t d0 =
248             vqrshrn_n_s32(v_sum.val[0], kInterRoundBitsHorizontal - 1);
249         const int16x4_t d1 =
250             vqrshrn_n_s32(v_sum.val[1], kInterRoundBitsHorizontal - 1);
251         vst1_u16(&dest16[x],
252                  vreinterpret_u16_s16(vadd_s16(d0, v_compound_offset)));
253         vst1_u16(&dest16[x + 4],
254                  vreinterpret_u16_s16(vadd_s16(d1, v_compound_offset)));
255       } else {
256         // Normally the Horizontal pass does the downshift in two passes:
257         // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
258         // kInterRoundBitsHorizontal). Each one uses a rounding shift.
259         // Combining them requires adding the rounding offset from the skipped
260         // shift.
261         const int32x4_t v_first_shift_rounding_bit =
262             vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 2));
263         v_sum.val[0] = vaddq_s32(v_sum.val[0], v_first_shift_rounding_bit);
264         v_sum.val[1] = vaddq_s32(v_sum.val[1], v_first_shift_rounding_bit);
265         const uint16x4_t d0 = vmin_u16(
266             vqrshrun_n_s32(v_sum.val[0], kFilterBits - 1), v_max_bitdepth);
267         const uint16x4_t d1 = vmin_u16(
268             vqrshrun_n_s32(v_sum.val[1], kFilterBits - 1), v_max_bitdepth);
269         vst1_u16(&dest16[x], d0);
270         vst1_u16(&dest16[x + 4], d1);
271       }
272       x += 8;
273     } while (x < width);
274     src += src_stride;
275     dest16 += pred_stride;
276   } while (--y != 0);
277 }
278 
279 template <int num_taps, bool is_compound, bool is_2d>
FilterHorizontalWidth4(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int height,const int16x4_t * const v_tap)280 void FilterHorizontalWidth4(const uint16_t* LIBGAV1_RESTRICT src,
281                             const ptrdiff_t src_stride,
282                             void* LIBGAV1_RESTRICT const dest,
283                             const ptrdiff_t pred_stride, const int height,
284                             const int16x4_t* const v_tap) {
285   auto* dest16 = static_cast<uint16_t*>(dest);
286   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
287   int y = height;
288   do {
289     const uint16x8_t v_zero = vdupq_n_u16(0);
290     uint16x4_t v_src[4];
291     int32x4_t v_sum;
292     const uint16x8_t src_long = vld1q_u16(src);
293     v_src[0] = vget_low_u16(src_long);
294     if (num_taps == 2) {
295       v_src[1] = vget_low_u16(vextq_u16(src_long, v_zero, 1));
296       v_sum = SumOnePassTaps<num_taps>(v_src, v_tap + 3);
297     } else {
298       v_src[1] = vget_low_u16(vextq_u16(src_long, v_zero, 1));
299       v_src[2] = vget_low_u16(vextq_u16(src_long, v_zero, 2));
300       v_src[3] = vget_low_u16(vextq_u16(src_long, v_zero, 3));
301       v_sum = SumOnePassTaps<num_taps>(v_src, v_tap + 2);
302     }
303     if (is_compound || is_2d) {
304       const int16x4_t d0 = vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1);
305       if (is_compound && !is_2d) {
306         vst1_u16(&dest16[0], vreinterpret_u16_s16(
307                                  vadd_s16(d0, vdup_n_s16(kCompoundOffset))));
308       } else {
309         vst1_u16(&dest16[0], vreinterpret_u16_s16(d0));
310       }
311     } else {
312       const int32x4_t v_first_shift_rounding_bit =
313           vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 2));
314       v_sum = vaddq_s32(v_sum, v_first_shift_rounding_bit);
315       const uint16x4_t d0 =
316           vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
317       vst1_u16(&dest16[0], d0);
318     }
319     src += src_stride;
320     dest16 += pred_stride;
321   } while (--y != 0);
322 }
323 
324 template <int num_taps, bool is_2d>
FilterHorizontalWidth2(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int height,const int16x4_t * const v_tap)325 void FilterHorizontalWidth2(const uint16_t* LIBGAV1_RESTRICT src,
326                             const ptrdiff_t src_stride,
327                             void* LIBGAV1_RESTRICT const dest,
328                             const ptrdiff_t pred_stride, const int height,
329                             const int16x4_t* const v_tap) {
330   auto* dest16 = static_cast<uint16_t*>(dest);
331   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
332   int y = height >> 1;
333   do {
334     const int16x8_t v_zero = vdupq_n_s16(0);
335     const int16x8_t input0 = vreinterpretq_s16_u16(vld1q_u16(src));
336     const int16x8_t input1 = vreinterpretq_s16_u16(vld1q_u16(src + src_stride));
337     const int16x8x2_t input = vzipq_s16(input0, input1);
338     int32x4_t v_sum;
339     if (num_taps == 2) {
340       v_sum = vmull_s16(vget_low_s16(input.val[0]), v_tap[3]);
341       v_sum = vmlal_s16(v_sum,
342                         vget_low_s16(vextq_s16(input.val[0], input.val[1], 2)),
343                         v_tap[4]);
344     } else {
345       v_sum = vmull_s16(vget_low_s16(input.val[0]), v_tap[2]);
346       v_sum = vmlal_s16(v_sum, vget_low_s16(vextq_s16(input.val[0], v_zero, 2)),
347                         v_tap[3]);
348       v_sum = vmlal_s16(v_sum, vget_low_s16(vextq_s16(input.val[0], v_zero, 4)),
349                         v_tap[4]);
350       v_sum = vmlal_s16(v_sum,
351                         vget_low_s16(vextq_s16(input.val[0], input.val[1], 6)),
352                         v_tap[5]);
353     }
354     if (is_2d) {
355       const uint16x4_t d0 = vreinterpret_u16_s16(
356           vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1));
357       dest16[0] = vget_lane_u16(d0, 0);
358       dest16[1] = vget_lane_u16(d0, 2);
359       dest16 += pred_stride;
360       dest16[0] = vget_lane_u16(d0, 1);
361       dest16[1] = vget_lane_u16(d0, 3);
362       dest16 += pred_stride;
363     } else {
364       // Normally the Horizontal pass does the downshift in two passes:
365       // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
366       // kInterRoundBitsHorizontal). Each one uses a rounding shift.
367       // Combining them requires adding the rounding offset from the skipped
368       // shift.
369       const int32x4_t v_first_shift_rounding_bit =
370           vdupq_n_s32(1 << (kInterRoundBitsHorizontal - 2));
371       v_sum = vaddq_s32(v_sum, v_first_shift_rounding_bit);
372       const uint16x4_t d0 =
373           vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
374       dest16[0] = vget_lane_u16(d0, 0);
375       dest16[1] = vget_lane_u16(d0, 2);
376       dest16 += pred_stride;
377       dest16[0] = vget_lane_u16(d0, 1);
378       dest16[1] = vget_lane_u16(d0, 3);
379       dest16 += pred_stride;
380     }
381     src += src_stride << 1;
382   } while (--y != 0);
383 
384   // The 2d filters have an odd |height| because the horizontal pass
385   // generates context for the vertical pass.
386   if (is_2d) {
387     assert(height % 2 == 1);
388     const int16x8_t input = vreinterpretq_s16_u16(vld1q_u16(src));
389     int32x4_t v_sum;
390     if (num_taps == 2) {
391       v_sum = vmull_s16(vget_low_s16(input), v_tap[3]);
392       v_sum =
393           vmlal_s16(v_sum, vget_low_s16(vextq_s16(input, input, 1)), v_tap[4]);
394     } else {
395       v_sum = vmull_s16(vget_low_s16(input), v_tap[2]);
396       v_sum =
397           vmlal_s16(v_sum, vget_low_s16(vextq_s16(input, input, 1)), v_tap[3]);
398       v_sum =
399           vmlal_s16(v_sum, vget_low_s16(vextq_s16(input, input, 2)), v_tap[4]);
400       v_sum =
401           vmlal_s16(v_sum, vget_low_s16(vextq_s16(input, input, 3)), v_tap[5]);
402     }
403     const uint16x4_t d0 = vreinterpret_u16_s16(
404         vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1));
405     Store2<0>(dest16, d0);
406   }
407 }
408 
409 template <int num_taps, bool is_compound, bool is_2d>
FilterHorizontal(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int width,const int height,const int16x4_t * const v_tap)410 void FilterHorizontal(const uint16_t* LIBGAV1_RESTRICT const src,
411                       const ptrdiff_t src_stride,
412                       void* LIBGAV1_RESTRICT const dest,
413                       const ptrdiff_t pred_stride, const int width,
414                       const int height, const int16x4_t* const v_tap) {
415   // Horizontal passes only needs to account for number of taps 2 and 4 when
416   // |width| <= 4.
417   assert(width <= 4);
418   assert(num_taps == 2 || num_taps == 4);
419   if (num_taps == 2 || num_taps == 4) {
420     if (width == 2 && !is_compound) {
421       FilterHorizontalWidth2<num_taps, is_2d>(src, src_stride, dest,
422                                               pred_stride, height, v_tap);
423       return;
424     }
425     assert(width == 4);
426     FilterHorizontalWidth4<num_taps, is_compound, is_2d>(
427         src, src_stride, dest, pred_stride, height, v_tap);
428   } else {
429     assert(false);
430   }
431 }
432 
433 template <bool is_compound = false, bool is_2d = false>
DoHorizontalPass(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const int filter_id,const int filter_index)434 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
435     const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
436     void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
437     const int width, const int height, const int filter_id,
438     const int filter_index) {
439   // Duplicate the absolute value for each tap.  Negative taps are corrected
440   // by using the vmlsl_u8 instruction.  Positive taps use vmlal_u8.
441   int16x4_t v_tap[kSubPixelTaps];
442   assert(filter_id != 0);
443 
444   for (int k = 0; k < kSubPixelTaps; ++k) {
445     v_tap[k] = vdup_n_s16(kHalfSubPixelFilters[filter_index][filter_id][k]);
446   }
447 
448   // Horizontal filter.
449   // Filter types used for width <= 4 are different from those for width > 4.
450   // When width > 4, the valid filter index range is always [0, 3].
451   // When width <= 4, the valid filter index range is always [4, 5].
452   if (width >= 8) {
453     if (filter_index == 2) {  // 8 tap.
454       FilterHorizontalWidth8AndUp<8, is_compound, is_2d>(
455           src, src_stride, dst, dst_stride, width, height, v_tap);
456     } else if (filter_index < 2) {  // 6 tap.
457       FilterHorizontalWidth8AndUp<6, is_compound, is_2d>(
458           src + 1, src_stride, dst, dst_stride, width, height, v_tap);
459     } else {  // 2 tap.
460       assert(filter_index == 3);
461       FilterHorizontalWidth8AndUp<2, is_compound, is_2d>(
462           src + 3, src_stride, dst, dst_stride, width, height, v_tap);
463     }
464   } else {
465     if ((filter_index & 0x4) != 0) {  // 4 tap.
466       // ((filter_index == 4) | (filter_index == 5))
467       FilterHorizontal<4, is_compound, is_2d>(src + 2, src_stride, dst,
468                                               dst_stride, width, height, v_tap);
469     } else {  // 2 tap.
470       assert(filter_index == 3);
471       FilterHorizontal<2, is_compound, is_2d>(src + 3, src_stride, dst,
472                                               dst_stride, width, height, v_tap);
473     }
474   }
475 }
476 
ConvolveHorizontal_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int horizontal_filter_id,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)477 void ConvolveHorizontal_NEON(
478     const void* LIBGAV1_RESTRICT const reference,
479     const ptrdiff_t reference_stride, const int horizontal_filter_index,
480     const int /*vertical_filter_index*/, const int horizontal_filter_id,
481     const int /*vertical_filter_id*/, const int width, const int height,
482     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
483   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
484   // Set |src| to the outermost tap.
485   const auto* const src =
486       static_cast<const uint16_t*>(reference) - kHorizontalOffset;
487   auto* const dest = static_cast<uint16_t*>(prediction);
488   const ptrdiff_t src_stride = reference_stride >> 1;
489   const ptrdiff_t dst_stride = pred_stride >> 1;
490 
491   DoHorizontalPass(src, src_stride, dest, dst_stride, width, height,
492                    horizontal_filter_id, filter_index);
493 }
494 
ConvolveCompoundHorizontal_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int horizontal_filter_id,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t)495 void ConvolveCompoundHorizontal_NEON(
496     const void* LIBGAV1_RESTRICT const reference,
497     const ptrdiff_t reference_stride, const int horizontal_filter_index,
498     const int /*vertical_filter_index*/, const int horizontal_filter_id,
499     const int /*vertical_filter_id*/, const int width, const int height,
500     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
501   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
502   const auto* const src =
503       static_cast<const uint16_t*>(reference) - kHorizontalOffset;
504   auto* const dest = static_cast<uint16_t*>(prediction);
505   const ptrdiff_t src_stride = reference_stride >> 1;
506 
507   DoHorizontalPass</*is_compound=*/true>(src, src_stride, dest, width, width,
508                                          height, horizontal_filter_id,
509                                          filter_index);
510 }
511 
512 template <int num_taps, bool is_compound = false>
FilterVertical(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const int16x4_t * const taps)513 void FilterVertical(const uint16_t* LIBGAV1_RESTRICT const src,
514                     const ptrdiff_t src_stride,
515                     void* LIBGAV1_RESTRICT const dst,
516                     const ptrdiff_t dst_stride, const int width,
517                     const int height, const int16x4_t* const taps) {
518   const int next_row = num_taps - 1;
519   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
520   auto* const dst16 = static_cast<uint16_t*>(dst);
521   assert(width >= 8);
522 
523   int x = 0;
524   do {
525     const uint16_t* src_x = src + x;
526     uint16x8_t srcs[8];
527     srcs[0] = vld1q_u16(src_x);
528     src_x += src_stride;
529     if (num_taps >= 4) {
530       srcs[1] = vld1q_u16(src_x);
531       src_x += src_stride;
532       srcs[2] = vld1q_u16(src_x);
533       src_x += src_stride;
534       if (num_taps >= 6) {
535         srcs[3] = vld1q_u16(src_x);
536         src_x += src_stride;
537         srcs[4] = vld1q_u16(src_x);
538         src_x += src_stride;
539         if (num_taps == 8) {
540           srcs[5] = vld1q_u16(src_x);
541           src_x += src_stride;
542           srcs[6] = vld1q_u16(src_x);
543           src_x += src_stride;
544         }
545       }
546     }
547 
548     // Decreasing the y loop counter produces worse code with clang.
549     // Don't unroll this loop since it generates too much code and the decoder
550     // is even slower.
551     int y = 0;
552     do {
553       srcs[next_row] = vld1q_u16(src_x);
554       src_x += src_stride;
555 
556       const int32x4x2_t v_sum = SumOnePassTaps<num_taps>(srcs, taps);
557       if (is_compound) {
558         const int16x4_t v_compound_offset = vdup_n_s16(kCompoundOffset);
559         const int16x4_t d0 =
560             vqrshrn_n_s32(v_sum.val[0], kInterRoundBitsHorizontal - 1);
561         const int16x4_t d1 =
562             vqrshrn_n_s32(v_sum.val[1], kInterRoundBitsHorizontal - 1);
563         vst1_u16(dst16 + x + y * dst_stride,
564                  vreinterpret_u16_s16(vadd_s16(d0, v_compound_offset)));
565         vst1_u16(dst16 + x + 4 + y * dst_stride,
566                  vreinterpret_u16_s16(vadd_s16(d1, v_compound_offset)));
567       } else {
568         const uint16x4_t d0 = vmin_u16(
569             vqrshrun_n_s32(v_sum.val[0], kFilterBits - 1), v_max_bitdepth);
570         const uint16x4_t d1 = vmin_u16(
571             vqrshrun_n_s32(v_sum.val[1], kFilterBits - 1), v_max_bitdepth);
572         vst1_u16(dst16 + x + y * dst_stride, d0);
573         vst1_u16(dst16 + x + 4 + y * dst_stride, d1);
574       }
575 
576       srcs[0] = srcs[1];
577       if (num_taps >= 4) {
578         srcs[1] = srcs[2];
579         srcs[2] = srcs[3];
580         if (num_taps >= 6) {
581           srcs[3] = srcs[4];
582           srcs[4] = srcs[5];
583           if (num_taps == 8) {
584             srcs[5] = srcs[6];
585             srcs[6] = srcs[7];
586           }
587         }
588       }
589     } while (++y < height);
590     x += 8;
591   } while (x < width);
592 }
593 
594 template <int num_taps, bool is_compound = false>
FilterVertical4xH(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int height,const int16x4_t * const taps)595 void FilterVertical4xH(const uint16_t* LIBGAV1_RESTRICT src,
596                        const ptrdiff_t src_stride,
597                        void* LIBGAV1_RESTRICT const dst,
598                        const ptrdiff_t dst_stride, const int height,
599                        const int16x4_t* const taps) {
600   const int next_row = num_taps - 1;
601   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
602   auto* dst16 = static_cast<uint16_t*>(dst);
603 
604   uint16x4_t srcs[9];
605   srcs[0] = vld1_u16(src);
606   src += src_stride;
607   if (num_taps >= 4) {
608     srcs[1] = vld1_u16(src);
609     src += src_stride;
610     srcs[2] = vld1_u16(src);
611     src += src_stride;
612     if (num_taps >= 6) {
613       srcs[3] = vld1_u16(src);
614       src += src_stride;
615       srcs[4] = vld1_u16(src);
616       src += src_stride;
617       if (num_taps == 8) {
618         srcs[5] = vld1_u16(src);
619         src += src_stride;
620         srcs[6] = vld1_u16(src);
621         src += src_stride;
622       }
623     }
624   }
625 
626   int y = height;
627   do {
628     srcs[next_row] = vld1_u16(src);
629     src += src_stride;
630     srcs[num_taps] = vld1_u16(src);
631     src += src_stride;
632 
633     const int32x4_t v_sum = SumOnePassTaps<num_taps>(srcs, taps);
634     const int32x4_t v_sum_1 = SumOnePassTaps<num_taps>(srcs + 1, taps);
635     if (is_compound) {
636       const int16x4_t d0 = vqrshrn_n_s32(v_sum, kInterRoundBitsHorizontal - 1);
637       const int16x4_t d1 =
638           vqrshrn_n_s32(v_sum_1, kInterRoundBitsHorizontal - 1);
639       vst1_u16(dst16,
640                vreinterpret_u16_s16(vadd_s16(d0, vdup_n_s16(kCompoundOffset))));
641       dst16 += dst_stride;
642       vst1_u16(dst16,
643                vreinterpret_u16_s16(vadd_s16(d1, vdup_n_s16(kCompoundOffset))));
644       dst16 += dst_stride;
645     } else {
646       const uint16x4_t d0 =
647           vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
648       const uint16x4_t d1 =
649           vmin_u16(vqrshrun_n_s32(v_sum_1, kFilterBits - 1), v_max_bitdepth);
650       vst1_u16(dst16, d0);
651       dst16 += dst_stride;
652       vst1_u16(dst16, d1);
653       dst16 += dst_stride;
654     }
655 
656     srcs[0] = srcs[2];
657     if (num_taps >= 4) {
658       srcs[1] = srcs[3];
659       srcs[2] = srcs[4];
660       if (num_taps >= 6) {
661         srcs[3] = srcs[5];
662         srcs[4] = srcs[6];
663         if (num_taps == 8) {
664           srcs[5] = srcs[7];
665           srcs[6] = srcs[8];
666         }
667       }
668     }
669     y -= 2;
670   } while (y != 0);
671 }
672 
673 template <int num_taps>
FilterVertical2xH(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int height,const int16x4_t * const taps)674 void FilterVertical2xH(const uint16_t* LIBGAV1_RESTRICT src,
675                        const ptrdiff_t src_stride,
676                        void* LIBGAV1_RESTRICT const dst,
677                        const ptrdiff_t dst_stride, const int height,
678                        const int16x4_t* const taps) {
679   const int next_row = num_taps - 1;
680   const uint16x4_t v_max_bitdepth = vdup_n_u16((1 << kBitdepth10) - 1);
681   auto* dst16 = static_cast<uint16_t*>(dst);
682   const uint16x4_t v_zero = vdup_n_u16(0);
683 
684   uint16x4_t srcs[9];
685   srcs[0] = Load2<0>(src, v_zero);
686   src += src_stride;
687   if (num_taps >= 4) {
688     srcs[0] = Load2<1>(src, srcs[0]);
689     src += src_stride;
690     srcs[2] = Load2<0>(src, v_zero);
691     src += src_stride;
692     srcs[1] = vext_u16(srcs[0], srcs[2], 2);
693     if (num_taps >= 6) {
694       srcs[2] = Load2<1>(src, srcs[2]);
695       src += src_stride;
696       srcs[4] = Load2<0>(src, v_zero);
697       src += src_stride;
698       srcs[3] = vext_u16(srcs[2], srcs[4], 2);
699       if (num_taps == 8) {
700         srcs[4] = Load2<1>(src, srcs[4]);
701         src += src_stride;
702         srcs[6] = Load2<0>(src, v_zero);
703         src += src_stride;
704         srcs[5] = vext_u16(srcs[4], srcs[6], 2);
705       }
706     }
707   }
708 
709   int y = height;
710   do {
711     srcs[next_row - 1] = Load2<1>(src, srcs[next_row - 1]);
712     src += src_stride;
713     srcs[num_taps] = Load2<0>(src, v_zero);
714     src += src_stride;
715     srcs[next_row] = vext_u16(srcs[next_row - 1], srcs[num_taps], 2);
716 
717     const int32x4_t v_sum = SumOnePassTaps<num_taps>(srcs, taps);
718     const uint16x4_t d0 =
719         vmin_u16(vqrshrun_n_s32(v_sum, kFilterBits - 1), v_max_bitdepth);
720     Store2<0>(dst16, d0);
721     dst16 += dst_stride;
722     Store2<1>(dst16, d0);
723     dst16 += dst_stride;
724 
725     srcs[0] = srcs[2];
726     if (num_taps >= 4) {
727       srcs[1] = srcs[3];
728       srcs[2] = srcs[4];
729       if (num_taps >= 6) {
730         srcs[3] = srcs[5];
731         srcs[4] = srcs[6];
732         if (num_taps == 8) {
733           srcs[5] = srcs[7];
734           srcs[6] = srcs[8];
735         }
736       }
737     }
738     y -= 2;
739   } while (y != 0);
740 }
741 
742 template <int num_taps, bool is_compound>
SimpleSum2DVerticalTaps(const int16x8_t * const src,const int16x8_t taps)743 int16x8_t SimpleSum2DVerticalTaps(const int16x8_t* const src,
744                                   const int16x8_t taps) {
745   const int16x4_t taps_lo = vget_low_s16(taps);
746   const int16x4_t taps_hi = vget_high_s16(taps);
747   int32x4_t sum_lo, sum_hi;
748   if (num_taps == 8) {
749     sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 0);
750     sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 0);
751     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 1);
752     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 1);
753     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 2);
754     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 2);
755     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_lo, 3);
756     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_lo, 3);
757 
758     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 0);
759     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 0);
760     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 1);
761     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 1);
762     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[6]), taps_hi, 2);
763     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[6]), taps_hi, 2);
764     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[7]), taps_hi, 3);
765     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[7]), taps_hi, 3);
766   } else if (num_taps == 6) {
767     sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 1);
768     sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 1);
769     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 2);
770     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 2);
771     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 3);
772     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 3);
773 
774     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 0);
775     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 0);
776     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 1);
777     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 1);
778     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 2);
779     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 2);
780   } else if (num_taps == 4) {
781     sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 2);
782     sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 2);
783     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 3);
784     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 3);
785 
786     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_hi, 0);
787     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_hi, 0);
788     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 1);
789     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 1);
790   } else if (num_taps == 2) {
791     sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 3);
792     sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 3);
793 
794     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_hi, 0);
795     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_hi, 0);
796   }
797 
798   if (is_compound) {
799     // Output is compound, so leave signed and do not saturate. Offset will
800     // accurately bring the value back into positive range.
801     return vcombine_s16(
802         vrshrn_n_s32(sum_lo, kInterRoundBitsCompoundVertical - 1),
803         vrshrn_n_s32(sum_hi, kInterRoundBitsCompoundVertical - 1));
804   }
805 
806   // Output is pixel, so saturate to clip at 0.
807   return vreinterpretq_s16_u16(
808       vcombine_u16(vqrshrun_n_s32(sum_lo, kInterRoundBitsVertical - 1),
809                    vqrshrun_n_s32(sum_hi, kInterRoundBitsVertical - 1)));
810 }
811 
812 template <int num_taps, bool is_compound = false>
Filter2DVerticalWidth8AndUp(const int16_t * LIBGAV1_RESTRICT src,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const int16x8_t taps)813 void Filter2DVerticalWidth8AndUp(const int16_t* LIBGAV1_RESTRICT src,
814                                  void* LIBGAV1_RESTRICT const dst,
815                                  const ptrdiff_t dst_stride, const int width,
816                                  const int height, const int16x8_t taps) {
817   assert(width >= 8);
818   constexpr int next_row = num_taps - 1;
819   const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
820   auto* const dst16 = static_cast<uint16_t*>(dst);
821 
822   int x = 0;
823   do {
824     int16x8_t srcs[9];
825     srcs[0] = vld1q_s16(src);
826     src += 8;
827     if (num_taps >= 4) {
828       srcs[1] = vld1q_s16(src);
829       src += 8;
830       srcs[2] = vld1q_s16(src);
831       src += 8;
832       if (num_taps >= 6) {
833         srcs[3] = vld1q_s16(src);
834         src += 8;
835         srcs[4] = vld1q_s16(src);
836         src += 8;
837         if (num_taps == 8) {
838           srcs[5] = vld1q_s16(src);
839           src += 8;
840           srcs[6] = vld1q_s16(src);
841           src += 8;
842         }
843       }
844     }
845 
846     uint16_t* d16 = dst16 + x;
847     int y = height;
848     do {
849       srcs[next_row] = vld1q_s16(src);
850       src += 8;
851       srcs[next_row + 1] = vld1q_s16(src);
852       src += 8;
853       const int16x8_t sum0 =
854           SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs + 0, taps);
855       const int16x8_t sum1 =
856           SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs + 1, taps);
857       if (is_compound) {
858         const int16x8_t v_compound_offset = vdupq_n_s16(kCompoundOffset);
859         vst1q_u16(d16,
860                   vreinterpretq_u16_s16(vaddq_s16(sum0, v_compound_offset)));
861         d16 += dst_stride;
862         vst1q_u16(d16,
863                   vreinterpretq_u16_s16(vaddq_s16(sum1, v_compound_offset)));
864         d16 += dst_stride;
865       } else {
866         vst1q_u16(d16, vminq_u16(vreinterpretq_u16_s16(sum0), v_max_bitdepth));
867         d16 += dst_stride;
868         vst1q_u16(d16, vminq_u16(vreinterpretq_u16_s16(sum1), v_max_bitdepth));
869         d16 += dst_stride;
870       }
871       srcs[0] = srcs[2];
872       if (num_taps >= 4) {
873         srcs[1] = srcs[3];
874         srcs[2] = srcs[4];
875         if (num_taps >= 6) {
876           srcs[3] = srcs[5];
877           srcs[4] = srcs[6];
878           if (num_taps == 8) {
879             srcs[5] = srcs[7];
880             srcs[6] = srcs[8];
881           }
882         }
883       }
884       y -= 2;
885     } while (y != 0);
886     x += 8;
887   } while (x < width);
888 }
889 
890 // Take advantage of |src_stride| == |width| to process two rows at a time.
891 template <int num_taps, bool is_compound = false>
Filter2DVerticalWidth4(const int16_t * LIBGAV1_RESTRICT src,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int height,const int16x8_t taps)892 void Filter2DVerticalWidth4(const int16_t* LIBGAV1_RESTRICT src,
893                             void* LIBGAV1_RESTRICT const dst,
894                             const ptrdiff_t dst_stride, const int height,
895                             const int16x8_t taps) {
896   const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
897   auto* dst16 = static_cast<uint16_t*>(dst);
898 
899   int16x8_t srcs[9];
900   srcs[0] = vld1q_s16(src);
901   src += 8;
902   if (num_taps >= 4) {
903     srcs[2] = vld1q_s16(src);
904     src += 8;
905     srcs[1] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[2]));
906     if (num_taps >= 6) {
907       srcs[4] = vld1q_s16(src);
908       src += 8;
909       srcs[3] = vcombine_s16(vget_high_s16(srcs[2]), vget_low_s16(srcs[4]));
910       if (num_taps == 8) {
911         srcs[6] = vld1q_s16(src);
912         src += 8;
913         srcs[5] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[6]));
914       }
915     }
916   }
917 
918   int y = height;
919   do {
920     srcs[num_taps] = vld1q_s16(src);
921     src += 8;
922     srcs[num_taps - 1] = vcombine_s16(vget_high_s16(srcs[num_taps - 2]),
923                                       vget_low_s16(srcs[num_taps]));
924 
925     const int16x8_t sum =
926         SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
927     if (is_compound) {
928       const int16x8_t v_compound_offset = vdupq_n_s16(kCompoundOffset);
929       vst1q_u16(dst16,
930                 vreinterpretq_u16_s16(vaddq_s16(sum, v_compound_offset)));
931       dst16 += 4 << 1;
932     } else {
933       const uint16x8_t d0 =
934           vminq_u16(vreinterpretq_u16_s16(sum), v_max_bitdepth);
935       vst1_u16(dst16, vget_low_u16(d0));
936       dst16 += dst_stride;
937       vst1_u16(dst16, vget_high_u16(d0));
938       dst16 += dst_stride;
939     }
940 
941     srcs[0] = srcs[2];
942     if (num_taps >= 4) {
943       srcs[1] = srcs[3];
944       srcs[2] = srcs[4];
945       if (num_taps >= 6) {
946         srcs[3] = srcs[5];
947         srcs[4] = srcs[6];
948         if (num_taps == 8) {
949           srcs[5] = srcs[7];
950           srcs[6] = srcs[8];
951         }
952       }
953     }
954     y -= 2;
955   } while (y != 0);
956 }
957 
958 // Take advantage of |src_stride| == |width| to process four rows at a time.
959 template <int num_taps>
Filter2DVerticalWidth2(const int16_t * LIBGAV1_RESTRICT src,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int height,const int16x8_t taps)960 void Filter2DVerticalWidth2(const int16_t* LIBGAV1_RESTRICT src,
961                             void* LIBGAV1_RESTRICT const dst,
962                             const ptrdiff_t dst_stride, const int height,
963                             const int16x8_t taps) {
964   constexpr int next_row = (num_taps < 6) ? 4 : 8;
965   const uint16x8_t v_max_bitdepth = vdupq_n_u16((1 << kBitdepth10) - 1);
966   auto* dst16 = static_cast<uint16_t*>(dst);
967 
968   int16x8_t srcs[9];
969   srcs[0] = vld1q_s16(src);
970   src += 8;
971   if (num_taps >= 6) {
972     srcs[4] = vld1q_s16(src);
973     src += 8;
974     srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
975     if (num_taps == 8) {
976       srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
977       srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
978     }
979   }
980 
981   int y = height;
982   do {
983     srcs[next_row] = vld1q_s16(src);
984     src += 8;
985     if (num_taps == 2) {
986       srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
987     } else if (num_taps == 4) {
988       srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
989       srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
990       srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
991     } else if (num_taps == 6) {
992       srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
993       srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
994       srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
995     } else if (num_taps == 8) {
996       srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
997       srcs[6] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[8]));
998       srcs[7] = vextq_s16(srcs[4], srcs[8], 6);
999     }
1000     const int16x8_t sum =
1001         SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps);
1002     const uint16x8_t d0 = vminq_u16(vreinterpretq_u16_s16(sum), v_max_bitdepth);
1003     Store2<0>(dst16, d0);
1004     dst16 += dst_stride;
1005     Store2<1>(dst16, d0);
1006     // When |height| <= 4 the taps are restricted to 2 and 4 tap variants.
1007     // Therefore we don't need to check this condition when |height| > 4.
1008     if (num_taps <= 4 && height == 2) return;
1009     dst16 += dst_stride;
1010     Store2<2>(dst16, d0);
1011     dst16 += dst_stride;
1012     Store2<3>(dst16, d0);
1013     dst16 += dst_stride;
1014 
1015     srcs[0] = srcs[4];
1016     if (num_taps == 6) {
1017       srcs[1] = srcs[5];
1018       srcs[4] = srcs[8];
1019     } else if (num_taps == 8) {
1020       srcs[1] = srcs[5];
1021       srcs[2] = srcs[6];
1022       srcs[3] = srcs[7];
1023       srcs[4] = srcs[8];
1024     }
1025 
1026     y -= 4;
1027   } while (y != 0);
1028 }
1029 
1030 template <int vertical_taps>
Filter2DVertical(const int16_t * LIBGAV1_RESTRICT const intermediate_result,const int width,const int height,const int16x8_t taps,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1031 void Filter2DVertical(const int16_t* LIBGAV1_RESTRICT const intermediate_result,
1032                       const int width, const int height, const int16x8_t taps,
1033                       void* LIBGAV1_RESTRICT const prediction,
1034                       const ptrdiff_t pred_stride) {
1035   auto* const dest = static_cast<uint16_t*>(prediction);
1036   if (width >= 8) {
1037     Filter2DVerticalWidth8AndUp<vertical_taps>(
1038         intermediate_result, dest, pred_stride, width, height, taps);
1039   } else if (width == 4) {
1040     Filter2DVerticalWidth4<vertical_taps>(intermediate_result, dest,
1041                                           pred_stride, height, taps);
1042   } else {
1043     assert(width == 2);
1044     Filter2DVerticalWidth2<vertical_taps>(intermediate_result, dest,
1045                                           pred_stride, height, taps);
1046   }
1047 }
1048 
Convolve2D_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int horizontal_filter_id,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1049 void Convolve2D_NEON(const void* LIBGAV1_RESTRICT const reference,
1050                      const ptrdiff_t reference_stride,
1051                      const int horizontal_filter_index,
1052                      const int vertical_filter_index,
1053                      const int horizontal_filter_id,
1054                      const int vertical_filter_id, const int width,
1055                      const int height, void* LIBGAV1_RESTRICT const prediction,
1056                      const ptrdiff_t pred_stride) {
1057   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
1058   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
1059   const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
1060   // The output of the horizontal filter is guaranteed to fit in 16 bits.
1061   int16_t intermediate_result[kMaxSuperBlockSizeInPixels *
1062                               (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
1063 #if LIBGAV1_MSAN
1064   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
1065   memset(intermediate_result, 0x43, sizeof(intermediate_result));
1066 #endif
1067   const int intermediate_height = height + vertical_taps - 1;
1068   const ptrdiff_t src_stride = reference_stride >> 1;
1069   const auto* const src = static_cast<const uint16_t*>(reference) -
1070                           (vertical_taps / 2 - 1) * src_stride -
1071                           kHorizontalOffset;
1072   const ptrdiff_t dest_stride = pred_stride >> 1;
1073 
1074   DoHorizontalPass</*is_compound=*/false, /*is_2d=*/true>(
1075       src, src_stride, intermediate_result, width, width, intermediate_height,
1076       horizontal_filter_id, horiz_filter_index);
1077 
1078   assert(vertical_filter_id != 0);
1079   const int16x8_t taps = vmovl_s8(
1080       vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]));
1081   if (vertical_taps == 8) {
1082     Filter2DVertical<8>(intermediate_result, width, height, taps, prediction,
1083                         dest_stride);
1084   } else if (vertical_taps == 6) {
1085     Filter2DVertical<6>(intermediate_result, width, height, taps, prediction,
1086                         dest_stride);
1087   } else if (vertical_taps == 4) {
1088     Filter2DVertical<4>(intermediate_result, width, height, taps, prediction,
1089                         dest_stride);
1090   } else {  // |vertical_taps| == 2
1091     Filter2DVertical<2>(intermediate_result, width, height, taps, prediction,
1092                         dest_stride);
1093   }
1094 }
1095 
1096 template <int vertical_taps>
Compound2DVertical(const int16_t * LIBGAV1_RESTRICT const intermediate_result,const int width,const int height,const int16x8_t taps,void * LIBGAV1_RESTRICT const prediction)1097 void Compound2DVertical(
1098     const int16_t* LIBGAV1_RESTRICT const intermediate_result, const int width,
1099     const int height, const int16x8_t taps,
1100     void* LIBGAV1_RESTRICT const prediction) {
1101   auto* const dest = static_cast<uint16_t*>(prediction);
1102   if (width == 4) {
1103     Filter2DVerticalWidth4<vertical_taps, /*is_compound=*/true>(
1104         intermediate_result, dest, width, height, taps);
1105   } else {
1106     Filter2DVerticalWidth8AndUp<vertical_taps, /*is_compound=*/true>(
1107         intermediate_result, dest, width, width, height, taps);
1108   }
1109 }
1110 
ConvolveCompound2D_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int horizontal_filter_id,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t)1111 void ConvolveCompound2D_NEON(
1112     const void* LIBGAV1_RESTRICT const reference,
1113     const ptrdiff_t reference_stride, const int horizontal_filter_index,
1114     const int vertical_filter_index, const int horizontal_filter_id,
1115     const int vertical_filter_id, const int width, const int height,
1116     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
1117   // The output of the horizontal filter, i.e. the intermediate_result, is
1118   // guaranteed to fit in int16_t.
1119   int16_t
1120       intermediate_result[(kMaxSuperBlockSizeInPixels *
1121                            (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1))];
1122 
1123   // Horizontal filter.
1124   // Filter types used for width <= 4 are different from those for width > 4.
1125   // When width > 4, the valid filter index range is always [0, 3].
1126   // When width <= 4, the valid filter index range is always [4, 5].
1127   // Similarly for height.
1128   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
1129   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
1130   const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
1131   const int intermediate_height = height + vertical_taps - 1;
1132   const ptrdiff_t src_stride = reference_stride >> 1;
1133   const auto* const src = static_cast<const uint16_t*>(reference) -
1134                           (vertical_taps / 2 - 1) * src_stride -
1135                           kHorizontalOffset;
1136 
1137   DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
1138       src, src_stride, intermediate_result, width, width, intermediate_height,
1139       horizontal_filter_id, horiz_filter_index);
1140 
1141   // Vertical filter.
1142   assert(vertical_filter_id != 0);
1143   const int16x8_t taps = vmovl_s8(
1144       vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]));
1145   if (vertical_taps == 8) {
1146     Compound2DVertical<8>(intermediate_result, width, height, taps, prediction);
1147   } else if (vertical_taps == 6) {
1148     Compound2DVertical<6>(intermediate_result, width, height, taps, prediction);
1149   } else if (vertical_taps == 4) {
1150     Compound2DVertical<4>(intermediate_result, width, height, taps, prediction);
1151   } else {  // |vertical_taps| == 2
1152     Compound2DVertical<2>(intermediate_result, width, height, taps, prediction);
1153   }
1154 }
1155 
ConvolveVertical_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1156 void ConvolveVertical_NEON(
1157     const void* LIBGAV1_RESTRICT const reference,
1158     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1159     const int vertical_filter_index, const int /*horizontal_filter_id*/,
1160     const int vertical_filter_id, const int width, const int height,
1161     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
1162   const int filter_index = GetFilterIndex(vertical_filter_index, height);
1163   const int vertical_taps = GetNumTapsInFilter(filter_index);
1164   const ptrdiff_t src_stride = reference_stride >> 1;
1165   const auto* src = static_cast<const uint16_t*>(reference) -
1166                     (vertical_taps / 2 - 1) * src_stride;
1167   auto* const dest = static_cast<uint16_t*>(prediction);
1168   const ptrdiff_t dest_stride = pred_stride >> 1;
1169   assert(vertical_filter_id != 0);
1170 
1171   int16x4_t taps[8];
1172   for (int k = 0; k < kSubPixelTaps; ++k) {
1173     taps[k] =
1174         vdup_n_s16(kHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
1175   }
1176 
1177   if (filter_index == 0) {  // 6 tap.
1178     if (width == 2) {
1179       FilterVertical2xH<6>(src, src_stride, dest, dest_stride, height,
1180                            taps + 1);
1181     } else if (width == 4) {
1182       FilterVertical4xH<6>(src, src_stride, dest, dest_stride, height,
1183                            taps + 1);
1184     } else {
1185       FilterVertical<6>(src, src_stride, dest, dest_stride, width, height,
1186                         taps + 1);
1187     }
1188   } else if ((static_cast<int>(filter_index == 1) &
1189               (static_cast<int>(vertical_filter_id == 1) |
1190                static_cast<int>(vertical_filter_id == 7) |
1191                static_cast<int>(vertical_filter_id == 8) |
1192                static_cast<int>(vertical_filter_id == 9) |
1193                static_cast<int>(vertical_filter_id == 15))) != 0) {  // 6 tap.
1194     if (width == 2) {
1195       FilterVertical2xH<6>(src, src_stride, dest, dest_stride, height,
1196                            taps + 1);
1197     } else if (width == 4) {
1198       FilterVertical4xH<6>(src, src_stride, dest, dest_stride, height,
1199                            taps + 1);
1200     } else {
1201       FilterVertical<6>(src, src_stride, dest, dest_stride, width, height,
1202                         taps + 1);
1203     }
1204   } else if (filter_index == 2) {  // 8 tap.
1205     if (width == 2) {
1206       FilterVertical2xH<8>(src, src_stride, dest, dest_stride, height, taps);
1207     } else if (width == 4) {
1208       FilterVertical4xH<8>(src, src_stride, dest, dest_stride, height, taps);
1209     } else {
1210       FilterVertical<8>(src, src_stride, dest, dest_stride, width, height,
1211                         taps);
1212     }
1213   } else if (filter_index == 3) {  // 2 tap.
1214     if (width == 2) {
1215       FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height,
1216                            taps + 3);
1217     } else if (width == 4) {
1218       FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height,
1219                            taps + 3);
1220     } else {
1221       FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
1222                         taps + 3);
1223     }
1224   } else {
1225     // 4 tap. When |filter_index| == 1 the |vertical_filter_id| values listed
1226     // below map to 4 tap filters.
1227     assert(filter_index == 5 || filter_index == 4 ||
1228            (filter_index == 1 &&
1229             (vertical_filter_id == 0 || vertical_filter_id == 2 ||
1230              vertical_filter_id == 3 || vertical_filter_id == 4 ||
1231              vertical_filter_id == 5 || vertical_filter_id == 6 ||
1232              vertical_filter_id == 10 || vertical_filter_id == 11 ||
1233              vertical_filter_id == 12 || vertical_filter_id == 13 ||
1234              vertical_filter_id == 14)));
1235     // According to GetNumTapsInFilter() this has 6 taps but here we are
1236     // treating it as though it has 4.
1237     if (filter_index == 1) src += src_stride;
1238     if (width == 2) {
1239       FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height,
1240                            taps + 2);
1241     } else if (width == 4) {
1242       FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height,
1243                            taps + 2);
1244     } else {
1245       FilterVertical<4>(src, src_stride, dest, dest_stride, width, height,
1246                         taps + 2);
1247     }
1248   }
1249 }
1250 
ConvolveCompoundVertical_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t)1251 void ConvolveCompoundVertical_NEON(
1252     const void* LIBGAV1_RESTRICT const reference,
1253     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1254     const int vertical_filter_index, const int /*horizontal_filter_id*/,
1255     const int vertical_filter_id, const int width, const int height,
1256     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
1257   const int filter_index = GetFilterIndex(vertical_filter_index, height);
1258   const int vertical_taps = GetNumTapsInFilter(filter_index);
1259   const ptrdiff_t src_stride = reference_stride >> 1;
1260   const auto* src = static_cast<const uint16_t*>(reference) -
1261                     (vertical_taps / 2 - 1) * src_stride;
1262   auto* const dest = static_cast<uint16_t*>(prediction);
1263   assert(vertical_filter_id != 0);
1264 
1265   int16x4_t taps[8];
1266   for (int k = 0; k < kSubPixelTaps; ++k) {
1267     taps[k] =
1268         vdup_n_s16(kHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
1269   }
1270 
1271   if (filter_index == 0) {  // 6 tap.
1272     if (width == 4) {
1273       FilterVertical4xH<6, /*is_compound=*/true>(src, src_stride, dest, 4,
1274                                                  height, taps + 1);
1275     } else {
1276       FilterVertical<6, /*is_compound=*/true>(src, src_stride, dest, width,
1277                                               width, height, taps + 1);
1278     }
1279   } else if ((static_cast<int>(filter_index == 1) &
1280               (static_cast<int>(vertical_filter_id == 1) |
1281                static_cast<int>(vertical_filter_id == 7) |
1282                static_cast<int>(vertical_filter_id == 8) |
1283                static_cast<int>(vertical_filter_id == 9) |
1284                static_cast<int>(vertical_filter_id == 15))) != 0) {  // 6 tap.
1285     if (width == 4) {
1286       FilterVertical4xH<6, /*is_compound=*/true>(src, src_stride, dest, 4,
1287                                                  height, taps + 1);
1288     } else {
1289       FilterVertical<6, /*is_compound=*/true>(src, src_stride, dest, width,
1290                                               width, height, taps + 1);
1291     }
1292   } else if (filter_index == 2) {  // 8 tap.
1293     if (width == 4) {
1294       FilterVertical4xH<8, /*is_compound=*/true>(src, src_stride, dest, 4,
1295                                                  height, taps);
1296     } else {
1297       FilterVertical<8, /*is_compound=*/true>(src, src_stride, dest, width,
1298                                               width, height, taps);
1299     }
1300   } else if (filter_index == 3) {  // 2 tap.
1301     if (width == 4) {
1302       FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
1303                                                  height, taps + 3);
1304     } else {
1305       FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
1306                                               width, height, taps + 3);
1307     }
1308   } else {
1309     // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map
1310     // to 4 tap filters.
1311     assert(filter_index == 5 || filter_index == 4 ||
1312            (filter_index == 1 &&
1313             (vertical_filter_id == 2 || vertical_filter_id == 3 ||
1314              vertical_filter_id == 4 || vertical_filter_id == 5 ||
1315              vertical_filter_id == 6 || vertical_filter_id == 10 ||
1316              vertical_filter_id == 11 || vertical_filter_id == 12 ||
1317              vertical_filter_id == 13 || vertical_filter_id == 14)));
1318     // According to GetNumTapsInFilter() this has 6 taps but here we are
1319     // treating it as though it has 4.
1320     if (filter_index == 1) src += src_stride;
1321     if (width == 4) {
1322       FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4,
1323                                                  height, taps + 2);
1324     } else {
1325       FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width,
1326                                               width, height, taps + 2);
1327     }
1328   }
1329 }
1330 
ConvolveCompoundCopy_NEON(const void * const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * const prediction,const ptrdiff_t)1331 void ConvolveCompoundCopy_NEON(
1332     const void* const reference, const ptrdiff_t reference_stride,
1333     const int /*horizontal_filter_index*/, const int /*vertical_filter_index*/,
1334     const int /*horizontal_filter_id*/, const int /*vertical_filter_id*/,
1335     const int width, const int height, void* const prediction,
1336     const ptrdiff_t /*pred_stride*/) {
1337   const auto* src = static_cast<const uint16_t*>(reference);
1338   const ptrdiff_t src_stride = reference_stride >> 1;
1339   auto* dest = static_cast<uint16_t*>(prediction);
1340   constexpr int final_shift =
1341       kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
1342   const uint16x8_t offset =
1343       vdupq_n_u16((1 << kBitdepth10) + (1 << (kBitdepth10 - 1)));
1344 
1345   if (width >= 16) {
1346     int y = height;
1347     do {
1348       int x = 0;
1349       int w = width;
1350       do {
1351         const uint16x8_t v_src_lo = vld1q_u16(&src[x]);
1352         const uint16x8_t v_src_hi = vld1q_u16(&src[x + 8]);
1353         const uint16x8_t v_sum_lo = vaddq_u16(v_src_lo, offset);
1354         const uint16x8_t v_sum_hi = vaddq_u16(v_src_hi, offset);
1355         const uint16x8_t v_dest_lo = vshlq_n_u16(v_sum_lo, final_shift);
1356         const uint16x8_t v_dest_hi = vshlq_n_u16(v_sum_hi, final_shift);
1357         vst1q_u16(&dest[x], v_dest_lo);
1358         vst1q_u16(&dest[x + 8], v_dest_hi);
1359         x += 16;
1360         w -= 16;
1361       } while (w != 0);
1362       src += src_stride;
1363       dest += width;
1364     } while (--y != 0);
1365   } else if (width == 8) {
1366     int y = height;
1367     do {
1368       const uint16x8_t v_src_lo = vld1q_u16(&src[0]);
1369       const uint16x8_t v_src_hi = vld1q_u16(&src[src_stride]);
1370       const uint16x8_t v_sum_lo = vaddq_u16(v_src_lo, offset);
1371       const uint16x8_t v_sum_hi = vaddq_u16(v_src_hi, offset);
1372       const uint16x8_t v_dest_lo = vshlq_n_u16(v_sum_lo, final_shift);
1373       const uint16x8_t v_dest_hi = vshlq_n_u16(v_sum_hi, final_shift);
1374       vst1q_u16(&dest[0], v_dest_lo);
1375       vst1q_u16(&dest[8], v_dest_hi);
1376       src += src_stride << 1;
1377       dest += 16;
1378       y -= 2;
1379     } while (y != 0);
1380   } else {  // width == 4
1381     int y = height;
1382     do {
1383       const uint16x4_t v_src_lo = vld1_u16(&src[0]);
1384       const uint16x4_t v_src_hi = vld1_u16(&src[src_stride]);
1385       const uint16x4_t v_sum_lo = vadd_u16(v_src_lo, vget_low_u16(offset));
1386       const uint16x4_t v_sum_hi = vadd_u16(v_src_hi, vget_low_u16(offset));
1387       const uint16x4_t v_dest_lo = vshl_n_u16(v_sum_lo, final_shift);
1388       const uint16x4_t v_dest_hi = vshl_n_u16(v_sum_hi, final_shift);
1389       vst1_u16(&dest[0], v_dest_lo);
1390       vst1_u16(&dest[4], v_dest_hi);
1391       src += src_stride << 1;
1392       dest += 8;
1393       y -= 2;
1394     } while (y != 0);
1395   }
1396 }
1397 
HalfAddHorizontal(const uint16_t * LIBGAV1_RESTRICT const src,uint16_t * LIBGAV1_RESTRICT const dst)1398 inline void HalfAddHorizontal(const uint16_t* LIBGAV1_RESTRICT const src,
1399                               uint16_t* LIBGAV1_RESTRICT const dst) {
1400   const uint16x8_t left = vld1q_u16(src);
1401   const uint16x8_t right = vld1q_u16(src + 1);
1402   vst1q_u16(dst, vrhaddq_u16(left, right));
1403 }
1404 
HalfAddHorizontal16(const uint16_t * LIBGAV1_RESTRICT const src,uint16_t * LIBGAV1_RESTRICT const dst)1405 inline void HalfAddHorizontal16(const uint16_t* LIBGAV1_RESTRICT const src,
1406                                 uint16_t* LIBGAV1_RESTRICT const dst) {
1407   HalfAddHorizontal(src, dst);
1408   HalfAddHorizontal(src + 8, dst + 8);
1409 }
1410 
1411 template <int width>
IntraBlockCopyHorizontal(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)1412 inline void IntraBlockCopyHorizontal(const uint16_t* LIBGAV1_RESTRICT src,
1413                                      const ptrdiff_t src_stride,
1414                                      const int height,
1415                                      uint16_t* LIBGAV1_RESTRICT dst,
1416                                      const ptrdiff_t dst_stride) {
1417   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
1418   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
1419 
1420   int y = height;
1421   do {
1422     HalfAddHorizontal16(src, dst);
1423     if (width >= 32) {
1424       src += 16;
1425       dst += 16;
1426       HalfAddHorizontal16(src, dst);
1427       if (width >= 64) {
1428         src += 16;
1429         dst += 16;
1430         HalfAddHorizontal16(src, dst);
1431         src += 16;
1432         dst += 16;
1433         HalfAddHorizontal16(src, dst);
1434         if (width == 128) {
1435           src += 16;
1436           dst += 16;
1437           HalfAddHorizontal16(src, dst);
1438           src += 16;
1439           dst += 16;
1440           HalfAddHorizontal16(src, dst);
1441           src += 16;
1442           dst += 16;
1443           HalfAddHorizontal16(src, dst);
1444           src += 16;
1445           dst += 16;
1446           HalfAddHorizontal16(src, dst);
1447         }
1448       }
1449     }
1450     src += src_remainder_stride;
1451     dst += dst_remainder_stride;
1452   } while (--y != 0);
1453 }
1454 
ConvolveIntraBlockCopyHorizontal_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1455 void ConvolveIntraBlockCopyHorizontal_NEON(
1456     const void* LIBGAV1_RESTRICT const reference,
1457     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1458     const int /*vertical_filter_index*/, const int /*subpixel_x*/,
1459     const int /*subpixel_y*/, const int width, const int height,
1460     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
1461   assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
1462   assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
1463   const auto* src = static_cast<const uint16_t*>(reference);
1464   auto* dest = static_cast<uint16_t*>(prediction);
1465   const ptrdiff_t src_stride = reference_stride >> 1;
1466   const ptrdiff_t dst_stride = pred_stride >> 1;
1467 
1468   if (width == 128) {
1469     IntraBlockCopyHorizontal<128>(src, src_stride, height, dest, dst_stride);
1470   } else if (width == 64) {
1471     IntraBlockCopyHorizontal<64>(src, src_stride, height, dest, dst_stride);
1472   } else if (width == 32) {
1473     IntraBlockCopyHorizontal<32>(src, src_stride, height, dest, dst_stride);
1474   } else if (width == 16) {
1475     IntraBlockCopyHorizontal<16>(src, src_stride, height, dest, dst_stride);
1476   } else if (width == 8) {
1477     int y = height;
1478     do {
1479       HalfAddHorizontal(src, dest);
1480       src += src_stride;
1481       dest += dst_stride;
1482     } while (--y != 0);
1483   } else {  // width == 4
1484     int y = height;
1485     do {
1486       uint16x4x2_t left;
1487       uint16x4x2_t right;
1488       left.val[0] = vld1_u16(src);
1489       right.val[0] = vld1_u16(src + 1);
1490       src += src_stride;
1491       left.val[1] = vld1_u16(src);
1492       right.val[1] = vld1_u16(src + 1);
1493       src += src_stride;
1494 
1495       vst1_u16(dest, vrhadd_u16(left.val[0], right.val[0]));
1496       dest += dst_stride;
1497       vst1_u16(dest, vrhadd_u16(left.val[1], right.val[1]));
1498       dest += dst_stride;
1499       y -= 2;
1500     } while (y != 0);
1501   }
1502 }
1503 
1504 template <int width>
IntraBlockCopyVertical(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)1505 inline void IntraBlockCopyVertical(const uint16_t* LIBGAV1_RESTRICT src,
1506                                    const ptrdiff_t src_stride, const int height,
1507                                    uint16_t* LIBGAV1_RESTRICT dst,
1508                                    const ptrdiff_t dst_stride) {
1509   const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
1510   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
1511   uint16x8_t row[8], below[8];
1512 
1513   row[0] = vld1q_u16(src);
1514   if (width >= 16) {
1515     src += 8;
1516     row[1] = vld1q_u16(src);
1517     if (width >= 32) {
1518       src += 8;
1519       row[2] = vld1q_u16(src);
1520       src += 8;
1521       row[3] = vld1q_u16(src);
1522       if (width == 64) {
1523         src += 8;
1524         row[4] = vld1q_u16(src);
1525         src += 8;
1526         row[5] = vld1q_u16(src);
1527         src += 8;
1528         row[6] = vld1q_u16(src);
1529         src += 8;
1530         row[7] = vld1q_u16(src);
1531       }
1532     }
1533   }
1534   src += src_remainder_stride;
1535 
1536   int y = height;
1537   do {
1538     below[0] = vld1q_u16(src);
1539     if (width >= 16) {
1540       src += 8;
1541       below[1] = vld1q_u16(src);
1542       if (width >= 32) {
1543         src += 8;
1544         below[2] = vld1q_u16(src);
1545         src += 8;
1546         below[3] = vld1q_u16(src);
1547         if (width == 64) {
1548           src += 8;
1549           below[4] = vld1q_u16(src);
1550           src += 8;
1551           below[5] = vld1q_u16(src);
1552           src += 8;
1553           below[6] = vld1q_u16(src);
1554           src += 8;
1555           below[7] = vld1q_u16(src);
1556         }
1557       }
1558     }
1559     src += src_remainder_stride;
1560 
1561     vst1q_u16(dst, vrhaddq_u16(row[0], below[0]));
1562     row[0] = below[0];
1563     if (width >= 16) {
1564       dst += 8;
1565       vst1q_u16(dst, vrhaddq_u16(row[1], below[1]));
1566       row[1] = below[1];
1567       if (width >= 32) {
1568         dst += 8;
1569         vst1q_u16(dst, vrhaddq_u16(row[2], below[2]));
1570         row[2] = below[2];
1571         dst += 8;
1572         vst1q_u16(dst, vrhaddq_u16(row[3], below[3]));
1573         row[3] = below[3];
1574         if (width >= 64) {
1575           dst += 8;
1576           vst1q_u16(dst, vrhaddq_u16(row[4], below[4]));
1577           row[4] = below[4];
1578           dst += 8;
1579           vst1q_u16(dst, vrhaddq_u16(row[5], below[5]));
1580           row[5] = below[5];
1581           dst += 8;
1582           vst1q_u16(dst, vrhaddq_u16(row[6], below[6]));
1583           row[6] = below[6];
1584           dst += 8;
1585           vst1q_u16(dst, vrhaddq_u16(row[7], below[7]));
1586           row[7] = below[7];
1587         }
1588       }
1589     }
1590     dst += dst_remainder_stride;
1591   } while (--y != 0);
1592 }
1593 
ConvolveIntraBlockCopyVertical_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1594 void ConvolveIntraBlockCopyVertical_NEON(
1595     const void* LIBGAV1_RESTRICT const reference,
1596     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1597     const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
1598     const int /*vertical_filter_id*/, const int width, const int height,
1599     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
1600   assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
1601   assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
1602   const auto* src = static_cast<const uint16_t*>(reference);
1603   auto* dest = static_cast<uint16_t*>(prediction);
1604   const ptrdiff_t src_stride = reference_stride >> 1;
1605   const ptrdiff_t dst_stride = pred_stride >> 1;
1606 
1607   if (width == 128) {
1608     // Due to register pressure, process two 64xH.
1609     for (int i = 0; i < 2; ++i) {
1610       IntraBlockCopyVertical<64>(src, src_stride, height, dest, dst_stride);
1611       src += 64;
1612       dest += 64;
1613     }
1614   } else if (width == 64) {
1615     IntraBlockCopyVertical<64>(src, src_stride, height, dest, dst_stride);
1616   } else if (width == 32) {
1617     IntraBlockCopyVertical<32>(src, src_stride, height, dest, dst_stride);
1618   } else if (width == 16) {
1619     IntraBlockCopyVertical<16>(src, src_stride, height, dest, dst_stride);
1620   } else if (width == 8) {
1621     IntraBlockCopyVertical<8>(src, src_stride, height, dest, dst_stride);
1622   } else {  // width == 4
1623     uint16x4_t row = vld1_u16(src);
1624     src += src_stride;
1625     int y = height;
1626     do {
1627       const uint16x4_t below = vld1_u16(src);
1628       src += src_stride;
1629       vst1_u16(dest, vrhadd_u16(row, below));
1630       dest += dst_stride;
1631       row = below;
1632     } while (--y != 0);
1633   }
1634 }
1635 
1636 template <int width>
IntraBlockCopy2D(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)1637 inline void IntraBlockCopy2D(const uint16_t* LIBGAV1_RESTRICT src,
1638                              const ptrdiff_t src_stride, const int height,
1639                              uint16_t* LIBGAV1_RESTRICT dst,
1640                              const ptrdiff_t dst_stride) {
1641   const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
1642   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
1643   uint16x8_t row[16];
1644   row[0] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1645   if (width >= 16) {
1646     src += 8;
1647     row[1] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1648     if (width >= 32) {
1649       src += 8;
1650       row[2] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1651       src += 8;
1652       row[3] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1653       if (width >= 64) {
1654         src += 8;
1655         row[4] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1656         src += 8;
1657         row[5] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1658         src += 8;
1659         row[6] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1660         src += 8;
1661         row[7] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1662         if (width == 128) {
1663           src += 8;
1664           row[8] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1665           src += 8;
1666           row[9] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1667           src += 8;
1668           row[10] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1669           src += 8;
1670           row[11] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1671           src += 8;
1672           row[12] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1673           src += 8;
1674           row[13] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1675           src += 8;
1676           row[14] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1677           src += 8;
1678           row[15] = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1679         }
1680       }
1681     }
1682   }
1683   src += src_remainder_stride;
1684 
1685   int y = height;
1686   do {
1687     const uint16x8_t below_0 = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1688     vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[0], below_0), 2));
1689     row[0] = below_0;
1690     if (width >= 16) {
1691       src += 8;
1692       dst += 8;
1693 
1694       const uint16x8_t below_1 = vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1695       vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[1], below_1), 2));
1696       row[1] = below_1;
1697       if (width >= 32) {
1698         src += 8;
1699         dst += 8;
1700 
1701         const uint16x8_t below_2 =
1702             vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1703         vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[2], below_2), 2));
1704         row[2] = below_2;
1705         src += 8;
1706         dst += 8;
1707 
1708         const uint16x8_t below_3 =
1709             vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1710         vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[3], below_3), 2));
1711         row[3] = below_3;
1712         if (width >= 64) {
1713           src += 8;
1714           dst += 8;
1715 
1716           const uint16x8_t below_4 =
1717               vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1718           vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[4], below_4), 2));
1719           row[4] = below_4;
1720           src += 8;
1721           dst += 8;
1722 
1723           const uint16x8_t below_5 =
1724               vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1725           vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[5], below_5), 2));
1726           row[5] = below_5;
1727           src += 8;
1728           dst += 8;
1729 
1730           const uint16x8_t below_6 =
1731               vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1732           vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[6], below_6), 2));
1733           row[6] = below_6;
1734           src += 8;
1735           dst += 8;
1736 
1737           const uint16x8_t below_7 =
1738               vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1739           vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[7], below_7), 2));
1740           row[7] = below_7;
1741           if (width == 128) {
1742             src += 8;
1743             dst += 8;
1744 
1745             const uint16x8_t below_8 =
1746                 vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1747             vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[8], below_8), 2));
1748             row[8] = below_8;
1749             src += 8;
1750             dst += 8;
1751 
1752             const uint16x8_t below_9 =
1753                 vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1754             vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[9], below_9), 2));
1755             row[9] = below_9;
1756             src += 8;
1757             dst += 8;
1758 
1759             const uint16x8_t below_10 =
1760                 vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1761             vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[10], below_10), 2));
1762             row[10] = below_10;
1763             src += 8;
1764             dst += 8;
1765 
1766             const uint16x8_t below_11 =
1767                 vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1768             vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[11], below_11), 2));
1769             row[11] = below_11;
1770             src += 8;
1771             dst += 8;
1772 
1773             const uint16x8_t below_12 =
1774                 vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1775             vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[12], below_12), 2));
1776             row[12] = below_12;
1777             src += 8;
1778             dst += 8;
1779 
1780             const uint16x8_t below_13 =
1781                 vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1782             vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[13], below_13), 2));
1783             row[13] = below_13;
1784             src += 8;
1785             dst += 8;
1786 
1787             const uint16x8_t below_14 =
1788                 vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1789             vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[14], below_14), 2));
1790             row[14] = below_14;
1791             src += 8;
1792             dst += 8;
1793 
1794             const uint16x8_t below_15 =
1795                 vaddq_u16(vld1q_u16(src), vld1q_u16(src + 1));
1796             vst1q_u16(dst, vrshrq_n_u16(vaddq_u16(row[15], below_15), 2));
1797             row[15] = below_15;
1798           }
1799         }
1800       }
1801     }
1802     src += src_remainder_stride;
1803     dst += dst_remainder_stride;
1804   } while (--y != 0);
1805 }
1806 
ConvolveIntraBlockCopy2D_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1807 void ConvolveIntraBlockCopy2D_NEON(
1808     const void* LIBGAV1_RESTRICT const reference,
1809     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1810     const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
1811     const int /*vertical_filter_id*/, const int width, const int height,
1812     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
1813   assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
1814   assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
1815   const auto* src = static_cast<const uint16_t*>(reference);
1816   auto* dest = static_cast<uint16_t*>(prediction);
1817   const ptrdiff_t src_stride = reference_stride >> 1;
1818   const ptrdiff_t dst_stride = pred_stride >> 1;
1819 
1820   // Note: allow vertical access to height + 1. Because this function is only
1821   // for u/v plane of intra block copy, such access is guaranteed to be within
1822   // the prediction block.
1823 
1824   if (width == 128) {
1825     IntraBlockCopy2D<128>(src, src_stride, height, dest, dst_stride);
1826   } else if (width == 64) {
1827     IntraBlockCopy2D<64>(src, src_stride, height, dest, dst_stride);
1828   } else if (width == 32) {
1829     IntraBlockCopy2D<32>(src, src_stride, height, dest, dst_stride);
1830   } else if (width == 16) {
1831     IntraBlockCopy2D<16>(src, src_stride, height, dest, dst_stride);
1832   } else if (width == 8) {
1833     IntraBlockCopy2D<8>(src, src_stride, height, dest, dst_stride);
1834   } else {  // width == 4
1835     uint16x4_t row0 = vadd_u16(vld1_u16(src), vld1_u16(src + 1));
1836     src += src_stride;
1837 
1838     int y = height;
1839     do {
1840       const uint16x4_t row1 = vadd_u16(vld1_u16(src), vld1_u16(src + 1));
1841       src += src_stride;
1842       const uint16x4_t row2 = vadd_u16(vld1_u16(src), vld1_u16(src + 1));
1843       src += src_stride;
1844       const uint16x4_t result_01 = vrshr_n_u16(vadd_u16(row0, row1), 2);
1845       const uint16x4_t result_12 = vrshr_n_u16(vadd_u16(row1, row2), 2);
1846       vst1_u16(dest, result_01);
1847       dest += dst_stride;
1848       vst1_u16(dest, result_12);
1849       dest += dst_stride;
1850       row0 = row2;
1851       y -= 2;
1852     } while (y != 0);
1853   }
1854 }
1855 
1856 // -----------------------------------------------------------------------------
1857 // Scaled Convolve
1858 
1859 // There are many opportunities for overreading in scaled convolve, because the
1860 // range of starting points for filter windows is anywhere from 0 to 16 for 8
1861 // destination pixels, and the window sizes range from 2 to 8. To accommodate
1862 // this range concisely, we use |grade_x| to mean the most steps in src that can
1863 // be traversed in a single |step_x| increment, i.e. 1 or 2. When grade_x is 2,
1864 // we are guaranteed to exceed 8 whole steps in src for every 8 |step_x|
1865 // increments. The first load covers the initial elements of src_x, while the
1866 // final load covers the taps.
1867 template <int grade_x>
LoadSrcVals(const uint16_t * const src_x)1868 inline uint8x16x3_t LoadSrcVals(const uint16_t* const src_x) {
1869   uint8x16x3_t ret;
1870   // When fractional step size is less than or equal to 1, the rightmost
1871   // starting value for a filter may be at position 7. For an 8-tap filter, the
1872   // rightmost value for the final tap may be at position 14. Therefore we load
1873   // 2 vectors of eight 16-bit values.
1874   ret.val[0] = vreinterpretq_u8_u16(vld1q_u16(src_x));
1875   ret.val[1] = vreinterpretq_u8_u16(vld1q_u16(src_x + 8));
1876 #if LIBGAV1_MSAN
1877   // Initialize to quiet msan warnings when grade_x <= 1.
1878   ret.val[2] = vdupq_n_u8(0);
1879 #endif
1880   if (grade_x > 1) {
1881     // When fractional step size is greater than 1 (up to 2), the rightmost
1882     // starting value for a filter may be at position 15. For an 8-tap filter,
1883     // the rightmost value for the final tap may be at position 22. Therefore we
1884     // load 3 vectors of eight 16-bit values.
1885     ret.val[2] = vreinterpretq_u8_u16(vld1q_u16(src_x + 16));
1886   }
1887   return ret;
1888 }
1889 
1890 // Assemble 4 values corresponding to one tap position across multiple filters.
1891 // This is a simple case because maximum offset is 8 and only smaller filters
1892 // work on 4xH.
PermuteSrcVals(const uint8x16x3_t src_bytes,const uint8x8_t indices)1893 inline uint16x4_t PermuteSrcVals(const uint8x16x3_t src_bytes,
1894                                  const uint8x8_t indices) {
1895   const uint8x16x2_t src_bytes2 = {src_bytes.val[0], src_bytes.val[1]};
1896   return vreinterpret_u16_u8(VQTbl2U8(src_bytes2, indices));
1897 }
1898 
1899 // Assemble 8 values corresponding to one tap position across multiple filters.
1900 // This requires a lot of workaround on A32 architectures, so it may be worth
1901 // using an overall different algorithm for that architecture.
1902 template <int grade_x>
PermuteSrcVals(const uint8x16x3_t src_bytes,const uint8x16_t indices)1903 inline uint16x8_t PermuteSrcVals(const uint8x16x3_t src_bytes,
1904                                  const uint8x16_t indices) {
1905   if (grade_x == 1) {
1906     const uint8x16x2_t src_bytes2 = {src_bytes.val[0], src_bytes.val[1]};
1907     return vreinterpretq_u16_u8(VQTbl2QU8(src_bytes2, indices));
1908   }
1909   return vreinterpretq_u16_u8(VQTbl3QU8(src_bytes, indices));
1910 }
1911 
1912 // Pre-transpose the 2 tap filters in |kAbsHalfSubPixelFilters|[3]
1913 // Although the taps need to be converted to 16-bit values, they must be
1914 // arranged by table lookup, which is more expensive for larger types than
1915 // lengthening in-loop. |tap_index| refers to the index within a kernel applied
1916 // to a single value.
GetPositive2TapFilter(const int tap_index)1917 inline int8x16_t GetPositive2TapFilter(const int tap_index) {
1918   assert(tap_index < 2);
1919   alignas(
1920       16) static constexpr int8_t kAbsHalfSubPixel2TapFilterColumns[2][16] = {
1921       {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4},
1922       {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}};
1923 
1924   return vld1q_s8(kAbsHalfSubPixel2TapFilterColumns[tap_index]);
1925 }
1926 
1927 template <int grade_x>
ConvolveKernelHorizontal2Tap(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int width,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT intermediate)1928 inline void ConvolveKernelHorizontal2Tap(
1929     const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
1930     const int width, const int subpixel_x, const int step_x,
1931     const int intermediate_height, int16_t* LIBGAV1_RESTRICT intermediate) {
1932   // Account for the 0-taps that precede the 2 nonzero taps in the spec.
1933   const int kernel_offset = 3;
1934   const int ref_x = subpixel_x >> kScaleSubPixelBits;
1935   const int step_x8 = step_x << 3;
1936   const int8x16_t filter_taps0 = GetPositive2TapFilter(0);
1937   const int8x16_t filter_taps1 = GetPositive2TapFilter(1);
1938   const uint16x8_t index_steps = vmulq_n_u16(
1939       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
1940   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
1941 
1942   int p = subpixel_x;
1943   if (width <= 4) {
1944     const uint16_t* src_y = src;
1945     // Only add steps to the 10-bit truncated p to avoid overflow.
1946     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
1947     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
1948     const uint8x8_t filter_indices =
1949         vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
1950     // Each lane of lane of taps[k] corresponds to one output value along the
1951     // row, containing kSubPixelFilters[filter_index][filter_id][k], where
1952     // filter_id depends on x.
1953     const int16x4_t taps[2] = {
1954         vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps0, filter_indices))),
1955         vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps1, filter_indices)))};
1956     // Lower byte of Nth value is at position 2*N.
1957     // Narrowing shift is not available here because the maximum shift
1958     // parameter is 8.
1959     const uint8x8_t src_indices0 = vshl_n_u8(
1960         vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
1961     // Upper byte of Nth value is at position 2*N+1.
1962     const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
1963     // Only 4 values needed.
1964     const uint8x8_t src_indices = InterleaveLow8(src_indices0, src_indices1);
1965     const uint8x8_t src_lookup[2] = {src_indices,
1966                                      vadd_u8(src_indices, vdup_n_u8(2))};
1967 
1968     int y = intermediate_height;
1969     do {
1970       const uint16_t* src_x =
1971           src_y + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
1972       // Load a pool of samples to select from using stepped indices.
1973       const uint8x16x3_t src_bytes = LoadSrcVals<1>(src_x);
1974       // Each lane corresponds to a different filter kernel.
1975       const uint16x4_t src[2] = {PermuteSrcVals(src_bytes, src_lookup[0]),
1976                                  PermuteSrcVals(src_bytes, src_lookup[1])};
1977 
1978       vst1_s16(intermediate,
1979                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/2>(src, taps),
1980                             kInterRoundBitsHorizontal - 1));
1981       src_y = AddByteStride(src_y, src_stride);
1982       intermediate += kIntermediateStride;
1983     } while (--y != 0);
1984     return;
1985   }
1986 
1987   // |width| >= 8
1988   int16_t* intermediate_x = intermediate;
1989   int x = 0;
1990   do {
1991     const uint16_t* src_x =
1992         src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
1993     // Only add steps to the 10-bit truncated p to avoid overflow.
1994     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
1995     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
1996     const uint8x8_t filter_indices =
1997         vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
1998                 filter_index_mask);
1999     // Each lane of lane of taps[k] corresponds to one output value along the
2000     // row, containing kSubPixelFilters[filter_index][filter_id][k], where
2001     // filter_id depends on x.
2002     const int16x8_t taps[2] = {
2003         vmovl_s8(VQTbl1S8(filter_taps0, filter_indices)),
2004         vmovl_s8(VQTbl1S8(filter_taps1, filter_indices))};
2005     const int16x4_t taps_low[2] = {vget_low_s16(taps[0]),
2006                                    vget_low_s16(taps[1])};
2007     const int16x4_t taps_high[2] = {vget_high_s16(taps[0]),
2008                                     vget_high_s16(taps[1])};
2009     // Lower byte of Nth value is at position 2*N.
2010     const uint8x8_t src_indices0 = vshl_n_u8(
2011         vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
2012     // Upper byte of Nth value is at position 2*N+1.
2013     const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
2014     const uint8x8x2_t src_indices_zip = vzip_u8(src_indices0, src_indices1);
2015     const uint8x16_t src_indices =
2016         vcombine_u8(src_indices_zip.val[0], src_indices_zip.val[1]);
2017     const uint8x16_t src_lookup[2] = {src_indices,
2018                                       vaddq_u8(src_indices, vdupq_n_u8(2))};
2019 
2020     int y = intermediate_height;
2021     do {
2022       // Load a pool of samples to select from using stepped indices.
2023       const uint8x16x3_t src_bytes = LoadSrcVals<grade_x>(src_x);
2024       // Each lane corresponds to a different filter kernel.
2025       const uint16x8_t src[2] = {
2026           PermuteSrcVals<grade_x>(src_bytes, src_lookup[0]),
2027           PermuteSrcVals<grade_x>(src_bytes, src_lookup[1])};
2028       const uint16x4_t src_low[2] = {vget_low_u16(src[0]),
2029                                      vget_low_u16(src[1])};
2030       const uint16x4_t src_high[2] = {vget_high_u16(src[0]),
2031                                       vget_high_u16(src[1])};
2032 
2033       vst1_s16(intermediate_x,
2034                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/2>(src_low, taps_low),
2035                             kInterRoundBitsHorizontal - 1));
2036       vst1_s16(intermediate_x + 4,
2037                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/2>(src_high, taps_high),
2038                             kInterRoundBitsHorizontal - 1));
2039       // Avoid right shifting the stride.
2040       src_x = AddByteStride(src_x, src_stride);
2041       intermediate_x += kIntermediateStride;
2042     } while (--y != 0);
2043     x += 8;
2044     p += step_x8;
2045   } while (x < width);
2046 }
2047 
2048 // Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[5].
GetPositive4TapFilter(const int tap_index)2049 inline int8x16_t GetPositive4TapFilter(const int tap_index) {
2050   assert(tap_index < 4);
2051   alignas(
2052       16) static constexpr int8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
2053       {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1},
2054       {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
2055       {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
2056       {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}};
2057 
2058   return vld1q_s8(kSubPixel4TapPositiveFilterColumns[tap_index]);
2059 }
2060 
2061 // This filter is only possible when width <= 4.
ConvolveKernelHorizontalPositive4Tap(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT intermediate)2062 inline void ConvolveKernelHorizontalPositive4Tap(
2063     const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
2064     const int subpixel_x, const int step_x, const int intermediate_height,
2065     int16_t* LIBGAV1_RESTRICT intermediate) {
2066   // Account for the 0-taps that precede the 2 nonzero taps in the spec.
2067   const int kernel_offset = 2;
2068   const int ref_x = subpixel_x >> kScaleSubPixelBits;
2069   const int8x16_t filter_taps0 = GetPositive4TapFilter(0);
2070   const int8x16_t filter_taps1 = GetPositive4TapFilter(1);
2071   const int8x16_t filter_taps2 = GetPositive4TapFilter(2);
2072   const int8x16_t filter_taps3 = GetPositive4TapFilter(3);
2073   const uint16x8_t index_steps = vmulq_n_u16(
2074       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
2075   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
2076 
2077   int p = subpixel_x;
2078   // Only add steps to the 10-bit truncated p to avoid overflow.
2079   const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
2080   const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
2081   const uint8x8_t filter_indices =
2082       vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
2083   // Each lane of lane of taps[k] corresponds to one output value along the row,
2084   // containing kSubPixelFilters[filter_index][filter_id][k], where filter_id
2085   // depends on x.
2086   const int16x4_t taps[4] = {
2087       vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps0, filter_indices))),
2088       vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps1, filter_indices))),
2089       vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps2, filter_indices))),
2090       vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps3, filter_indices)))};
2091   // Lower byte of Nth value is at position 2*N.
2092   // Narrowing shift is not available here because the maximum shift
2093   // parameter is 8.
2094   const uint8x8_t src_indices0 = vshl_n_u8(
2095       vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
2096   // Upper byte of Nth value is at position 2*N+1.
2097   const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
2098   // Only 4 values needed.
2099   const uint8x8_t src_indices_base = InterleaveLow8(src_indices0, src_indices1);
2100 
2101   uint8x8_t src_lookup[4];
2102   const uint8x8_t two = vdup_n_u8(2);
2103   src_lookup[0] = src_indices_base;
2104   for (int i = 1; i < 4; ++i) {
2105     src_lookup[i] = vadd_u8(src_lookup[i - 1], two);
2106   }
2107 
2108   const uint16_t* src_y =
2109       src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
2110   int y = intermediate_height;
2111   do {
2112     // Load a pool of samples to select from using stepped indices.
2113     const uint8x16x3_t src_bytes = LoadSrcVals<1>(src_y);
2114     // Each lane corresponds to a different filter kernel.
2115     const uint16x4_t src[4] = {PermuteSrcVals(src_bytes, src_lookup[0]),
2116                                PermuteSrcVals(src_bytes, src_lookup[1]),
2117                                PermuteSrcVals(src_bytes, src_lookup[2]),
2118                                PermuteSrcVals(src_bytes, src_lookup[3])};
2119 
2120     vst1_s16(intermediate,
2121              vrshrn_n_s32(SumOnePassTaps</*num_taps=*/4>(src, taps),
2122                           kInterRoundBitsHorizontal - 1));
2123     src_y = AddByteStride(src_y, src_stride);
2124     intermediate += kIntermediateStride;
2125   } while (--y != 0);
2126 }
2127 
2128 // Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[4].
GetSigned4TapFilter(const int tap_index)2129 inline int8x16_t GetSigned4TapFilter(const int tap_index) {
2130   assert(tap_index < 4);
2131   alignas(16) static constexpr int8_t
2132       kAbsHalfSubPixel4TapSignedFilterColumns[4][16] = {
2133           {-0, -2, -4, -5, -6, -6, -7, -6, -6, -5, -5, -5, -4, -3, -2, -1},
2134           {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
2135           {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
2136           {-0, -1, -2, -3, -4, -5, -5, -5, -6, -6, -7, -6, -6, -5, -4, -2}};
2137 
2138   return vld1q_s8(kAbsHalfSubPixel4TapSignedFilterColumns[tap_index]);
2139 }
2140 
2141 // This filter is only possible when width <= 4.
ConvolveKernelHorizontalSigned4Tap(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT intermediate)2142 inline void ConvolveKernelHorizontalSigned4Tap(
2143     const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
2144     const int subpixel_x, const int step_x, const int intermediate_height,
2145     int16_t* LIBGAV1_RESTRICT intermediate) {
2146   const int kernel_offset = 2;
2147   const int ref_x = subpixel_x >> kScaleSubPixelBits;
2148   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
2149   const int8x16_t filter_taps0 = GetSigned4TapFilter(0);
2150   const int8x16_t filter_taps1 = GetSigned4TapFilter(1);
2151   const int8x16_t filter_taps2 = GetSigned4TapFilter(2);
2152   const int8x16_t filter_taps3 = GetSigned4TapFilter(3);
2153   const uint16x8_t index_steps = vmulq_n_u16(
2154       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
2155 
2156   const int p = subpixel_x;
2157   // Only add steps to the 10-bit truncated p to avoid overflow.
2158   const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
2159   const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
2160   const uint8x8_t filter_indices =
2161       vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
2162   // Each lane of lane of taps[k] corresponds to one output value along the row,
2163   // containing kSubPixelFilters[filter_index][filter_id][k], where filter_id
2164   // depends on x.
2165   const int16x4_t taps[4] = {
2166       vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps0, filter_indices))),
2167       vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps1, filter_indices))),
2168       vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps2, filter_indices))),
2169       vget_low_s16(vmovl_s8(VQTbl1S8(filter_taps3, filter_indices)))};
2170   // Lower byte of Nth value is at position 2*N.
2171   // Narrowing shift is not available here because the maximum shift
2172   // parameter is 8.
2173   const uint8x8_t src_indices0 = vshl_n_u8(
2174       vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
2175   // Upper byte of Nth value is at position 2*N+1.
2176   const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
2177   // Only 4 values needed.
2178   const uint8x8_t src_indices_base = InterleaveLow8(src_indices0, src_indices1);
2179 
2180   uint8x8_t src_lookup[4];
2181   const uint8x8_t two = vdup_n_u8(2);
2182   src_lookup[0] = src_indices_base;
2183   for (int i = 1; i < 4; ++i) {
2184     src_lookup[i] = vadd_u8(src_lookup[i - 1], two);
2185   }
2186 
2187   const uint16_t* src_y =
2188       src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
2189   int y = intermediate_height;
2190   do {
2191     // Load a pool of samples to select from using stepped indices.
2192     const uint8x16x3_t src_bytes = LoadSrcVals<1>(src_y);
2193     // Each lane corresponds to a different filter kernel.
2194     const uint16x4_t src[4] = {PermuteSrcVals(src_bytes, src_lookup[0]),
2195                                PermuteSrcVals(src_bytes, src_lookup[1]),
2196                                PermuteSrcVals(src_bytes, src_lookup[2]),
2197                                PermuteSrcVals(src_bytes, src_lookup[3])};
2198 
2199     vst1_s16(intermediate,
2200              vrshrn_n_s32(SumOnePassTaps</*num_taps=*/4>(src, taps),
2201                           kInterRoundBitsHorizontal - 1));
2202     src_y = AddByteStride(src_y, src_stride);
2203     intermediate += kIntermediateStride;
2204   } while (--y != 0);
2205 }
2206 
2207 // Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[0].
GetSigned6TapFilter(const int tap_index)2208 inline int8x16_t GetSigned6TapFilter(const int tap_index) {
2209   assert(tap_index < 6);
2210   alignas(16) static constexpr int8_t
2211       kAbsHalfSubPixel6TapSignedFilterColumns[6][16] = {
2212           {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0},
2213           {-0, -3, -5, -6, -7, -7, -8, -7, -7, -6, -6, -6, -5, -4, -2, -1},
2214           {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
2215           {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
2216           {-0, -1, -2, -4, -5, -6, -6, -6, -7, -7, -8, -7, -7, -6, -5, -3},
2217           {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
2218 
2219   return vld1q_s8(kAbsHalfSubPixel6TapSignedFilterColumns[tap_index]);
2220 }
2221 
2222 // This filter is only possible when width >= 8.
2223 template <int grade_x>
ConvolveKernelHorizontalSigned6Tap(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int width,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT const intermediate)2224 inline void ConvolveKernelHorizontalSigned6Tap(
2225     const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
2226     const int width, const int subpixel_x, const int step_x,
2227     const int intermediate_height,
2228     int16_t* LIBGAV1_RESTRICT const intermediate) {
2229   const int kernel_offset = 1;
2230   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
2231   const int ref_x = subpixel_x >> kScaleSubPixelBits;
2232   const int step_x8 = step_x << 3;
2233   int8x16_t filter_taps[6];
2234   for (int i = 0; i < 6; ++i) {
2235     filter_taps[i] = GetSigned6TapFilter(i);
2236   }
2237   const uint16x8_t index_steps = vmulq_n_u16(
2238       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
2239 
2240   int16_t* intermediate_x = intermediate;
2241   int x = 0;
2242   int p = subpixel_x;
2243   do {
2244     const uint16_t* src_x =
2245         src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
2246     // Only add steps to the 10-bit truncated p to avoid overflow.
2247     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
2248     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
2249     const uint8x8_t filter_indices =
2250         vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
2251                 filter_index_mask);
2252 
2253     // Each lane of lane of taps_(low|high)[k] corresponds to one output value
2254     // along the row, containing kSubPixelFilters[filter_index][filter_id][k],
2255     // where filter_id depends on x.
2256     int16x4_t taps_low[6];
2257     int16x4_t taps_high[6];
2258     for (int i = 0; i < 6; ++i) {
2259       const int16x8_t taps_i =
2260           vmovl_s8(VQTbl1S8(filter_taps[i], filter_indices));
2261       taps_low[i] = vget_low_s16(taps_i);
2262       taps_high[i] = vget_high_s16(taps_i);
2263     }
2264 
2265     // Lower byte of Nth value is at position 2*N.
2266     const uint8x8_t src_indices0 = vshl_n_u8(
2267         vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
2268     // Upper byte of Nth value is at position 2*N+1.
2269     const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
2270     const uint8x8x2_t src_indices_zip = vzip_u8(src_indices0, src_indices1);
2271     const uint8x16_t src_indices_base =
2272         vcombine_u8(src_indices_zip.val[0], src_indices_zip.val[1]);
2273 
2274     uint8x16_t src_lookup[6];
2275     const uint8x16_t two = vdupq_n_u8(2);
2276     src_lookup[0] = src_indices_base;
2277     for (int i = 1; i < 6; ++i) {
2278       src_lookup[i] = vaddq_u8(src_lookup[i - 1], two);
2279     }
2280 
2281     int y = intermediate_height;
2282     do {
2283       // Load a pool of samples to select from using stepped indices.
2284       const uint8x16x3_t src_bytes = LoadSrcVals<grade_x>(src_x);
2285 
2286       uint16x4_t src_low[6];
2287       uint16x4_t src_high[6];
2288       for (int i = 0; i < 6; ++i) {
2289         const uint16x8_t src_i =
2290             PermuteSrcVals<grade_x>(src_bytes, src_lookup[i]);
2291         src_low[i] = vget_low_u16(src_i);
2292         src_high[i] = vget_high_u16(src_i);
2293       }
2294 
2295       vst1_s16(intermediate_x,
2296                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/6>(src_low, taps_low),
2297                             kInterRoundBitsHorizontal - 1));
2298       vst1_s16(intermediate_x + 4,
2299                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/6>(src_high, taps_high),
2300                             kInterRoundBitsHorizontal - 1));
2301       // Avoid right shifting the stride.
2302       src_x = AddByteStride(src_x, src_stride);
2303       intermediate_x += kIntermediateStride;
2304     } while (--y != 0);
2305     x += 8;
2306     p += step_x8;
2307   } while (x < width);
2308 }
2309 
2310 // Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[1]. This filter
2311 // has mixed positive and negative outer taps depending on the filter id.
GetMixed6TapFilter(const int tap_index)2312 inline int8x16_t GetMixed6TapFilter(const int tap_index) {
2313   assert(tap_index < 6);
2314   alignas(16) static constexpr int8_t
2315       kAbsHalfSubPixel6TapMixedFilterColumns[6][16] = {
2316           {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0},
2317           {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1},
2318           {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
2319           {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
2320           {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14},
2321           {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}};
2322 
2323   return vld1q_s8(kAbsHalfSubPixel6TapMixedFilterColumns[tap_index]);
2324 }
2325 
2326 // This filter is only possible when width >= 8.
2327 template <int grade_x>
ConvolveKernelHorizontalMixed6Tap(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int width,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT const intermediate)2328 inline void ConvolveKernelHorizontalMixed6Tap(
2329     const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
2330     const int width, const int subpixel_x, const int step_x,
2331     const int intermediate_height,
2332     int16_t* LIBGAV1_RESTRICT const intermediate) {
2333   const int kernel_offset = 1;
2334   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
2335   const int ref_x = subpixel_x >> kScaleSubPixelBits;
2336   const int step_x8 = step_x << 3;
2337   int8x16_t filter_taps[6];
2338   for (int i = 0; i < 6; ++i) {
2339     filter_taps[i] = GetMixed6TapFilter(i);
2340   }
2341   const uint16x8_t index_steps = vmulq_n_u16(
2342       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
2343 
2344   int16_t* intermediate_x = intermediate;
2345   int x = 0;
2346   int p = subpixel_x;
2347   do {
2348     const uint16_t* src_x =
2349         src + (p >> kScaleSubPixelBits) - ref_x + kernel_offset;
2350     // Only add steps to the 10-bit truncated p to avoid overflow.
2351     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
2352     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
2353 
2354     const uint8x8_t filter_indices =
2355         vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
2356                 filter_index_mask);
2357     // Each lane of lane of taps_(low|high)[k] corresponds to one output value
2358     // along the row, containing kSubPixelFilters[filter_index][filter_id][k],
2359     // where filter_id depends on x.
2360     int16x4_t taps_low[6];
2361     int16x4_t taps_high[6];
2362     for (int i = 0; i < 6; ++i) {
2363       const int16x8_t taps = vmovl_s8(VQTbl1S8(filter_taps[i], filter_indices));
2364       taps_low[i] = vget_low_s16(taps);
2365       taps_high[i] = vget_high_s16(taps);
2366     }
2367 
2368     // Lower byte of Nth value is at position 2*N.
2369     const uint8x8_t src_indices0 = vshl_n_u8(
2370         vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
2371     // Upper byte of Nth value is at position 2*N+1.
2372     const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
2373     const uint8x8x2_t src_indices_zip = vzip_u8(src_indices0, src_indices1);
2374     const uint8x16_t src_indices_base =
2375         vcombine_u8(src_indices_zip.val[0], src_indices_zip.val[1]);
2376 
2377     uint8x16_t src_lookup[6];
2378     const uint8x16_t two = vdupq_n_u8(2);
2379     src_lookup[0] = src_indices_base;
2380     for (int i = 1; i < 6; ++i) {
2381       src_lookup[i] = vaddq_u8(src_lookup[i - 1], two);
2382     }
2383 
2384     int y = intermediate_height;
2385     do {
2386       // Load a pool of samples to select from using stepped indices.
2387       const uint8x16x3_t src_bytes = LoadSrcVals<grade_x>(src_x);
2388 
2389       uint16x4_t src_low[6];
2390       uint16x4_t src_high[6];
2391       for (int i = 0; i < 6; ++i) {
2392         const uint16x8_t src_i =
2393             PermuteSrcVals<grade_x>(src_bytes, src_lookup[i]);
2394         src_low[i] = vget_low_u16(src_i);
2395         src_high[i] = vget_high_u16(src_i);
2396       }
2397 
2398       vst1_s16(intermediate_x,
2399                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/6>(src_low, taps_low),
2400                             kInterRoundBitsHorizontal - 1));
2401       vst1_s16(intermediate_x + 4,
2402                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/6>(src_high, taps_high),
2403                             kInterRoundBitsHorizontal - 1));
2404       // Avoid right shifting the stride.
2405       src_x = AddByteStride(src_x, src_stride);
2406       intermediate_x += kIntermediateStride;
2407     } while (--y != 0);
2408     x += 8;
2409     p += step_x8;
2410   } while (x < width);
2411 }
2412 
2413 // Pre-transpose the 8 tap filters in |kAbsHalfSubPixelFilters|[2].
GetSigned8TapFilter(const int tap_index)2414 inline int8x16_t GetSigned8TapFilter(const int tap_index) {
2415   assert(tap_index < 8);
2416   alignas(16) static constexpr int8_t
2417       kAbsHalfSubPixel8TapSignedFilterColumns[8][16] = {
2418           {-0, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, -0},
2419           {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1},
2420           {-0, -3, -6, -9, -11, -11, -12, -12, -12, -11, -10, -9, -7, -5, -3,
2421            -1},
2422           {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4},
2423           {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63},
2424           {-0, -1, -3, -5, -7, -9, -10, -11, -12, -12, -12, -11, -11, -9, -6,
2425            -3},
2426           {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1},
2427           {-0, -0, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1}};
2428 
2429   return vld1q_s8(kAbsHalfSubPixel8TapSignedFilterColumns[tap_index]);
2430 }
2431 
2432 // This filter is only possible when width >= 8.
2433 template <int grade_x>
ConvolveKernelHorizontalSigned8Tap(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int width,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT const intermediate)2434 inline void ConvolveKernelHorizontalSigned8Tap(
2435     const uint16_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
2436     const int width, const int subpixel_x, const int step_x,
2437     const int intermediate_height,
2438     int16_t* LIBGAV1_RESTRICT const intermediate) {
2439   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
2440   const int ref_x = subpixel_x >> kScaleSubPixelBits;
2441   const int step_x8 = step_x << 3;
2442   int8x16_t filter_taps[8];
2443   for (int i = 0; i < 8; ++i) {
2444     filter_taps[i] = GetSigned8TapFilter(i);
2445   }
2446   const uint16x8_t index_steps = vmulq_n_u16(
2447       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
2448   int16_t* intermediate_x = intermediate;
2449   int x = 0;
2450   int p = subpixel_x;
2451   do {
2452     const uint16_t* src_x = src + (p >> kScaleSubPixelBits) - ref_x;
2453     // Only add steps to the 10-bit truncated p to avoid overflow.
2454     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
2455     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
2456 
2457     const uint8x8_t filter_indices =
2458         vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
2459                 filter_index_mask);
2460 
2461     // Lower byte of Nth value is at position 2*N.
2462     const uint8x8_t src_indices0 = vshl_n_u8(
2463         vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits)), 1);
2464     // Upper byte of Nth value is at position 2*N+1.
2465     const uint8x8_t src_indices1 = vadd_u8(src_indices0, vdup_n_u8(1));
2466     const uint8x8x2_t src_indices_zip = vzip_u8(src_indices0, src_indices1);
2467     const uint8x16_t src_indices_base =
2468         vcombine_u8(src_indices_zip.val[0], src_indices_zip.val[1]);
2469 
2470     uint8x16_t src_lookup[8];
2471     const uint8x16_t two = vdupq_n_u8(2);
2472     src_lookup[0] = src_indices_base;
2473     for (int i = 1; i < 8; ++i) {
2474       src_lookup[i] = vaddq_u8(src_lookup[i - 1], two);
2475     }
2476     // Each lane of lane of taps_(low|high)[k] corresponds to one output value
2477     // along the row, containing kSubPixelFilters[filter_index][filter_id][k],
2478     // where filter_id depends on x.
2479     int16x4_t taps_low[8];
2480     int16x4_t taps_high[8];
2481     for (int i = 0; i < 8; ++i) {
2482       const int16x8_t taps = vmovl_s8(VQTbl1S8(filter_taps[i], filter_indices));
2483       taps_low[i] = vget_low_s16(taps);
2484       taps_high[i] = vget_high_s16(taps);
2485     }
2486 
2487     int y = intermediate_height;
2488     do {
2489       // Load a pool of samples to select from using stepped indices.
2490       const uint8x16x3_t src_bytes = LoadSrcVals<grade_x>(src_x);
2491 
2492       uint16x4_t src_low[8];
2493       uint16x4_t src_high[8];
2494       for (int i = 0; i < 8; ++i) {
2495         const uint16x8_t src_i =
2496             PermuteSrcVals<grade_x>(src_bytes, src_lookup[i]);
2497         src_low[i] = vget_low_u16(src_i);
2498         src_high[i] = vget_high_u16(src_i);
2499       }
2500 
2501       vst1_s16(intermediate_x,
2502                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/8>(src_low, taps_low),
2503                             kInterRoundBitsHorizontal - 1));
2504       vst1_s16(intermediate_x + 4,
2505                vrshrn_n_s32(SumOnePassTaps</*num_taps=*/8>(src_high, taps_high),
2506                             kInterRoundBitsHorizontal - 1));
2507       // Avoid right shifting the stride.
2508       src_x = AddByteStride(src_x, src_stride);
2509       intermediate_x += kIntermediateStride;
2510     } while (--y != 0);
2511     x += 8;
2512     p += step_x8;
2513   } while (x < width);
2514 }
2515 
2516 // Process 16 bit inputs and output 32 bits.
2517 template <int num_taps, bool is_compound>
Sum2DVerticalTaps4(const int16x4_t * const src,const int16x8_t taps)2518 inline int16x4_t Sum2DVerticalTaps4(const int16x4_t* const src,
2519                                     const int16x8_t taps) {
2520   const int16x4_t taps_lo = vget_low_s16(taps);
2521   const int16x4_t taps_hi = vget_high_s16(taps);
2522   int32x4_t sum;
2523   if (num_taps == 8) {
2524     sum = vmull_lane_s16(src[0], taps_lo, 0);
2525     sum = vmlal_lane_s16(sum, src[1], taps_lo, 1);
2526     sum = vmlal_lane_s16(sum, src[2], taps_lo, 2);
2527     sum = vmlal_lane_s16(sum, src[3], taps_lo, 3);
2528     sum = vmlal_lane_s16(sum, src[4], taps_hi, 0);
2529     sum = vmlal_lane_s16(sum, src[5], taps_hi, 1);
2530     sum = vmlal_lane_s16(sum, src[6], taps_hi, 2);
2531     sum = vmlal_lane_s16(sum, src[7], taps_hi, 3);
2532   } else if (num_taps == 6) {
2533     sum = vmull_lane_s16(src[0], taps_lo, 1);
2534     sum = vmlal_lane_s16(sum, src[1], taps_lo, 2);
2535     sum = vmlal_lane_s16(sum, src[2], taps_lo, 3);
2536     sum = vmlal_lane_s16(sum, src[3], taps_hi, 0);
2537     sum = vmlal_lane_s16(sum, src[4], taps_hi, 1);
2538     sum = vmlal_lane_s16(sum, src[5], taps_hi, 2);
2539   } else if (num_taps == 4) {
2540     sum = vmull_lane_s16(src[0], taps_lo, 2);
2541     sum = vmlal_lane_s16(sum, src[1], taps_lo, 3);
2542     sum = vmlal_lane_s16(sum, src[2], taps_hi, 0);
2543     sum = vmlal_lane_s16(sum, src[3], taps_hi, 1);
2544   } else if (num_taps == 2) {
2545     sum = vmull_lane_s16(src[0], taps_lo, 3);
2546     sum = vmlal_lane_s16(sum, src[1], taps_hi, 0);
2547   }
2548 
2549   if (is_compound) {
2550     return vrshrn_n_s32(sum, kInterRoundBitsCompoundVertical - 1);
2551   }
2552 
2553   return vreinterpret_s16_u16(vqrshrun_n_s32(sum, kInterRoundBitsVertical - 1));
2554 }
2555 
2556 template <int num_taps, int grade_y, int width, bool is_compound>
ConvolveVerticalScale2Or4xH(const int16_t * LIBGAV1_RESTRICT const src,const int subpixel_y,const int filter_index,const int step_y,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)2557 void ConvolveVerticalScale2Or4xH(const int16_t* LIBGAV1_RESTRICT const src,
2558                                  const int subpixel_y, const int filter_index,
2559                                  const int step_y, const int height,
2560                                  void* LIBGAV1_RESTRICT const dest,
2561                                  const ptrdiff_t dest_stride) {
2562   static_assert(width == 2 || width == 4, "");
2563   // We increment stride with the 8-bit pointer and then reinterpret to avoid
2564   // shifting |dest_stride|.
2565   auto* dest_y = static_cast<uint16_t*>(dest);
2566   // In compound mode, |dest_stride| is based on the size of uint16_t, rather
2567   // than bytes.
2568   auto* compound_dest_y = static_cast<uint16_t*>(dest);
2569   // This stride always corresponds to int16_t.
2570   constexpr ptrdiff_t src_stride = kIntermediateStride;
2571   const int16_t* src_y = src;
2572   int16x4_t s[num_taps + grade_y];
2573 
2574   int p = subpixel_y & 1023;
2575   int prev_p = p;
2576   int y = height;
2577   do {
2578     for (int i = 0; i < num_taps; ++i) {
2579       s[i] = vld1_s16(src_y + i * src_stride);
2580     }
2581     int filter_id = (p >> 6) & kSubPixelMask;
2582     int16x8_t filter =
2583         vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
2584     int16x4_t sums = Sum2DVerticalTaps4<num_taps, is_compound>(s, filter);
2585     if (is_compound) {
2586       assert(width != 2);
2587       // This offset potentially overflows into the sign bit, but should yield
2588       // the correct unsigned value.
2589       const uint16x4_t result =
2590           vreinterpret_u16_s16(vadd_s16(sums, vdup_n_s16(kCompoundOffset)));
2591       vst1_u16(compound_dest_y, result);
2592       compound_dest_y += dest_stride;
2593     } else {
2594       const uint16x4_t result = vmin_u16(vreinterpret_u16_s16(sums),
2595                                          vdup_n_u16((1 << kBitdepth10) - 1));
2596       if (width == 2) {
2597         Store2<0>(dest_y, result);
2598       } else {
2599         vst1_u16(dest_y, result);
2600       }
2601       dest_y = AddByteStride(dest_y, dest_stride);
2602     }
2603     p += step_y;
2604     const int p_diff =
2605         (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
2606     prev_p = p;
2607     // Here we load extra source in case it is needed. If |p_diff| == 0, these
2608     // values will be unused, but it's faster to load than to branch.
2609     s[num_taps] = vld1_s16(src_y + num_taps * src_stride);
2610     if (grade_y > 1) {
2611       s[num_taps + 1] = vld1_s16(src_y + (num_taps + 1) * src_stride);
2612     }
2613 
2614     filter_id = (p >> 6) & kSubPixelMask;
2615     filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
2616     sums = Sum2DVerticalTaps4<num_taps, is_compound>(&s[p_diff], filter);
2617     if (is_compound) {
2618       assert(width != 2);
2619       const uint16x4_t result =
2620           vreinterpret_u16_s16(vadd_s16(sums, vdup_n_s16(kCompoundOffset)));
2621       vst1_u16(compound_dest_y, result);
2622       compound_dest_y += dest_stride;
2623     } else {
2624       const uint16x4_t result = vmin_u16(vreinterpret_u16_s16(sums),
2625                                          vdup_n_u16((1 << kBitdepth10) - 1));
2626       if (width == 2) {
2627         Store2<0>(dest_y, result);
2628       } else {
2629         vst1_u16(dest_y, result);
2630       }
2631       dest_y = AddByteStride(dest_y, dest_stride);
2632     }
2633     p += step_y;
2634     src_y = src + (p >> kScaleSubPixelBits) * src_stride;
2635     prev_p = p;
2636     y -= 2;
2637   } while (y != 0);
2638 }
2639 
2640 template <int num_taps, int grade_y, bool is_compound>
ConvolveVerticalScale(const int16_t * LIBGAV1_RESTRICT const source,const int intermediate_height,const int width,const int subpixel_y,const int filter_index,const int step_y,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)2641 void ConvolveVerticalScale(const int16_t* LIBGAV1_RESTRICT const source,
2642                            const int intermediate_height, const int width,
2643                            const int subpixel_y, const int filter_index,
2644                            const int step_y, const int height,
2645                            void* LIBGAV1_RESTRICT const dest,
2646                            const ptrdiff_t dest_stride) {
2647   // This stride always corresponds to int16_t.
2648   constexpr ptrdiff_t src_stride = kIntermediateStride;
2649 
2650   int16x8_t s[num_taps + 2];
2651 
2652   const int16_t* src = source;
2653   int x = 0;
2654   do {
2655     const int16_t* src_y = src;
2656     int p = subpixel_y & 1023;
2657     int prev_p = p;
2658     // We increment stride with the 8-bit pointer and then reinterpret to avoid
2659     // shifting |dest_stride|.
2660     auto* dest_y = static_cast<uint16_t*>(dest) + x;
2661     // In compound mode, |dest_stride| is based on the size of uint16_t, rather
2662     // than bytes.
2663     auto* compound_dest_y = static_cast<uint16_t*>(dest) + x;
2664     int y = height;
2665     do {
2666       for (int i = 0; i < num_taps; ++i) {
2667         s[i] = vld1q_s16(src_y + i * src_stride);
2668       }
2669       int filter_id = (p >> 6) & kSubPixelMask;
2670       int16x8_t filter =
2671           vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
2672       int16x8_t sums =
2673           SimpleSum2DVerticalTaps<num_taps, is_compound>(s, filter);
2674       if (is_compound) {
2675         // This offset potentially overflows int16_t, but should yield the
2676         // correct unsigned value.
2677         const uint16x8_t result = vreinterpretq_u16_s16(
2678             vaddq_s16(sums, vdupq_n_s16(kCompoundOffset)));
2679         vst1q_u16(compound_dest_y, result);
2680         compound_dest_y += dest_stride;
2681       } else {
2682         const uint16x8_t result = vminq_u16(
2683             vreinterpretq_u16_s16(sums), vdupq_n_u16((1 << kBitdepth10) - 1));
2684         vst1q_u16(dest_y, result);
2685         dest_y = AddByteStride(dest_y, dest_stride);
2686       }
2687       p += step_y;
2688       const int p_diff =
2689           (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
2690       prev_p = p;
2691       // Here we load extra source in case it is needed. If |p_diff| == 0, these
2692       // values will be unused, but it's faster to load than to branch.
2693       s[num_taps] = vld1q_s16(src_y + num_taps * src_stride);
2694       if (grade_y > 1) {
2695         s[num_taps + 1] = vld1q_s16(src_y + (num_taps + 1) * src_stride);
2696       }
2697 
2698       filter_id = (p >> 6) & kSubPixelMask;
2699       filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
2700       sums = SimpleSum2DVerticalTaps<num_taps, is_compound>(&s[p_diff], filter);
2701       if (is_compound) {
2702         assert(width != 2);
2703         const uint16x8_t result = vreinterpretq_u16_s16(
2704             vaddq_s16(sums, vdupq_n_s16(kCompoundOffset)));
2705         vst1q_u16(compound_dest_y, result);
2706         compound_dest_y += dest_stride;
2707       } else {
2708         const uint16x8_t result = vminq_u16(
2709             vreinterpretq_u16_s16(sums), vdupq_n_u16((1 << kBitdepth10) - 1));
2710         vst1q_u16(dest_y, result);
2711         dest_y = AddByteStride(dest_y, dest_stride);
2712       }
2713       p += step_y;
2714       src_y = src + (p >> kScaleSubPixelBits) * src_stride;
2715       prev_p = p;
2716 
2717       y -= 2;
2718     } while (y != 0);
2719     src += kIntermediateStride * intermediate_height;
2720     x += 8;
2721   } while (x < width);
2722 }
2723 
2724 template <bool is_compound>
ConvolveScale2D_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int subpixel_x,const int subpixel_y,const int step_x,const int step_y,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)2725 void ConvolveScale2D_NEON(const void* LIBGAV1_RESTRICT const reference,
2726                           const ptrdiff_t reference_stride,
2727                           const int horizontal_filter_index,
2728                           const int vertical_filter_index, const int subpixel_x,
2729                           const int subpixel_y, const int step_x,
2730                           const int step_y, const int width, const int height,
2731                           void* LIBGAV1_RESTRICT const prediction,
2732                           const ptrdiff_t pred_stride) {
2733   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
2734   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
2735   assert(step_x <= 2048);
2736   assert(step_y <= 2048);
2737   const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
2738   const int intermediate_height =
2739       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
2740        kScaleSubPixelBits) +
2741       num_vert_taps;
2742   int16_t intermediate_result[kIntermediateAllocWidth *
2743                               (2 * kIntermediateAllocWidth + 8)];
2744 #if LIBGAV1_MSAN
2745   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
2746   memset(intermediate_result, 0x54, sizeof(intermediate_result));
2747 #endif
2748   // Horizontal filter.
2749   // Filter types used for width <= 4 are different from those for width > 4.
2750   // When width > 4, the valid filter index range is always [0, 3].
2751   // When width <= 4, the valid filter index range is always [3, 5].
2752   // The same applies to height and vertical filter index.
2753   int filter_index = GetFilterIndex(horizontal_filter_index, width);
2754   int16_t* intermediate = intermediate_result;
2755   const ptrdiff_t src_stride = reference_stride;
2756   const auto* src = static_cast<const uint16_t*>(reference);
2757   const int vert_kernel_offset = (8 - num_vert_taps) / 2;
2758   src = AddByteStride(src, vert_kernel_offset * src_stride);
2759 
2760   // Derive the maximum value of |step_x| at which all source values fit in one
2761   // 16-byte (8-value) load. Final index is src_x + |num_taps| - 1 < 16
2762   // step_x*7 is the final base subpel index for the shuffle mask for filter
2763   // inputs in each iteration on large blocks. When step_x is large, we need a
2764   // larger structure and use a larger table lookup in order to gather all
2765   // filter inputs.
2766   const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index);
2767   // |num_taps| - 1 is the shuffle index of the final filter input.
2768   const int kernel_start_ceiling = 16 - num_horiz_taps;
2769   // This truncated quotient |grade_x_threshold| selects |step_x| such that:
2770   // (step_x * 7) >> kScaleSubPixelBits < single load limit
2771   const int grade_x_threshold =
2772       (kernel_start_ceiling << kScaleSubPixelBits) / 7;
2773 
2774   switch (filter_index) {
2775     case 0:
2776       if (step_x > grade_x_threshold) {
2777         ConvolveKernelHorizontalSigned6Tap<2>(
2778             src, src_stride, width, subpixel_x, step_x, intermediate_height,
2779             intermediate);
2780       } else {
2781         ConvolveKernelHorizontalSigned6Tap<1>(
2782             src, src_stride, width, subpixel_x, step_x, intermediate_height,
2783             intermediate);
2784       }
2785       break;
2786     case 1:
2787       if (step_x > grade_x_threshold) {
2788         ConvolveKernelHorizontalMixed6Tap<2>(src, src_stride, width, subpixel_x,
2789                                              step_x, intermediate_height,
2790                                              intermediate);
2791 
2792       } else {
2793         ConvolveKernelHorizontalMixed6Tap<1>(src, src_stride, width, subpixel_x,
2794                                              step_x, intermediate_height,
2795                                              intermediate);
2796       }
2797       break;
2798     case 2:
2799       if (step_x > grade_x_threshold) {
2800         ConvolveKernelHorizontalSigned8Tap<2>(
2801             src, src_stride, width, subpixel_x, step_x, intermediate_height,
2802             intermediate);
2803       } else {
2804         ConvolveKernelHorizontalSigned8Tap<1>(
2805             src, src_stride, width, subpixel_x, step_x, intermediate_height,
2806             intermediate);
2807       }
2808       break;
2809     case 3:
2810       if (step_x > grade_x_threshold) {
2811         ConvolveKernelHorizontal2Tap<2>(src, src_stride, width, subpixel_x,
2812                                         step_x, intermediate_height,
2813                                         intermediate);
2814       } else {
2815         ConvolveKernelHorizontal2Tap<1>(src, src_stride, width, subpixel_x,
2816                                         step_x, intermediate_height,
2817                                         intermediate);
2818       }
2819       break;
2820     case 4:
2821       assert(width <= 4);
2822       ConvolveKernelHorizontalSigned4Tap(src, src_stride, subpixel_x, step_x,
2823                                          intermediate_height, intermediate);
2824       break;
2825     default:
2826       assert(filter_index == 5);
2827       ConvolveKernelHorizontalPositive4Tap(src, src_stride, subpixel_x, step_x,
2828                                            intermediate_height, intermediate);
2829   }
2830 
2831   // Vertical filter.
2832   filter_index = GetFilterIndex(vertical_filter_index, height);
2833   intermediate = intermediate_result;
2834   switch (filter_index) {
2835     case 0:
2836     case 1:
2837       if (step_y <= 1024) {
2838         if (!is_compound && width == 2) {
2839           ConvolveVerticalScale2Or4xH<6, 1, 2, is_compound>(
2840               intermediate, subpixel_y, filter_index, step_y, height,
2841               prediction, pred_stride);
2842         } else if (width == 4) {
2843           ConvolveVerticalScale2Or4xH<6, 1, 4, is_compound>(
2844               intermediate, subpixel_y, filter_index, step_y, height,
2845               prediction, pred_stride);
2846         } else {
2847           ConvolveVerticalScale<6, 1, is_compound>(
2848               intermediate, intermediate_height, width, subpixel_y,
2849               filter_index, step_y, height, prediction, pred_stride);
2850         }
2851       } else {
2852         if (!is_compound && width == 2) {
2853           ConvolveVerticalScale2Or4xH<6, 2, 2, is_compound>(
2854               intermediate, subpixel_y, filter_index, step_y, height,
2855               prediction, pred_stride);
2856         } else if (width == 4) {
2857           ConvolveVerticalScale2Or4xH<6, 2, 4, is_compound>(
2858               intermediate, subpixel_y, filter_index, step_y, height,
2859               prediction, pred_stride);
2860         } else {
2861           ConvolveVerticalScale<6, 2, is_compound>(
2862               intermediate, intermediate_height, width, subpixel_y,
2863               filter_index, step_y, height, prediction, pred_stride);
2864         }
2865       }
2866       break;
2867     case 2:
2868       if (step_y <= 1024) {
2869         if (!is_compound && width == 2) {
2870           ConvolveVerticalScale2Or4xH<8, 1, 2, is_compound>(
2871               intermediate, subpixel_y, filter_index, step_y, height,
2872               prediction, pred_stride);
2873         } else if (width == 4) {
2874           ConvolveVerticalScale2Or4xH<8, 1, 4, is_compound>(
2875               intermediate, subpixel_y, filter_index, step_y, height,
2876               prediction, pred_stride);
2877         } else {
2878           ConvolveVerticalScale<8, 1, is_compound>(
2879               intermediate, intermediate_height, width, subpixel_y,
2880               filter_index, step_y, height, prediction, pred_stride);
2881         }
2882       } else {
2883         if (!is_compound && width == 2) {
2884           ConvolveVerticalScale2Or4xH<8, 2, 2, is_compound>(
2885               intermediate, subpixel_y, filter_index, step_y, height,
2886               prediction, pred_stride);
2887         } else if (width == 4) {
2888           ConvolveVerticalScale2Or4xH<8, 2, 4, is_compound>(
2889               intermediate, subpixel_y, filter_index, step_y, height,
2890               prediction, pred_stride);
2891         } else {
2892           ConvolveVerticalScale<8, 2, is_compound>(
2893               intermediate, intermediate_height, width, subpixel_y,
2894               filter_index, step_y, height, prediction, pred_stride);
2895         }
2896       }
2897       break;
2898     case 3:
2899       if (step_y <= 1024) {
2900         if (!is_compound && width == 2) {
2901           ConvolveVerticalScale2Or4xH<2, 1, 2, is_compound>(
2902               intermediate, subpixel_y, filter_index, step_y, height,
2903               prediction, pred_stride);
2904         } else if (width == 4) {
2905           ConvolveVerticalScale2Or4xH<2, 1, 4, is_compound>(
2906               intermediate, subpixel_y, filter_index, step_y, height,
2907               prediction, pred_stride);
2908         } else {
2909           ConvolveVerticalScale<2, 1, is_compound>(
2910               intermediate, intermediate_height, width, subpixel_y,
2911               filter_index, step_y, height, prediction, pred_stride);
2912         }
2913       } else {
2914         if (!is_compound && width == 2) {
2915           ConvolveVerticalScale2Or4xH<2, 2, 2, is_compound>(
2916               intermediate, subpixel_y, filter_index, step_y, height,
2917               prediction, pred_stride);
2918         } else if (width == 4) {
2919           ConvolveVerticalScale2Or4xH<2, 2, 4, is_compound>(
2920               intermediate, subpixel_y, filter_index, step_y, height,
2921               prediction, pred_stride);
2922         } else {
2923           ConvolveVerticalScale<2, 2, is_compound>(
2924               intermediate, intermediate_height, width, subpixel_y,
2925               filter_index, step_y, height, prediction, pred_stride);
2926         }
2927       }
2928       break;
2929     default:
2930       assert(filter_index == 4 || filter_index == 5);
2931       assert(height <= 4);
2932       if (step_y <= 1024) {
2933         if (!is_compound && width == 2) {
2934           ConvolveVerticalScale2Or4xH<4, 1, 2, is_compound>(
2935               intermediate, subpixel_y, filter_index, step_y, height,
2936               prediction, pred_stride);
2937         } else if (width == 4) {
2938           ConvolveVerticalScale2Or4xH<4, 1, 4, is_compound>(
2939               intermediate, subpixel_y, filter_index, step_y, height,
2940               prediction, pred_stride);
2941         } else {
2942           ConvolveVerticalScale<4, 1, is_compound>(
2943               intermediate, intermediate_height, width, subpixel_y,
2944               filter_index, step_y, height, prediction, pred_stride);
2945         }
2946       } else {
2947         if (!is_compound && width == 2) {
2948           ConvolveVerticalScale2Or4xH<4, 2, 2, is_compound>(
2949               intermediate, subpixel_y, filter_index, step_y, height,
2950               prediction, pred_stride);
2951         } else if (width == 4) {
2952           ConvolveVerticalScale2Or4xH<4, 2, 4, is_compound>(
2953               intermediate, subpixel_y, filter_index, step_y, height,
2954               prediction, pred_stride);
2955         } else {
2956           ConvolveVerticalScale<4, 2, is_compound>(
2957               intermediate, intermediate_height, width, subpixel_y,
2958               filter_index, step_y, height, prediction, pred_stride);
2959         }
2960       }
2961   }
2962 }
2963 
Init10bpp()2964 void Init10bpp() {
2965   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
2966   assert(dsp != nullptr);
2967   dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
2968   dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
2969   dsp->convolve[0][0][1][1] = Convolve2D_NEON;
2970 
2971   dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON;
2972   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON;
2973   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
2974   dsp->convolve[0][1][1][1] = ConvolveCompound2D_NEON;
2975 
2976   dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_NEON;
2977   dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_NEON;
2978   dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_NEON;
2979 
2980   dsp->convolve_scale[0] = ConvolveScale2D_NEON<false>;
2981   dsp->convolve_scale[1] = ConvolveScale2D_NEON<true>;
2982 }
2983 
2984 }  // namespace
2985 
ConvolveInit10bpp_NEON()2986 void ConvolveInit10bpp_NEON() { Init10bpp(); }
2987 
2988 }  // namespace dsp
2989 }  // namespace libgav1
2990 
2991 #else   // !(LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10)
2992 
2993 namespace libgav1 {
2994 namespace dsp {
2995 
ConvolveInit10bpp_NEON()2996 void ConvolveInit10bpp_NEON() {}
2997 
2998 }  // namespace dsp
2999 }  // namespace libgav1
3000 #endif  // LIBGAV1_ENABLE_NEON && LIBGAV1_MAX_BITDEPTH >= 10
3001