xref: /aosp_15_r20/external/libgav1/src/dsp/arm/convolve_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/convolve.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstddef>
25 #include <cstdint>
26 
27 #include "src/dsp/arm/common_neon.h"
28 #include "src/dsp/constants.h"
29 #include "src/dsp/dsp.h"
30 #include "src/utils/common.h"
31 #include "src/utils/compiler_attributes.h"
32 
33 namespace libgav1 {
34 namespace dsp {
35 namespace low_bitdepth {
36 namespace {
37 
38 // Include the constants and utility functions inside the anonymous namespace.
39 #include "src/dsp/convolve.inc"
40 
41 // Multiply every entry in |src[]| by the corresponding entry in |taps[]| and
42 // sum. The filters in |taps[]| are pre-shifted by 1. This prevents the final
43 // sum from outranging int16_t.
44 template <int filter_index, bool negative_outside_taps = false>
SumOnePassTaps(const uint8x8_t * const src,const uint8x8_t * const taps)45 int16x8_t SumOnePassTaps(const uint8x8_t* const src,
46                          const uint8x8_t* const taps) {
47   uint16x8_t sum;
48   if (filter_index == 0) {
49     // 6 taps. + - + + - +
50     sum = vmull_u8(src[0], taps[0]);
51     // Unsigned overflow will result in a valid int16_t value.
52     sum = vmlsl_u8(sum, src[1], taps[1]);
53     sum = vmlal_u8(sum, src[2], taps[2]);
54     sum = vmlal_u8(sum, src[3], taps[3]);
55     sum = vmlsl_u8(sum, src[4], taps[4]);
56     sum = vmlal_u8(sum, src[5], taps[5]);
57   } else if (filter_index == 1 && negative_outside_taps) {
58     // 6 taps. - + + + + -
59     // Set a base we can subtract from.
60     sum = vmull_u8(src[1], taps[1]);
61     sum = vmlsl_u8(sum, src[0], taps[0]);
62     sum = vmlal_u8(sum, src[2], taps[2]);
63     sum = vmlal_u8(sum, src[3], taps[3]);
64     sum = vmlal_u8(sum, src[4], taps[4]);
65     sum = vmlsl_u8(sum, src[5], taps[5]);
66   } else if (filter_index == 1) {
67     // 6 taps. All are positive.
68     sum = vmull_u8(src[0], taps[0]);
69     sum = vmlal_u8(sum, src[1], taps[1]);
70     sum = vmlal_u8(sum, src[2], taps[2]);
71     sum = vmlal_u8(sum, src[3], taps[3]);
72     sum = vmlal_u8(sum, src[4], taps[4]);
73     sum = vmlal_u8(sum, src[5], taps[5]);
74   } else if (filter_index == 2) {
75     // 8 taps. - + - + + - + -
76     sum = vmull_u8(src[1], taps[1]);
77     sum = vmlsl_u8(sum, src[0], taps[0]);
78     sum = vmlsl_u8(sum, src[2], taps[2]);
79     sum = vmlal_u8(sum, src[3], taps[3]);
80     sum = vmlal_u8(sum, src[4], taps[4]);
81     sum = vmlsl_u8(sum, src[5], taps[5]);
82     sum = vmlal_u8(sum, src[6], taps[6]);
83     sum = vmlsl_u8(sum, src[7], taps[7]);
84   } else if (filter_index == 3) {
85     // 2 taps. All are positive.
86     sum = vmull_u8(src[0], taps[0]);
87     sum = vmlal_u8(sum, src[1], taps[1]);
88   } else if (filter_index == 4) {
89     // 4 taps. - + + -
90     sum = vmull_u8(src[1], taps[1]);
91     sum = vmlsl_u8(sum, src[0], taps[0]);
92     sum = vmlal_u8(sum, src[2], taps[2]);
93     sum = vmlsl_u8(sum, src[3], taps[3]);
94   } else if (filter_index == 5) {
95     // 4 taps. All are positive.
96     sum = vmull_u8(src[0], taps[0]);
97     sum = vmlal_u8(sum, src[1], taps[1]);
98     sum = vmlal_u8(sum, src[2], taps[2]);
99     sum = vmlal_u8(sum, src[3], taps[3]);
100   }
101   return vreinterpretq_s16_u16(sum);
102 }
103 
104 template <int filter_index, bool negative_outside_taps, bool is_2d,
105           bool is_compound>
FilterHorizontalWidth8AndUp(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int width,const int height,const uint8x8_t * const v_tap)106 void FilterHorizontalWidth8AndUp(const uint8_t* LIBGAV1_RESTRICT src,
107                                  const ptrdiff_t src_stride,
108                                  void* LIBGAV1_RESTRICT const dest,
109                                  const ptrdiff_t pred_stride, const int width,
110                                  const int height,
111                                  const uint8x8_t* const v_tap) {
112   auto* dest8 = static_cast<uint8_t*>(dest);
113   auto* dest16 = static_cast<uint16_t*>(dest);
114   if (!is_2d) {
115     int y = height;
116     do {
117       int x = 0;
118       do {  // Increasing loop counter x is better.
119         const uint8x16_t src_long = vld1q_u8(src + x);
120         uint8x8_t v_src[8];
121         int16x8_t sum;
122         if (filter_index < 2) {
123           v_src[0] = vget_low_u8(src_long);
124           v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
125           v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2));
126           v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3));
127           v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 4));
128           v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 5));
129           sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src,
130                                                                     v_tap + 1);
131         } else if (filter_index == 2) {
132           v_src[0] = vget_low_u8(src_long);
133           v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
134           v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2));
135           v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3));
136           v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 4));
137           v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 5));
138           v_src[6] = vget_low_u8(vextq_u8(src_long, src_long, 6));
139           v_src[7] = vget_low_u8(vextq_u8(src_long, src_long, 7));
140           sum = SumOnePassTaps<filter_index, false>(v_src, v_tap);
141         } else if (filter_index == 3) {
142           v_src[0] = vget_low_u8(src_long);
143           v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
144           sum = SumOnePassTaps<filter_index, false>(v_src, v_tap + 3);
145         } else if (filter_index > 3) {
146           v_src[0] = vget_low_u8(src_long);
147           v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
148           v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2));
149           v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3));
150           sum = SumOnePassTaps<filter_index, false>(v_src, v_tap + 2);
151         }
152         if (is_compound) {
153           const uint16x8_t v_sum = vreinterpretq_u16_s16(
154               vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
155           vst1q_u16(&dest16[x], v_sum);
156         } else {
157           // Normally the Horizontal pass does the downshift in two passes:
158           // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
159           // kInterRoundBitsHorizontal). Each one uses a rounding shift.
160           // Combining them requires adding the rounding offset from the skipped
161           // shift.
162           constexpr int first_shift_rounding_bit =
163               1 << (kInterRoundBitsHorizontal - 2);
164           sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit));
165           const uint8x8_t result = vqrshrun_n_s16(sum, kFilterBits - 1);
166           vst1_u8(&dest8[x], result);
167         }
168         x += 8;
169       } while (x < width);
170       src += src_stride;
171       dest8 += pred_stride;
172       dest16 += pred_stride;
173     } while (--y != 0);
174   } else {
175     int x = 0;
176     do {
177       const uint8_t* s = src + x;
178       int y = height;
179       do {  // Increasing loop counter x is better.
180         const uint8x16_t src_long = vld1q_u8(s);
181         uint8x8_t v_src[8];
182         int16x8_t sum;
183         if (filter_index < 2) {
184           v_src[0] = vget_low_u8(src_long);
185           v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
186           v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2));
187           v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3));
188           v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 4));
189           v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 5));
190           sum = SumOnePassTaps<filter_index, negative_outside_taps>(v_src,
191                                                                     v_tap + 1);
192         } else if (filter_index == 2) {
193           v_src[0] = vget_low_u8(src_long);
194           v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
195           v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2));
196           v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3));
197           v_src[4] = vget_low_u8(vextq_u8(src_long, src_long, 4));
198           v_src[5] = vget_low_u8(vextq_u8(src_long, src_long, 5));
199           v_src[6] = vget_low_u8(vextq_u8(src_long, src_long, 6));
200           v_src[7] = vget_low_u8(vextq_u8(src_long, src_long, 7));
201           sum = SumOnePassTaps<filter_index, false>(v_src, v_tap);
202         } else if (filter_index == 3) {
203           v_src[0] = vget_low_u8(src_long);
204           v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
205           sum = SumOnePassTaps<filter_index, false>(v_src, v_tap + 3);
206         } else if (filter_index > 3) {
207           v_src[0] = vget_low_u8(src_long);
208           v_src[1] = vget_low_u8(vextq_u8(src_long, src_long, 1));
209           v_src[2] = vget_low_u8(vextq_u8(src_long, src_long, 2));
210           v_src[3] = vget_low_u8(vextq_u8(src_long, src_long, 3));
211           sum = SumOnePassTaps<filter_index, false>(v_src, v_tap + 2);
212         }
213         const uint16x8_t v_sum = vreinterpretq_u16_s16(
214             vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
215         vst1q_u16(dest16, v_sum);
216         s += src_stride;
217         dest16 += 8;
218       } while (--y != 0);
219       x += 8;
220     } while (x < width);
221   }
222 }
223 
224 template <int filter_index, bool is_2d, bool is_compound>
FilterHorizontalWidth4(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int height,const uint8x8_t * const v_tap)225 void FilterHorizontalWidth4(const uint8_t* LIBGAV1_RESTRICT src,
226                             const ptrdiff_t src_stride,
227                             void* LIBGAV1_RESTRICT const dest,
228                             const ptrdiff_t pred_stride, const int height,
229                             const uint8x8_t* const v_tap) {
230   auto* dest8 = static_cast<uint8_t*>(dest);
231   auto* dest16 = static_cast<uint16_t*>(dest);
232   int y = height;
233   do {
234     uint8x8_t v_src[4];
235     int16x8_t sum;
236     v_src[0] = vld1_u8(src);
237     if (filter_index == 3) {
238       v_src[1] = RightShiftVector<1 * 8>(v_src[0]);
239       sum = SumOnePassTaps<filter_index, false>(v_src, v_tap + 3);
240     } else {
241       v_src[1] = RightShiftVector<1 * 8>(v_src[0]);
242       v_src[2] = RightShiftVector<2 * 8>(v_src[0]);
243       v_src[3] = RightShiftVector<3 * 8>(v_src[0]);
244       sum = SumOnePassTaps<filter_index, false>(v_src, v_tap + 2);
245     }
246     if (is_2d || is_compound) {
247       const uint16x4_t v_sum = vreinterpret_u16_s16(
248           vrshr_n_s16(vget_low_s16(sum), kInterRoundBitsHorizontal - 1));
249       vst1_u16(dest16, v_sum);
250     } else {
251       constexpr int first_shift_rounding_bit =
252           1 << (kInterRoundBitsHorizontal - 2);
253       sum = vaddq_s16(sum, vdupq_n_s16(first_shift_rounding_bit));
254       const uint8x8_t result = vqrshrun_n_s16(sum, kFilterBits - 1);
255       StoreLo4(&dest8[0], result);
256     }
257     src += src_stride;
258     dest8 += pred_stride;
259     dest16 += pred_stride;
260   } while (--y != 0);
261 }
262 
263 template <int filter_index, bool is_2d>
FilterHorizontalWidth2(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int height,const uint8x8_t * const v_tap)264 void FilterHorizontalWidth2(const uint8_t* LIBGAV1_RESTRICT src,
265                             const ptrdiff_t src_stride,
266                             void* LIBGAV1_RESTRICT const dest,
267                             const ptrdiff_t pred_stride, const int height,
268                             const uint8x8_t* const v_tap) {
269   auto* dest8 = static_cast<uint8_t*>(dest);
270   auto* dest16 = static_cast<uint16_t*>(dest);
271   int y = height >> 1;
272   do {
273     const uint8x8_t input0 = vld1_u8(src);
274     const uint8x8_t input1 = vld1_u8(src + src_stride);
275     const uint8x8x2_t input = vzip_u8(input0, input1);
276     uint16x8_t sum;
277     if (filter_index == 3) {
278       // tap signs : + +
279       sum = vmull_u8(input.val[0], v_tap[3]);
280       sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 2), v_tap[4]);
281     } else if (filter_index == 4) {
282       // tap signs : - + + -
283       sum = vmull_u8(RightShiftVector<2 * 8>(input.val[0]), v_tap[3]);
284       sum = vmlsl_u8(sum, input.val[0], v_tap[2]);
285       sum = vmlal_u8(sum, RightShiftVector<4 * 8>(input.val[0]), v_tap[4]);
286       sum = vmlsl_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[5]);
287     } else {
288       // tap signs : + + + +
289       sum = vmull_u8(input.val[0], v_tap[2]);
290       sum = vmlal_u8(sum, RightShiftVector<2 * 8>(input.val[0]), v_tap[3]);
291       sum = vmlal_u8(sum, RightShiftVector<4 * 8>(input.val[0]), v_tap[4]);
292       sum = vmlal_u8(sum, vext_u8(input.val[0], input.val[1], 6), v_tap[5]);
293     }
294     int16x8_t s = vreinterpretq_s16_u16(sum);
295     if (is_2d) {
296       const uint16x8_t v_sum =
297           vreinterpretq_u16_s16(vrshrq_n_s16(s, kInterRoundBitsHorizontal - 1));
298       dest16[0] = vgetq_lane_u16(v_sum, 0);
299       dest16[1] = vgetq_lane_u16(v_sum, 2);
300       dest16 += pred_stride;
301       dest16[0] = vgetq_lane_u16(v_sum, 1);
302       dest16[1] = vgetq_lane_u16(v_sum, 3);
303       dest16 += pred_stride;
304     } else {
305       // Normally the Horizontal pass does the downshift in two passes:
306       // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
307       // kInterRoundBitsHorizontal). Each one uses a rounding shift.
308       // Combining them requires adding the rounding offset from the skipped
309       // shift.
310       constexpr int first_shift_rounding_bit =
311           1 << (kInterRoundBitsHorizontal - 2);
312       s = vaddq_s16(s, vdupq_n_s16(first_shift_rounding_bit));
313       const uint8x8_t result = vqrshrun_n_s16(s, kFilterBits - 1);
314       dest8[0] = vget_lane_u8(result, 0);
315       dest8[1] = vget_lane_u8(result, 2);
316       dest8 += pred_stride;
317       dest8[0] = vget_lane_u8(result, 1);
318       dest8[1] = vget_lane_u8(result, 3);
319       dest8 += pred_stride;
320     }
321     src += src_stride << 1;
322   } while (--y != 0);
323 
324   // The 2d filters have an odd |height| because the horizontal pass
325   // generates context for the vertical pass.
326   if (is_2d) {
327     assert(height % 2 == 1);
328     const uint8x8_t input = vld1_u8(src);
329     uint16x8_t sum;
330     if (filter_index == 3) {
331       sum = vmull_u8(input, v_tap[3]);
332       sum = vmlal_u8(sum, RightShiftVector<1 * 8>(input), v_tap[4]);
333     } else if (filter_index == 4) {
334       sum = vmull_u8(RightShiftVector<1 * 8>(input), v_tap[3]);
335       sum = vmlsl_u8(sum, input, v_tap[2]);
336       sum = vmlal_u8(sum, RightShiftVector<2 * 8>(input), v_tap[4]);
337       sum = vmlsl_u8(sum, RightShiftVector<3 * 8>(input), v_tap[5]);
338     } else {
339       assert(filter_index == 5);
340       sum = vmull_u8(input, v_tap[2]);
341       sum = vmlal_u8(sum, RightShiftVector<1 * 8>(input), v_tap[3]);
342       sum = vmlal_u8(sum, RightShiftVector<2 * 8>(input), v_tap[4]);
343       sum = vmlal_u8(sum, RightShiftVector<3 * 8>(input), v_tap[5]);
344     }
345     // |sum| contains an int16_t value.
346     sum = vreinterpretq_u16_s16(vrshrq_n_s16(vreinterpretq_s16_u16(sum),
347                                              kInterRoundBitsHorizontal - 1));
348     Store2<0>(dest16, sum);
349   }
350 }
351 
352 template <int filter_index, bool negative_outside_taps, bool is_2d,
353           bool is_compound>
FilterHorizontal(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int width,const int height,const uint8x8_t * const v_tap)354 void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT const src,
355                       const ptrdiff_t src_stride,
356                       void* LIBGAV1_RESTRICT const dest,
357                       const ptrdiff_t pred_stride, const int width,
358                       const int height, const uint8x8_t* const v_tap) {
359   assert(width < 8 || filter_index <= 3);
360   // Don't simplify the redundant if conditions with the template parameters,
361   // which helps the compiler generate compact code.
362   if (width >= 8 && filter_index <= 3) {
363     FilterHorizontalWidth8AndUp<filter_index, negative_outside_taps, is_2d,
364                                 is_compound>(src, src_stride, dest, pred_stride,
365                                              width, height, v_tap);
366     return;
367   }
368 
369   // Horizontal passes only needs to account for number of taps 2 and 4 when
370   // |width| <= 4.
371   assert(width <= 4);
372   assert(filter_index >= 3 && filter_index <= 5);
373   if (filter_index >= 3 && filter_index <= 5) {
374     if (width == 2 && !is_compound) {
375       FilterHorizontalWidth2<filter_index, is_2d>(src, src_stride, dest,
376                                                   pred_stride, height, v_tap);
377       return;
378     }
379     assert(width == 4);
380     FilterHorizontalWidth4<filter_index, is_2d, is_compound>(
381         src, src_stride, dest, pred_stride, height, v_tap);
382   }
383 }
384 
385 // Process 16 bit inputs and output 32 bits.
386 template <int num_taps, bool is_compound>
Sum2DVerticalTaps4(const int16x4_t * const src,const int16x8_t taps)387 inline int16x4_t Sum2DVerticalTaps4(const int16x4_t* const src,
388                                     const int16x8_t taps) {
389   const int16x4_t taps_lo = vget_low_s16(taps);
390   const int16x4_t taps_hi = vget_high_s16(taps);
391   int32x4_t sum;
392   if (num_taps == 8) {
393     sum = vmull_lane_s16(src[0], taps_lo, 0);
394     sum = vmlal_lane_s16(sum, src[1], taps_lo, 1);
395     sum = vmlal_lane_s16(sum, src[2], taps_lo, 2);
396     sum = vmlal_lane_s16(sum, src[3], taps_lo, 3);
397     sum = vmlal_lane_s16(sum, src[4], taps_hi, 0);
398     sum = vmlal_lane_s16(sum, src[5], taps_hi, 1);
399     sum = vmlal_lane_s16(sum, src[6], taps_hi, 2);
400     sum = vmlal_lane_s16(sum, src[7], taps_hi, 3);
401   } else if (num_taps == 6) {
402     sum = vmull_lane_s16(src[0], taps_lo, 1);
403     sum = vmlal_lane_s16(sum, src[1], taps_lo, 2);
404     sum = vmlal_lane_s16(sum, src[2], taps_lo, 3);
405     sum = vmlal_lane_s16(sum, src[3], taps_hi, 0);
406     sum = vmlal_lane_s16(sum, src[4], taps_hi, 1);
407     sum = vmlal_lane_s16(sum, src[5], taps_hi, 2);
408   } else if (num_taps == 4) {
409     sum = vmull_lane_s16(src[0], taps_lo, 2);
410     sum = vmlal_lane_s16(sum, src[1], taps_lo, 3);
411     sum = vmlal_lane_s16(sum, src[2], taps_hi, 0);
412     sum = vmlal_lane_s16(sum, src[3], taps_hi, 1);
413   } else if (num_taps == 2) {
414     sum = vmull_lane_s16(src[0], taps_lo, 3);
415     sum = vmlal_lane_s16(sum, src[1], taps_hi, 0);
416   }
417 
418   if (is_compound) {
419     return vqrshrn_n_s32(sum, kInterRoundBitsCompoundVertical - 1);
420   }
421 
422   return vqrshrn_n_s32(sum, kInterRoundBitsVertical - 1);
423 }
424 
425 template <int num_taps, bool is_compound>
SimpleSum2DVerticalTaps(const int16x8_t * const src,const int16x8_t taps)426 int16x8_t SimpleSum2DVerticalTaps(const int16x8_t* const src,
427                                   const int16x8_t taps) {
428   const int16x4_t taps_lo = vget_low_s16(taps);
429   const int16x4_t taps_hi = vget_high_s16(taps);
430   int32x4_t sum_lo, sum_hi;
431   if (num_taps == 8) {
432     sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 0);
433     sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 0);
434     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 1);
435     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 1);
436     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 2);
437     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 2);
438     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_lo, 3);
439     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_lo, 3);
440 
441     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 0);
442     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 0);
443     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 1);
444     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 1);
445     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[6]), taps_hi, 2);
446     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[6]), taps_hi, 2);
447     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[7]), taps_hi, 3);
448     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[7]), taps_hi, 3);
449   } else if (num_taps == 6) {
450     sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 1);
451     sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 1);
452     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 2);
453     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 2);
454     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_lo, 3);
455     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_lo, 3);
456 
457     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 0);
458     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 0);
459     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[4]), taps_hi, 1);
460     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[4]), taps_hi, 1);
461     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[5]), taps_hi, 2);
462     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[5]), taps_hi, 2);
463   } else if (num_taps == 4) {
464     sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 2);
465     sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 2);
466     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_lo, 3);
467     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_lo, 3);
468 
469     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[2]), taps_hi, 0);
470     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[2]), taps_hi, 0);
471     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[3]), taps_hi, 1);
472     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[3]), taps_hi, 1);
473   } else if (num_taps == 2) {
474     sum_lo = vmull_lane_s16(vget_low_s16(src[0]), taps_lo, 3);
475     sum_hi = vmull_lane_s16(vget_high_s16(src[0]), taps_lo, 3);
476 
477     sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(src[1]), taps_hi, 0);
478     sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(src[1]), taps_hi, 0);
479   }
480 
481   if (is_compound) {
482     return vcombine_s16(
483         vqrshrn_n_s32(sum_lo, kInterRoundBitsCompoundVertical - 1),
484         vqrshrn_n_s32(sum_hi, kInterRoundBitsCompoundVertical - 1));
485   }
486 
487   return vcombine_s16(vqrshrn_n_s32(sum_lo, kInterRoundBitsVertical - 1),
488                       vqrshrn_n_s32(sum_hi, kInterRoundBitsVertical - 1));
489 }
490 
491 template <int num_taps, bool is_compound = false>
Filter2DVerticalWidth8AndUp(const uint16_t * LIBGAV1_RESTRICT src,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const int16x8_t taps)492 void Filter2DVerticalWidth8AndUp(const uint16_t* LIBGAV1_RESTRICT src,
493                                  void* LIBGAV1_RESTRICT const dst,
494                                  const ptrdiff_t dst_stride, const int width,
495                                  const int height, const int16x8_t taps) {
496   assert(width >= 8);
497   constexpr int next_row = num_taps - 1;
498   auto* const dst8 = static_cast<uint8_t*>(dst);
499   auto* const dst16 = static_cast<uint16_t*>(dst);
500 
501   int x = 0;
502   do {
503     int16x8_t srcs[9];
504     srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src));
505     src += 8;
506     if (num_taps >= 4) {
507       srcs[1] = vreinterpretq_s16_u16(vld1q_u16(src));
508       src += 8;
509       srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src));
510       src += 8;
511       if (num_taps >= 6) {
512         srcs[3] = vreinterpretq_s16_u16(vld1q_u16(src));
513         src += 8;
514         srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src));
515         src += 8;
516         if (num_taps == 8) {
517           srcs[5] = vreinterpretq_s16_u16(vld1q_u16(src));
518           src += 8;
519           srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src));
520           src += 8;
521         }
522       }
523     }
524 
525     uint8_t* d8 = dst8 + x;
526     uint16_t* d16 = dst16 + x;
527     int y = height;
528     do {
529       srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src));
530       src += 8;
531       srcs[next_row + 1] = vreinterpretq_s16_u16(vld1q_u16(src));
532       src += 8;
533       const int16x8_t sum0 =
534           SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs + 0, taps);
535       const int16x8_t sum1 =
536           SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs + 1, taps);
537       if (is_compound) {
538         vst1q_u16(d16, vreinterpretq_u16_s16(sum0));
539         d16 += dst_stride;
540         vst1q_u16(d16, vreinterpretq_u16_s16(sum1));
541         d16 += dst_stride;
542       } else {
543         vst1_u8(d8, vqmovun_s16(sum0));
544         d8 += dst_stride;
545         vst1_u8(d8, vqmovun_s16(sum1));
546         d8 += dst_stride;
547       }
548       srcs[0] = srcs[2];
549       if (num_taps >= 4) {
550         srcs[1] = srcs[3];
551         srcs[2] = srcs[4];
552         if (num_taps >= 6) {
553           srcs[3] = srcs[5];
554           srcs[4] = srcs[6];
555           if (num_taps == 8) {
556             srcs[5] = srcs[7];
557             srcs[6] = srcs[8];
558           }
559         }
560       }
561       y -= 2;
562     } while (y != 0);
563     x += 8;
564   } while (x < width);
565 }
566 
567 // Take advantage of |src_stride| == |width| to process two rows at a time.
568 template <int num_taps, bool is_compound = false>
Filter2DVerticalWidth4(const uint16_t * LIBGAV1_RESTRICT src,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int height,const int16x8_t taps)569 void Filter2DVerticalWidth4(const uint16_t* LIBGAV1_RESTRICT src,
570                             void* LIBGAV1_RESTRICT const dst,
571                             const ptrdiff_t dst_stride, const int height,
572                             const int16x8_t taps) {
573   auto* dst8 = static_cast<uint8_t*>(dst);
574   auto* dst16 = static_cast<uint16_t*>(dst);
575 
576   int16x8_t srcs[9];
577   srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src));
578   src += 8;
579   if (num_taps >= 4) {
580     srcs[2] = vreinterpretq_s16_u16(vld1q_u16(src));
581     src += 8;
582     srcs[1] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[2]));
583     if (num_taps >= 6) {
584       srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src));
585       src += 8;
586       srcs[3] = vcombine_s16(vget_high_s16(srcs[2]), vget_low_s16(srcs[4]));
587       if (num_taps == 8) {
588         srcs[6] = vreinterpretq_s16_u16(vld1q_u16(src));
589         src += 8;
590         srcs[5] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[6]));
591       }
592     }
593   }
594 
595   int y = height;
596   do {
597     srcs[num_taps] = vreinterpretq_s16_u16(vld1q_u16(src));
598     src += 8;
599     srcs[num_taps - 1] = vcombine_s16(vget_high_s16(srcs[num_taps - 2]),
600                                       vget_low_s16(srcs[num_taps]));
601 
602     const int16x8_t sum =
603         SimpleSum2DVerticalTaps<num_taps, is_compound>(srcs, taps);
604     if (is_compound) {
605       const uint16x8_t results = vreinterpretq_u16_s16(sum);
606       vst1q_u16(dst16, results);
607       dst16 += 4 << 1;
608     } else {
609       const uint8x8_t results = vqmovun_s16(sum);
610 
611       StoreLo4(dst8, results);
612       dst8 += dst_stride;
613       StoreHi4(dst8, results);
614       dst8 += dst_stride;
615     }
616 
617     srcs[0] = srcs[2];
618     if (num_taps >= 4) {
619       srcs[1] = srcs[3];
620       srcs[2] = srcs[4];
621       if (num_taps >= 6) {
622         srcs[3] = srcs[5];
623         srcs[4] = srcs[6];
624         if (num_taps == 8) {
625           srcs[5] = srcs[7];
626           srcs[6] = srcs[8];
627         }
628       }
629     }
630     y -= 2;
631   } while (y != 0);
632 }
633 
634 // Take advantage of |src_stride| == |width| to process four rows at a time.
635 template <int num_taps>
Filter2DVerticalWidth2(const uint16_t * LIBGAV1_RESTRICT src,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int height,const int16x8_t taps)636 void Filter2DVerticalWidth2(const uint16_t* LIBGAV1_RESTRICT src,
637                             void* LIBGAV1_RESTRICT const dst,
638                             const ptrdiff_t dst_stride, const int height,
639                             const int16x8_t taps) {
640   constexpr int next_row = (num_taps < 6) ? 4 : 8;
641 
642   auto* dst8 = static_cast<uint8_t*>(dst);
643 
644   int16x8_t srcs[9];
645   srcs[0] = vreinterpretq_s16_u16(vld1q_u16(src));
646   src += 8;
647   if (num_taps >= 6) {
648     srcs[4] = vreinterpretq_s16_u16(vld1q_u16(src));
649     src += 8;
650     srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
651     if (num_taps == 8) {
652       srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
653       srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
654     }
655   }
656 
657   int y = 0;
658   do {
659     srcs[next_row] = vreinterpretq_s16_u16(vld1q_u16(src));
660     src += 8;
661     if (num_taps == 2) {
662       srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
663     } else if (num_taps == 4) {
664       srcs[1] = vextq_s16(srcs[0], srcs[4], 2);
665       srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
666       srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
667     } else if (num_taps == 6) {
668       srcs[2] = vcombine_s16(vget_high_s16(srcs[0]), vget_low_s16(srcs[4]));
669       srcs[3] = vextq_s16(srcs[0], srcs[4], 6);
670       srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
671     } else if (num_taps == 8) {
672       srcs[5] = vextq_s16(srcs[4], srcs[8], 2);
673       srcs[6] = vcombine_s16(vget_high_s16(srcs[4]), vget_low_s16(srcs[8]));
674       srcs[7] = vextq_s16(srcs[4], srcs[8], 6);
675     }
676 
677     const int16x8_t sum =
678         SimpleSum2DVerticalTaps<num_taps, /*is_compound=*/false>(srcs, taps);
679     const uint8x8_t results = vqmovun_s16(sum);
680 
681     Store2<0>(dst8, results);
682     dst8 += dst_stride;
683     Store2<1>(dst8, results);
684     // When |height| <= 4 the taps are restricted to 2 and 4 tap variants.
685     // Therefore we don't need to check this condition when |height| > 4.
686     if (num_taps <= 4 && height == 2) return;
687     dst8 += dst_stride;
688     Store2<2>(dst8, results);
689     dst8 += dst_stride;
690     Store2<3>(dst8, results);
691     dst8 += dst_stride;
692 
693     srcs[0] = srcs[4];
694     if (num_taps == 6) {
695       srcs[1] = srcs[5];
696       srcs[4] = srcs[8];
697     } else if (num_taps == 8) {
698       srcs[1] = srcs[5];
699       srcs[2] = srcs[6];
700       srcs[3] = srcs[7];
701       srcs[4] = srcs[8];
702     }
703 
704     y += 4;
705   } while (y < height);
706 }
707 
708 template <bool is_2d = false, bool is_compound = false>
DoHorizontalPass(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const int filter_id,const int filter_index)709 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
710     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
711     void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
712     const int width, const int height, const int filter_id,
713     const int filter_index) {
714   // Duplicate the absolute value for each tap.  Negative taps are corrected
715   // by using the vmlsl_u8 instruction.  Positive taps use vmlal_u8.
716   uint8x8_t v_tap[kSubPixelTaps];
717   assert(filter_id != 0);
718 
719   for (int k = 0; k < kSubPixelTaps; ++k) {
720     v_tap[k] = vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][filter_id][k]);
721   }
722 
723   if (filter_index == 2) {  // 8 tap.
724     FilterHorizontal<2, true, is_2d, is_compound>(
725         src, src_stride, dst, dst_stride, width, height, v_tap);
726   } else if (filter_index == 1) {  // 6 tap.
727     // Check if outside taps are positive.
728     if ((filter_id == 1) | (filter_id == 15)) {
729       FilterHorizontal<1, false, is_2d, is_compound>(
730           src + 1, src_stride, dst, dst_stride, width, height, v_tap);
731     } else {
732       FilterHorizontal<1, true, is_2d, is_compound>(
733           src + 1, src_stride, dst, dst_stride, width, height, v_tap);
734     }
735   } else if (filter_index == 0) {  // 6 tap.
736     FilterHorizontal<0, true, is_2d, is_compound>(
737         src + 1, src_stride, dst, dst_stride, width, height, v_tap);
738   } else if (filter_index == 4) {  // 4 tap.
739     FilterHorizontal<4, true, is_2d, is_compound>(
740         src + 2, src_stride, dst, dst_stride, width, height, v_tap);
741   } else if (filter_index == 5) {  // 4 tap.
742     FilterHorizontal<5, true, is_2d, is_compound>(
743         src + 2, src_stride, dst, dst_stride, width, height, v_tap);
744   } else {  // 2 tap.
745     FilterHorizontal<3, true, is_2d, is_compound>(
746         src + 3, src_stride, dst, dst_stride, width, height, v_tap);
747   }
748 }
749 
750 template <int vertical_taps>
Filter2DVertical(const uint16_t * LIBGAV1_RESTRICT const intermediate_result,const int width,const int height,const int16x8_t taps,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)751 void Filter2DVertical(
752     const uint16_t* LIBGAV1_RESTRICT const intermediate_result, const int width,
753     const int height, const int16x8_t taps,
754     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
755   auto* const dest = static_cast<uint8_t*>(prediction);
756   if (width >= 8) {
757     Filter2DVerticalWidth8AndUp<vertical_taps>(
758         intermediate_result, dest, pred_stride, width, height, taps);
759   } else if (width == 4) {
760     Filter2DVerticalWidth4<vertical_taps>(intermediate_result, dest,
761                                           pred_stride, height, taps);
762   } else {
763     assert(width == 2);
764     Filter2DVerticalWidth2<vertical_taps>(intermediate_result, dest,
765                                           pred_stride, height, taps);
766   }
767 }
768 
Convolve2D_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int horizontal_filter_id,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)769 void Convolve2D_NEON(const void* LIBGAV1_RESTRICT const reference,
770                      const ptrdiff_t reference_stride,
771                      const int horizontal_filter_index,
772                      const int vertical_filter_index,
773                      const int horizontal_filter_id,
774                      const int vertical_filter_id, const int width,
775                      const int height, void* LIBGAV1_RESTRICT const prediction,
776                      const ptrdiff_t pred_stride) {
777   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
778   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
779   const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
780 
781   // The output of the horizontal filter is guaranteed to fit in 16 bits.
782   uint16_t
783       intermediate_result[kMaxSuperBlockSizeInPixels *
784                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
785 #if LIBGAV1_MSAN
786   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
787   memset(intermediate_result, 0x33, sizeof(intermediate_result));
788 #endif
789   const int intermediate_height = height + vertical_taps - 1;
790   const ptrdiff_t src_stride = reference_stride;
791   const auto* const src = static_cast<const uint8_t*>(reference) -
792                           (vertical_taps / 2 - 1) * src_stride -
793                           kHorizontalOffset;
794 
795   DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width,
796                                    width, intermediate_height,
797                                    horizontal_filter_id, horiz_filter_index);
798 
799   // Vertical filter.
800   assert(vertical_filter_id != 0);
801   const int16x8_t taps = vmovl_s8(
802       vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]));
803   if (vertical_taps == 8) {
804     Filter2DVertical<8>(intermediate_result, width, height, taps, prediction,
805                         pred_stride);
806   } else if (vertical_taps == 6) {
807     Filter2DVertical<6>(intermediate_result, width, height, taps, prediction,
808                         pred_stride);
809   } else if (vertical_taps == 4) {
810     Filter2DVertical<4>(intermediate_result, width, height, taps, prediction,
811                         pred_stride);
812   } else {  // |vertical_taps| == 2
813     Filter2DVertical<2>(intermediate_result, width, height, taps, prediction,
814                         pred_stride);
815   }
816 }
817 
818 // There are many opportunities for overreading in scaled convolve, because the
819 // range of starting points for filter windows is anywhere from 0 to 16 for 8
820 // destination pixels, and the window sizes range from 2 to 8. To accommodate
821 // this range concisely, we use |grade_x| to mean the most steps in src that can
822 // be traversed in a single |step_x| increment, i.e. 1 or 2. When grade_x is 2,
823 // we are guaranteed to exceed 8 whole steps in src for every 8 |step_x|
824 // increments. The first load covers the initial elements of src_x, while the
825 // final load covers the taps.
826 template <int grade_x>
LoadSrcVals(const uint8_t * const src_x)827 inline uint8x8x3_t LoadSrcVals(const uint8_t* const src_x) {
828   uint8x8x3_t ret;
829   const uint8x16_t src_val = vld1q_u8(src_x);
830   ret.val[0] = vget_low_u8(src_val);
831   ret.val[1] = vget_high_u8(src_val);
832 #if LIBGAV1_MSAN
833   // Initialize to quiet msan warnings when grade_x <= 1.
834   ret.val[2] = vdup_n_u8(0);
835 #endif
836   if (grade_x > 1) {
837     ret.val[2] = vld1_u8(src_x + 16);
838   }
839   return ret;
840 }
841 
842 // Pre-transpose the 2 tap filters in |kAbsHalfSubPixelFilters|[3]
GetPositive2TapFilter(const int tap_index)843 inline uint8x16_t GetPositive2TapFilter(const int tap_index) {
844   assert(tap_index < 2);
845   alignas(
846       16) static constexpr uint8_t kAbsHalfSubPixel2TapFilterColumns[2][16] = {
847       {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4},
848       {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}};
849 
850   return vld1q_u8(kAbsHalfSubPixel2TapFilterColumns[tap_index]);
851 }
852 
853 template <int grade_x>
ConvolveKernelHorizontal2Tap(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int width,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT intermediate)854 inline void ConvolveKernelHorizontal2Tap(
855     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
856     const int width, const int subpixel_x, const int step_x,
857     const int intermediate_height, int16_t* LIBGAV1_RESTRICT intermediate) {
858   // Account for the 0-taps that precede the 2 nonzero taps.
859   const int kernel_offset = 3;
860   const int ref_x = subpixel_x >> kScaleSubPixelBits;
861   const int step_x8 = step_x << 3;
862   const uint8x16_t filter_taps0 = GetPositive2TapFilter(0);
863   const uint8x16_t filter_taps1 = GetPositive2TapFilter(1);
864   const uint16x8_t index_steps = vmulq_n_u16(
865       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
866   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
867 
868   int p = subpixel_x;
869   if (width <= 4) {
870     const uint8_t* src_x =
871         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
872     // Only add steps to the 10-bit truncated p to avoid overflow.
873     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
874     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
875     const uint8x8_t filter_indices =
876         vand_u8(vshrn_n_u16(subpel_index_offsets, 6), filter_index_mask);
877     // This is a special case. The 2-tap filter has no negative taps, so we
878     // can use unsigned values.
879     // For each x, a lane of tapsK has
880     // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
881     // on x.
882     const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices),
883                                VQTbl1U8(filter_taps1, filter_indices)};
884     int y = intermediate_height;
885     do {
886       // Load a pool of samples to select from using stepped indices.
887       const uint8x16_t src_vals = vld1q_u8(src_x);
888       const uint8x8_t src_indices =
889           vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
890 
891       // For each x, a lane of srcK contains src_x[k].
892       const uint8x8_t src[2] = {
893           VQTbl1U8(src_vals, src_indices),
894           VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))};
895 
896       vst1q_s16(intermediate,
897                 vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps),
898                              kInterRoundBitsHorizontal - 1));
899       src_x += src_stride;
900       intermediate += kIntermediateStride;
901     } while (--y != 0);
902     return;
903   }
904 
905   // |width| >= 8
906   int x = 0;
907   do {
908     const uint8_t* src_x =
909         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
910     // Only add steps to the 10-bit truncated p to avoid overflow.
911     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
912     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
913     const uint8x8_t filter_indices =
914         vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
915                 filter_index_mask);
916     // This is a special case. The 2-tap filter has no negative taps, so we
917     // can use unsigned values.
918     // For each x, a lane of tapsK has
919     // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
920     // on x.
921     const uint8x8_t taps[2] = {VQTbl1U8(filter_taps0, filter_indices),
922                                VQTbl1U8(filter_taps1, filter_indices)};
923     int y = intermediate_height;
924     do {
925       // Load a pool of samples to select from using stepped indices.
926       const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
927       const uint8x8_t src_indices =
928           vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
929 
930       // For each x, a lane of srcK contains src_x[k].
931       const uint8x8_t src[2] = {
932           vtbl3_u8(src_vals, src_indices),
933           vtbl3_u8(src_vals, vadd_u8(src_indices, vdup_n_u8(1)))};
934 
935       vst1q_s16(intermediate,
936                 vrshrq_n_s16(SumOnePassTaps</*filter_index=*/3>(src, taps),
937                              kInterRoundBitsHorizontal - 1));
938       src_x += src_stride;
939       intermediate += kIntermediateStride;
940     } while (--y != 0);
941     x += 8;
942     p += step_x8;
943   } while (x < width);
944 }
945 
946 // Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[5].
GetPositive4TapFilter(const int tap_index)947 inline uint8x16_t GetPositive4TapFilter(const int tap_index) {
948   assert(tap_index < 4);
949   alignas(
950       16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
951       {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1},
952       {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
953       {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
954       {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}};
955 
956   return vld1q_u8(kSubPixel4TapPositiveFilterColumns[tap_index]);
957 }
958 
959 // This filter is only possible when width <= 4.
ConvolveKernelHorizontalPositive4Tap(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT intermediate)960 void ConvolveKernelHorizontalPositive4Tap(
961     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
962     const int subpixel_x, const int step_x, const int intermediate_height,
963     int16_t* LIBGAV1_RESTRICT intermediate) {
964   const int kernel_offset = 2;
965   const int ref_x = subpixel_x >> kScaleSubPixelBits;
966   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
967   const uint8x16_t filter_taps0 = GetPositive4TapFilter(0);
968   const uint8x16_t filter_taps1 = GetPositive4TapFilter(1);
969   const uint8x16_t filter_taps2 = GetPositive4TapFilter(2);
970   const uint8x16_t filter_taps3 = GetPositive4TapFilter(3);
971   const uint16x8_t index_steps = vmulq_n_u16(
972       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
973   const int p = subpixel_x;
974   // First filter is special, just a 128 tap on the center.
975   const uint8_t* src_x =
976       &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
977   // Only add steps to the 10-bit truncated p to avoid overflow.
978   const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
979   const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
980   const uint8x8_t filter_indices = vand_u8(
981       vshrn_n_u16(subpel_index_offsets, kFilterIndexShift), filter_index_mask);
982   // Note that filter_id depends on x.
983   // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k].
984   const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices),
985                              VQTbl1U8(filter_taps1, filter_indices),
986                              VQTbl1U8(filter_taps2, filter_indices),
987                              VQTbl1U8(filter_taps3, filter_indices)};
988 
989   const uint8x8_t src_indices =
990       vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
991   int y = intermediate_height;
992   do {
993     // Load a pool of samples to select from using stepped index vectors.
994     const uint8x16_t src_vals = vld1q_u8(src_x);
995 
996     // For each x, srcK contains src_x[k] where k=1.
997     // Whereas taps come from different arrays, src pixels are drawn from the
998     // same contiguous line.
999     const uint8x8_t src[4] = {
1000         VQTbl1U8(src_vals, src_indices),
1001         VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(1))),
1002         VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(2))),
1003         VQTbl1U8(src_vals, vadd_u8(src_indices, vdup_n_u8(3)))};
1004 
1005     vst1q_s16(intermediate,
1006               vrshrq_n_s16(SumOnePassTaps</*filter_index=*/5>(src, taps),
1007                            kInterRoundBitsHorizontal - 1));
1008 
1009     src_x += src_stride;
1010     intermediate += kIntermediateStride;
1011   } while (--y != 0);
1012 }
1013 
1014 // Pre-transpose the 4 tap filters in |kAbsHalfSubPixelFilters|[4].
GetSigned4TapFilter(const int tap_index)1015 inline uint8x16_t GetSigned4TapFilter(const int tap_index) {
1016   assert(tap_index < 4);
1017   alignas(16) static constexpr uint8_t
1018       kAbsHalfSubPixel4TapSignedFilterColumns[4][16] = {
1019           {0, 2, 4, 5, 6, 6, 7, 6, 6, 5, 5, 5, 4, 3, 2, 1},
1020           {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
1021           {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
1022           {0, 1, 2, 3, 4, 5, 5, 5, 6, 6, 7, 6, 6, 5, 4, 2}};
1023 
1024   return vld1q_u8(kAbsHalfSubPixel4TapSignedFilterColumns[tap_index]);
1025 }
1026 
1027 // This filter is only possible when width <= 4.
ConvolveKernelHorizontalSigned4Tap(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT intermediate)1028 inline void ConvolveKernelHorizontalSigned4Tap(
1029     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
1030     const int subpixel_x, const int step_x, const int intermediate_height,
1031     int16_t* LIBGAV1_RESTRICT intermediate) {
1032   const int kernel_offset = 2;
1033   const int ref_x = subpixel_x >> kScaleSubPixelBits;
1034   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
1035   const uint8x16_t filter_taps0 = GetSigned4TapFilter(0);
1036   const uint8x16_t filter_taps1 = GetSigned4TapFilter(1);
1037   const uint8x16_t filter_taps2 = GetSigned4TapFilter(2);
1038   const uint8x16_t filter_taps3 = GetSigned4TapFilter(3);
1039   const uint16x4_t index_steps = vmul_n_u16(vcreate_u16(0x0003000200010000),
1040                                             static_cast<uint16_t>(step_x));
1041 
1042   const int p = subpixel_x;
1043   const uint8_t* src_x =
1044       &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
1045   // Only add steps to the 10-bit truncated p to avoid overflow.
1046   const uint16x4_t p_fraction = vdup_n_u16(p & 1023);
1047   const uint16x4_t subpel_index_offsets = vadd_u16(index_steps, p_fraction);
1048   const uint8x8_t filter_index_offsets = vshrn_n_u16(
1049       vcombine_u16(subpel_index_offsets, vdup_n_u16(0)), kFilterIndexShift);
1050   const uint8x8_t filter_indices =
1051       vand_u8(filter_index_offsets, filter_index_mask);
1052   // Note that filter_id depends on x.
1053   // For each x, tapsK has kSubPixelFilters[filter_index][filter_id][k].
1054   const uint8x8_t taps[4] = {VQTbl1U8(filter_taps0, filter_indices),
1055                              VQTbl1U8(filter_taps1, filter_indices),
1056                              VQTbl1U8(filter_taps2, filter_indices),
1057                              VQTbl1U8(filter_taps3, filter_indices)};
1058 
1059   const uint8x8_t src_indices_base =
1060       vshr_n_u8(filter_index_offsets, kScaleSubPixelBits - kFilterIndexShift);
1061 
1062   const uint8x8_t src_indices[4] = {src_indices_base,
1063                                     vadd_u8(src_indices_base, vdup_n_u8(1)),
1064                                     vadd_u8(src_indices_base, vdup_n_u8(2)),
1065                                     vadd_u8(src_indices_base, vdup_n_u8(3))};
1066 
1067   int y = intermediate_height;
1068   do {
1069     // Load a pool of samples to select from using stepped indices.
1070     const uint8x16_t src_vals = vld1q_u8(src_x);
1071 
1072     // For each x, srcK contains src_x[k] where k=1.
1073     // Whereas taps come from different arrays, src pixels are drawn from the
1074     // same contiguous line.
1075     const uint8x8_t src[4] = {
1076         VQTbl1U8(src_vals, src_indices[0]), VQTbl1U8(src_vals, src_indices[1]),
1077         VQTbl1U8(src_vals, src_indices[2]), VQTbl1U8(src_vals, src_indices[3])};
1078 
1079     vst1q_s16(intermediate,
1080               vrshrq_n_s16(SumOnePassTaps</*filter_index=*/4>(src, taps),
1081                            kInterRoundBitsHorizontal - 1));
1082     src_x += src_stride;
1083     intermediate += kIntermediateStride;
1084   } while (--y != 0);
1085 }
1086 
1087 // Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[0].
GetSigned6TapFilter(const int tap_index)1088 inline uint8x16_t GetSigned6TapFilter(const int tap_index) {
1089   assert(tap_index < 6);
1090   alignas(16) static constexpr uint8_t
1091       kAbsHalfSubPixel6TapSignedFilterColumns[6][16] = {
1092           {0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0},
1093           {0, 3, 5, 6, 7, 7, 8, 7, 7, 6, 6, 6, 5, 4, 2, 1},
1094           {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
1095           {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
1096           {0, 1, 2, 4, 5, 6, 6, 6, 7, 7, 8, 7, 7, 6, 5, 3},
1097           {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
1098 
1099   return vld1q_u8(kAbsHalfSubPixel6TapSignedFilterColumns[tap_index]);
1100 }
1101 
1102 // This filter is only possible when width >= 8.
1103 template <int grade_x>
ConvolveKernelHorizontalSigned6Tap(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int width,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT const intermediate)1104 inline void ConvolveKernelHorizontalSigned6Tap(
1105     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
1106     const int width, const int subpixel_x, const int step_x,
1107     const int intermediate_height,
1108     int16_t* LIBGAV1_RESTRICT const intermediate) {
1109   const int kernel_offset = 1;
1110   const uint8x8_t one = vdup_n_u8(1);
1111   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
1112   const int ref_x = subpixel_x >> kScaleSubPixelBits;
1113   const int step_x8 = step_x << 3;
1114   uint8x16_t filter_taps[6];
1115   for (int i = 0; i < 6; ++i) {
1116     filter_taps[i] = GetSigned6TapFilter(i);
1117   }
1118   const uint16x8_t index_steps = vmulq_n_u16(
1119       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
1120 
1121   int16_t* intermediate_x = intermediate;
1122   int x = 0;
1123   int p = subpixel_x;
1124   do {
1125     // Avoid overloading outside the reference boundaries. This means
1126     // |trailing_width| can be up to 24.
1127     const uint8_t* src_x =
1128         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
1129     // Only add steps to the 10-bit truncated p to avoid overflow.
1130     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
1131     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
1132     const uint8x8_t src_indices =
1133         vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
1134     uint8x8_t src_lookup[6];
1135     src_lookup[0] = src_indices;
1136     for (int i = 1; i < 6; ++i) {
1137       src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
1138     }
1139 
1140     const uint8x8_t filter_indices =
1141         vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
1142                 filter_index_mask);
1143     // For each x, a lane of taps[k] has
1144     // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
1145     // on x.
1146     uint8x8_t taps[6];
1147     for (int i = 0; i < 6; ++i) {
1148       taps[i] = VQTbl1U8(filter_taps[i], filter_indices);
1149     }
1150     int y = intermediate_height;
1151     do {
1152       // Load a pool of samples to select from using stepped indices.
1153       const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
1154 
1155       const uint8x8_t src[6] = {
1156           vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]),
1157           vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]),
1158           vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5])};
1159 
1160       vst1q_s16(intermediate_x,
1161                 vrshrq_n_s16(SumOnePassTaps</*filter_index=*/0>(src, taps),
1162                              kInterRoundBitsHorizontal - 1));
1163       src_x += src_stride;
1164       intermediate_x += kIntermediateStride;
1165     } while (--y != 0);
1166     x += 8;
1167     p += step_x8;
1168   } while (x < width);
1169 }
1170 
1171 // Pre-transpose the 6 tap filters in |kAbsHalfSubPixelFilters|[1]. This filter
1172 // has mixed positive and negative outer taps which are handled in
1173 // GetMixed6TapFilter().
GetPositive6TapFilter(const int tap_index)1174 inline uint8x16_t GetPositive6TapFilter(const int tap_index) {
1175   assert(tap_index < 6);
1176   alignas(16) static constexpr uint8_t
1177       kAbsHalfSubPixel6TapPositiveFilterColumns[4][16] = {
1178           {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1},
1179           {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
1180           {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
1181           {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14}};
1182 
1183   return vld1q_u8(kAbsHalfSubPixel6TapPositiveFilterColumns[tap_index]);
1184 }
1185 
GetMixed6TapFilter(const int tap_index)1186 inline int8x16_t GetMixed6TapFilter(const int tap_index) {
1187   assert(tap_index < 2);
1188   alignas(
1189       16) static constexpr int8_t kHalfSubPixel6TapMixedFilterColumns[2][16] = {
1190       {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0},
1191       {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}};
1192 
1193   return vld1q_s8(kHalfSubPixel6TapMixedFilterColumns[tap_index]);
1194 }
1195 
1196 // This filter is only possible when width >= 8.
1197 template <int grade_x>
ConvolveKernelHorizontalMixed6Tap(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int width,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT const intermediate)1198 inline void ConvolveKernelHorizontalMixed6Tap(
1199     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
1200     const int width, const int subpixel_x, const int step_x,
1201     const int intermediate_height,
1202     int16_t* LIBGAV1_RESTRICT const intermediate) {
1203   const int kernel_offset = 1;
1204   const uint8x8_t one = vdup_n_u8(1);
1205   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
1206   const int ref_x = subpixel_x >> kScaleSubPixelBits;
1207   const int step_x8 = step_x << 3;
1208   uint8x8_t taps[4];
1209   int16x8_t mixed_taps[2];
1210   uint8x16_t positive_filter_taps[4];
1211   for (int i = 0; i < 4; ++i) {
1212     positive_filter_taps[i] = GetPositive6TapFilter(i);
1213   }
1214   int8x16_t mixed_filter_taps[2];
1215   mixed_filter_taps[0] = GetMixed6TapFilter(0);
1216   mixed_filter_taps[1] = GetMixed6TapFilter(1);
1217   const uint16x8_t index_steps = vmulq_n_u16(
1218       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
1219 
1220   int16_t* intermediate_x = intermediate;
1221   int x = 0;
1222   int p = subpixel_x;
1223   do {
1224     const uint8_t* src_x =
1225         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
1226     // Only add steps to the 10-bit truncated p to avoid overflow.
1227     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
1228     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
1229     const uint8x8_t src_indices =
1230         vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
1231     uint8x8_t src_lookup[6];
1232     src_lookup[0] = src_indices;
1233     for (int i = 1; i < 6; ++i) {
1234       src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
1235     }
1236 
1237     const uint8x8_t filter_indices =
1238         vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
1239                 filter_index_mask);
1240     // For each x, a lane of taps[k] has
1241     // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
1242     // on x.
1243     for (int i = 0; i < 4; ++i) {
1244       taps[i] = VQTbl1U8(positive_filter_taps[i], filter_indices);
1245     }
1246     mixed_taps[0] = vmovl_s8(VQTbl1S8(mixed_filter_taps[0], filter_indices));
1247     mixed_taps[1] = vmovl_s8(VQTbl1S8(mixed_filter_taps[1], filter_indices));
1248 
1249     int y = intermediate_height;
1250     do {
1251       // Load a pool of samples to select from using stepped indices.
1252       const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
1253 
1254       int16x8_t sum_mixed = vmulq_s16(
1255           mixed_taps[0], ZeroExtend(vtbl3_u8(src_vals, src_lookup[0])));
1256       sum_mixed = vmlaq_s16(sum_mixed, mixed_taps[1],
1257                             ZeroExtend(vtbl3_u8(src_vals, src_lookup[5])));
1258       uint16x8_t sum = vreinterpretq_u16_s16(sum_mixed);
1259       sum = vmlal_u8(sum, taps[0], vtbl3_u8(src_vals, src_lookup[1]));
1260       sum = vmlal_u8(sum, taps[1], vtbl3_u8(src_vals, src_lookup[2]));
1261       sum = vmlal_u8(sum, taps[2], vtbl3_u8(src_vals, src_lookup[3]));
1262       sum = vmlal_u8(sum, taps[3], vtbl3_u8(src_vals, src_lookup[4]));
1263 
1264       vst1q_s16(intermediate_x, vrshrq_n_s16(vreinterpretq_s16_u16(sum),
1265                                              kInterRoundBitsHorizontal - 1));
1266       src_x += src_stride;
1267       intermediate_x += kIntermediateStride;
1268     } while (--y != 0);
1269     x += 8;
1270     p += step_x8;
1271   } while (x < width);
1272 }
1273 
1274 // Pre-transpose the 8 tap filters in |kAbsHalfSubPixelFilters|[2].
GetSigned8TapFilter(const int tap_index)1275 inline uint8x16_t GetSigned8TapFilter(const int tap_index) {
1276   assert(tap_index < 8);
1277   alignas(16) static constexpr uint8_t
1278       kAbsHalfSubPixel8TapSignedFilterColumns[8][16] = {
1279           {0, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1, 1, 1, 1, 0},
1280           {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1},
1281           {0, 3, 6, 9, 11, 11, 12, 12, 12, 11, 10, 9, 7, 5, 3, 1},
1282           {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4},
1283           {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63},
1284           {0, 1, 3, 5, 7, 9, 10, 11, 12, 12, 12, 11, 11, 9, 6, 3},
1285           {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1},
1286           {0, 0, 1, 1, 1, 1, 1, 1, 2, 2, 2, 2, 2, 1, 1, 1}};
1287 
1288   return vld1q_u8(kAbsHalfSubPixel8TapSignedFilterColumns[tap_index]);
1289 }
1290 
1291 // This filter is only possible when width >= 8.
1292 template <int grade_x>
ConvolveKernelHorizontalSigned8Tap(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,const int width,const int subpixel_x,const int step_x,const int intermediate_height,int16_t * LIBGAV1_RESTRICT const intermediate)1293 inline void ConvolveKernelHorizontalSigned8Tap(
1294     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
1295     const int width, const int subpixel_x, const int step_x,
1296     const int intermediate_height,
1297     int16_t* LIBGAV1_RESTRICT const intermediate) {
1298   const uint8x8_t one = vdup_n_u8(1);
1299   const uint8x8_t filter_index_mask = vdup_n_u8(kSubPixelMask);
1300   const int ref_x = subpixel_x >> kScaleSubPixelBits;
1301   const int step_x8 = step_x << 3;
1302   uint8x8_t taps[8];
1303   uint8x16_t filter_taps[8];
1304   for (int i = 0; i < 8; ++i) {
1305     filter_taps[i] = GetSigned8TapFilter(i);
1306   }
1307   const uint16x8_t index_steps = vmulq_n_u16(
1308       vmovl_u8(vcreate_u8(0x0706050403020100)), static_cast<uint16_t>(step_x));
1309 
1310   int16_t* intermediate_x = intermediate;
1311   int x = 0;
1312   int p = subpixel_x;
1313   do {
1314     const uint8_t* src_x = &src[(p >> kScaleSubPixelBits) - ref_x];
1315     // Only add steps to the 10-bit truncated p to avoid overflow.
1316     const uint16x8_t p_fraction = vdupq_n_u16(p & 1023);
1317     const uint16x8_t subpel_index_offsets = vaddq_u16(index_steps, p_fraction);
1318     const uint8x8_t src_indices =
1319         vmovn_u16(vshrq_n_u16(subpel_index_offsets, kScaleSubPixelBits));
1320     uint8x8_t src_lookup[8];
1321     src_lookup[0] = src_indices;
1322     for (int i = 1; i < 8; ++i) {
1323       src_lookup[i] = vadd_u8(src_lookup[i - 1], one);
1324     }
1325 
1326     const uint8x8_t filter_indices =
1327         vand_u8(vshrn_n_u16(subpel_index_offsets, kFilterIndexShift),
1328                 filter_index_mask);
1329     // For each x, a lane of taps[k] has
1330     // kSubPixelFilters[filter_index][filter_id][k], where filter_id depends
1331     // on x.
1332     for (int i = 0; i < 8; ++i) {
1333       taps[i] = VQTbl1U8(filter_taps[i], filter_indices);
1334     }
1335 
1336     int y = intermediate_height;
1337     do {
1338       // Load a pool of samples to select from using stepped indices.
1339       const uint8x8x3_t src_vals = LoadSrcVals<grade_x>(src_x);
1340 
1341       const uint8x8_t src[8] = {
1342           vtbl3_u8(src_vals, src_lookup[0]), vtbl3_u8(src_vals, src_lookup[1]),
1343           vtbl3_u8(src_vals, src_lookup[2]), vtbl3_u8(src_vals, src_lookup[3]),
1344           vtbl3_u8(src_vals, src_lookup[4]), vtbl3_u8(src_vals, src_lookup[5]),
1345           vtbl3_u8(src_vals, src_lookup[6]), vtbl3_u8(src_vals, src_lookup[7])};
1346 
1347       vst1q_s16(intermediate_x,
1348                 vrshrq_n_s16(SumOnePassTaps</*filter_index=*/2>(src, taps),
1349                              kInterRoundBitsHorizontal - 1));
1350       src_x += src_stride;
1351       intermediate_x += kIntermediateStride;
1352     } while (--y != 0);
1353     x += 8;
1354     p += step_x8;
1355   } while (x < width);
1356 }
1357 
1358 // This function handles blocks of width 2 or 4.
1359 template <int num_taps, int grade_y, int width, bool is_compound>
ConvolveVerticalScale4xH(const int16_t * LIBGAV1_RESTRICT const src,const int subpixel_y,const int filter_index,const int step_y,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)1360 void ConvolveVerticalScale4xH(const int16_t* LIBGAV1_RESTRICT const src,
1361                               const int subpixel_y, const int filter_index,
1362                               const int step_y, const int height,
1363                               void* LIBGAV1_RESTRICT const dest,
1364                               const ptrdiff_t dest_stride) {
1365   constexpr ptrdiff_t src_stride = kIntermediateStride;
1366   const int16_t* src_y = src;
1367   // |dest| is 16-bit in compound mode, Pixel otherwise.
1368   auto* dest16_y = static_cast<uint16_t*>(dest);
1369   auto* dest_y = static_cast<uint8_t*>(dest);
1370   int16x4_t s[num_taps + grade_y];
1371 
1372   int p = subpixel_y & 1023;
1373   int prev_p = p;
1374   int y = height;
1375   do {
1376     for (int i = 0; i < num_taps; ++i) {
1377       s[i] = vld1_s16(src_y + i * src_stride);
1378     }
1379     int filter_id = (p >> 6) & kSubPixelMask;
1380     int16x8_t filter =
1381         vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
1382     int16x4_t sums = Sum2DVerticalTaps4<num_taps, is_compound>(s, filter);
1383     if (is_compound) {
1384       assert(width != 2);
1385       const uint16x4_t result = vreinterpret_u16_s16(sums);
1386       vst1_u16(dest16_y, result);
1387     } else {
1388       const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums));
1389       if (width == 2) {
1390         Store2<0>(dest_y, result);
1391       } else {
1392         StoreLo4(dest_y, result);
1393       }
1394     }
1395     p += step_y;
1396     const int p_diff =
1397         (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
1398     prev_p = p;
1399     // Here we load extra source in case it is needed. If |p_diff| == 0, these
1400     // values will be unused, but it's faster to load than to branch.
1401     s[num_taps] = vld1_s16(src_y + num_taps * src_stride);
1402     if (grade_y > 1) {
1403       s[num_taps + 1] = vld1_s16(src_y + (num_taps + 1) * src_stride);
1404     }
1405     dest16_y += dest_stride;
1406     dest_y += dest_stride;
1407 
1408     filter_id = (p >> 6) & kSubPixelMask;
1409     filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
1410     sums = Sum2DVerticalTaps4<num_taps, is_compound>(&s[p_diff], filter);
1411     if (is_compound) {
1412       assert(width != 2);
1413       const uint16x4_t result = vreinterpret_u16_s16(sums);
1414       vst1_u16(dest16_y, result);
1415     } else {
1416       const uint8x8_t result = vqmovun_s16(vcombine_s16(sums, sums));
1417       if (width == 2) {
1418         Store2<0>(dest_y, result);
1419       } else {
1420         StoreLo4(dest_y, result);
1421       }
1422     }
1423     p += step_y;
1424     src_y = src + (p >> kScaleSubPixelBits) * src_stride;
1425     prev_p = p;
1426     dest16_y += dest_stride;
1427     dest_y += dest_stride;
1428     y -= 2;
1429   } while (y != 0);
1430 }
1431 
1432 template <int num_taps, int grade_y, bool is_compound>
ConvolveVerticalScale(const int16_t * LIBGAV1_RESTRICT const source,const int intermediate_height,const int width,const int subpixel_y,const int filter_index,const int step_y,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)1433 inline void ConvolveVerticalScale(const int16_t* LIBGAV1_RESTRICT const source,
1434                                   const int intermediate_height,
1435                                   const int width, const int subpixel_y,
1436                                   const int filter_index, const int step_y,
1437                                   const int height,
1438                                   void* LIBGAV1_RESTRICT const dest,
1439                                   const ptrdiff_t dest_stride) {
1440   constexpr ptrdiff_t src_stride = kIntermediateStride;
1441   // A possible improvement is to use arithmetic to decide how many times to
1442   // apply filters to same source before checking whether to load new srcs.
1443   // However, this will only improve performance with very small step sizes.
1444   int16x8_t s[num_taps + grade_y];
1445   // |dest| is 16-bit in compound mode, Pixel otherwise.
1446   uint16_t* dest16_y;
1447   uint8_t* dest_y;
1448   const int16_t* src = source;
1449 
1450   int x = 0;
1451   do {
1452     const int16_t* src_y = src;
1453     dest16_y = static_cast<uint16_t*>(dest) + x;
1454     dest_y = static_cast<uint8_t*>(dest) + x;
1455     int p = subpixel_y & 1023;
1456     int prev_p = p;
1457     int y = height;
1458     do {
1459       for (int i = 0; i < num_taps; ++i) {
1460         s[i] = vld1q_s16(src_y + i * src_stride);
1461       }
1462       int filter_id = (p >> 6) & kSubPixelMask;
1463       int16x8_t filter =
1464           vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
1465       int16x8_t sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(s, filter);
1466       if (is_compound) {
1467         vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum));
1468       } else {
1469         vst1_u8(dest_y, vqmovun_s16(sum));
1470       }
1471       p += step_y;
1472       const int p_diff =
1473           (p >> kScaleSubPixelBits) - (prev_p >> kScaleSubPixelBits);
1474       // |grade_y| > 1 always means p_diff > 0, so load vectors that may be
1475       // needed. Otherwise, we only need to load one vector because |p_diff|
1476       // can't exceed 1.
1477       s[num_taps] = vld1q_s16(src_y + num_taps * src_stride);
1478       if (grade_y > 1) {
1479         s[num_taps + 1] = vld1q_s16(src_y + (num_taps + 1) * src_stride);
1480       }
1481       dest16_y += dest_stride;
1482       dest_y += dest_stride;
1483 
1484       filter_id = (p >> 6) & kSubPixelMask;
1485       filter = vmovl_s8(vld1_s8(kHalfSubPixelFilters[filter_index][filter_id]));
1486       sum = SimpleSum2DVerticalTaps<num_taps, is_compound>(&s[p_diff], filter);
1487       if (is_compound) {
1488         vst1q_u16(dest16_y, vreinterpretq_u16_s16(sum));
1489       } else {
1490         vst1_u8(dest_y, vqmovun_s16(sum));
1491       }
1492       p += step_y;
1493       src_y = src + (p >> kScaleSubPixelBits) * src_stride;
1494       prev_p = p;
1495       dest16_y += dest_stride;
1496       dest_y += dest_stride;
1497       y -= 2;
1498     } while (y != 0);
1499     src += kIntermediateStride * intermediate_height;
1500     x += 8;
1501   } while (x < width);
1502 }
1503 
1504 template <bool is_compound>
ConvolveScale2D_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int subpixel_x,const int subpixel_y,const int step_x,const int step_y,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1505 void ConvolveScale2D_NEON(const void* LIBGAV1_RESTRICT const reference,
1506                           const ptrdiff_t reference_stride,
1507                           const int horizontal_filter_index,
1508                           const int vertical_filter_index, const int subpixel_x,
1509                           const int subpixel_y, const int step_x,
1510                           const int step_y, const int width, const int height,
1511                           void* LIBGAV1_RESTRICT const prediction,
1512                           const ptrdiff_t pred_stride) {
1513   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
1514   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
1515   assert(step_x <= 2048);
1516   assert(step_y <= 2048);
1517   const int num_vert_taps = GetNumTapsInFilter(vert_filter_index);
1518   const int intermediate_height =
1519       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
1520        kScaleSubPixelBits) +
1521       num_vert_taps;
1522   // The output of the horizontal filter, i.e. the intermediate_result, is
1523   // guaranteed to fit in int16_t.
1524   int16_t intermediate_result[kIntermediateAllocWidth *
1525                               (2 * kIntermediateAllocWidth + 8)];
1526 #if LIBGAV1_MSAN
1527   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
1528   memset(intermediate_result, 0x44, sizeof(intermediate_result));
1529 #endif
1530   // Horizontal filter.
1531   // Filter types used for width <= 4 are different from those for width > 4.
1532   // When width > 4, the valid filter index range is always [0, 3].
1533   // When width <= 4, the valid filter index range is always [3, 5].
1534   // Similarly for height.
1535   int filter_index = GetFilterIndex(horizontal_filter_index, width);
1536   int16_t* intermediate = intermediate_result;
1537   const ptrdiff_t src_stride = reference_stride;
1538   const auto* src = static_cast<const uint8_t*>(reference);
1539   const int vert_kernel_offset = (8 - num_vert_taps) / 2;
1540   src += vert_kernel_offset * src_stride;
1541 
1542   // Derive the maximum value of |step_x| at which all source values fit in one
1543   // 16-byte load. Final index is src_x + |num_taps| - 1 < 16
1544   // step_x*7 is the final base subpel index for the shuffle mask for filter
1545   // inputs in each iteration on large blocks. When step_x is large, we need a
1546   // larger structure and use a larger table lookup in order to gather all
1547   // filter inputs.
1548   // |num_taps| - 1 is the shuffle index of the final filter input.
1549   const int num_horiz_taps = GetNumTapsInFilter(horiz_filter_index);
1550   const int kernel_start_ceiling = 16 - num_horiz_taps;
1551   // This truncated quotient |grade_x_threshold| selects |step_x| such that:
1552   // (step_x * 7) >> kScaleSubPixelBits < single load limit
1553   const int grade_x_threshold =
1554       (kernel_start_ceiling << kScaleSubPixelBits) / 7;
1555   switch (filter_index) {
1556     case 0:
1557       if (step_x > grade_x_threshold) {
1558         ConvolveKernelHorizontalSigned6Tap<2>(
1559             src, src_stride, width, subpixel_x, step_x, intermediate_height,
1560             intermediate);
1561       } else {
1562         ConvolveKernelHorizontalSigned6Tap<1>(
1563             src, src_stride, width, subpixel_x, step_x, intermediate_height,
1564             intermediate);
1565       }
1566       break;
1567     case 1:
1568       if (step_x > grade_x_threshold) {
1569         ConvolveKernelHorizontalMixed6Tap<2>(src, src_stride, width, subpixel_x,
1570                                              step_x, intermediate_height,
1571                                              intermediate);
1572 
1573       } else {
1574         ConvolveKernelHorizontalMixed6Tap<1>(src, src_stride, width, subpixel_x,
1575                                              step_x, intermediate_height,
1576                                              intermediate);
1577       }
1578       break;
1579     case 2:
1580       if (step_x > grade_x_threshold) {
1581         ConvolveKernelHorizontalSigned8Tap<2>(
1582             src, src_stride, width, subpixel_x, step_x, intermediate_height,
1583             intermediate);
1584       } else {
1585         ConvolveKernelHorizontalSigned8Tap<1>(
1586             src, src_stride, width, subpixel_x, step_x, intermediate_height,
1587             intermediate);
1588       }
1589       break;
1590     case 3:
1591       if (step_x > grade_x_threshold) {
1592         ConvolveKernelHorizontal2Tap<2>(src, src_stride, width, subpixel_x,
1593                                         step_x, intermediate_height,
1594                                         intermediate);
1595       } else {
1596         ConvolveKernelHorizontal2Tap<1>(src, src_stride, width, subpixel_x,
1597                                         step_x, intermediate_height,
1598                                         intermediate);
1599       }
1600       break;
1601     case 4:
1602       assert(width <= 4);
1603       ConvolveKernelHorizontalSigned4Tap(src, src_stride, subpixel_x, step_x,
1604                                          intermediate_height, intermediate);
1605       break;
1606     default:
1607       assert(filter_index == 5);
1608       ConvolveKernelHorizontalPositive4Tap(src, src_stride, subpixel_x, step_x,
1609                                            intermediate_height, intermediate);
1610   }
1611   // Vertical filter.
1612   filter_index = GetFilterIndex(vertical_filter_index, height);
1613   intermediate = intermediate_result;
1614 
1615   switch (filter_index) {
1616     case 0:
1617     case 1:
1618       if (step_y <= 1024) {
1619         if (!is_compound && width == 2) {
1620           ConvolveVerticalScale4xH<6, 1, 2, is_compound>(
1621               intermediate, subpixel_y, filter_index, step_y, height,
1622               prediction, pred_stride);
1623         } else if (width == 4) {
1624           ConvolveVerticalScale4xH<6, 1, 4, is_compound>(
1625               intermediate, subpixel_y, filter_index, step_y, height,
1626               prediction, pred_stride);
1627         } else {
1628           ConvolveVerticalScale<6, 1, is_compound>(
1629               intermediate, intermediate_height, width, subpixel_y,
1630               filter_index, step_y, height, prediction, pred_stride);
1631         }
1632       } else {
1633         if (!is_compound && width == 2) {
1634           ConvolveVerticalScale4xH<6, 2, 2, is_compound>(
1635               intermediate, subpixel_y, filter_index, step_y, height,
1636               prediction, pred_stride);
1637         } else if (width == 4) {
1638           ConvolveVerticalScale4xH<6, 2, 4, is_compound>(
1639               intermediate, subpixel_y, filter_index, step_y, height,
1640               prediction, pred_stride);
1641         } else {
1642           ConvolveVerticalScale<6, 2, is_compound>(
1643               intermediate, intermediate_height, width, subpixel_y,
1644               filter_index, step_y, height, prediction, pred_stride);
1645         }
1646       }
1647       break;
1648     case 2:
1649       if (step_y <= 1024) {
1650         if (!is_compound && width == 2) {
1651           ConvolveVerticalScale4xH<8, 1, 2, is_compound>(
1652               intermediate, subpixel_y, filter_index, step_y, height,
1653               prediction, pred_stride);
1654         } else if (width == 4) {
1655           ConvolveVerticalScale4xH<8, 1, 4, is_compound>(
1656               intermediate, subpixel_y, filter_index, step_y, height,
1657               prediction, pred_stride);
1658         } else {
1659           ConvolveVerticalScale<8, 1, is_compound>(
1660               intermediate, intermediate_height, width, subpixel_y,
1661               filter_index, step_y, height, prediction, pred_stride);
1662         }
1663       } else {
1664         if (!is_compound && width == 2) {
1665           ConvolveVerticalScale4xH<8, 2, 2, is_compound>(
1666               intermediate, subpixel_y, filter_index, step_y, height,
1667               prediction, pred_stride);
1668         } else if (width == 4) {
1669           ConvolveVerticalScale4xH<8, 2, 4, is_compound>(
1670               intermediate, subpixel_y, filter_index, step_y, height,
1671               prediction, pred_stride);
1672         } else {
1673           ConvolveVerticalScale<8, 2, is_compound>(
1674               intermediate, intermediate_height, width, subpixel_y,
1675               filter_index, step_y, height, prediction, pred_stride);
1676         }
1677       }
1678       break;
1679     case 3:
1680       if (step_y <= 1024) {
1681         if (!is_compound && width == 2) {
1682           ConvolveVerticalScale4xH<2, 1, 2, is_compound>(
1683               intermediate, subpixel_y, filter_index, step_y, height,
1684               prediction, pred_stride);
1685         } else if (width == 4) {
1686           ConvolveVerticalScale4xH<2, 1, 4, is_compound>(
1687               intermediate, subpixel_y, filter_index, step_y, height,
1688               prediction, pred_stride);
1689         } else {
1690           ConvolveVerticalScale<2, 1, is_compound>(
1691               intermediate, intermediate_height, width, subpixel_y,
1692               filter_index, step_y, height, prediction, pred_stride);
1693         }
1694       } else {
1695         if (!is_compound && width == 2) {
1696           ConvolveVerticalScale4xH<2, 2, 2, is_compound>(
1697               intermediate, subpixel_y, filter_index, step_y, height,
1698               prediction, pred_stride);
1699         } else if (width == 4) {
1700           ConvolveVerticalScale4xH<2, 2, 4, is_compound>(
1701               intermediate, subpixel_y, filter_index, step_y, height,
1702               prediction, pred_stride);
1703         } else {
1704           ConvolveVerticalScale<2, 2, is_compound>(
1705               intermediate, intermediate_height, width, subpixel_y,
1706               filter_index, step_y, height, prediction, pred_stride);
1707         }
1708       }
1709       break;
1710     case 4:
1711     default:
1712       assert(filter_index == 4 || filter_index == 5);
1713       assert(height <= 4);
1714       if (step_y <= 1024) {
1715         if (!is_compound && width == 2) {
1716           ConvolveVerticalScale4xH<4, 1, 2, is_compound>(
1717               intermediate, subpixel_y, filter_index, step_y, height,
1718               prediction, pred_stride);
1719         } else if (width == 4) {
1720           ConvolveVerticalScale4xH<4, 1, 4, is_compound>(
1721               intermediate, subpixel_y, filter_index, step_y, height,
1722               prediction, pred_stride);
1723         } else {
1724           ConvolveVerticalScale<4, 1, is_compound>(
1725               intermediate, intermediate_height, width, subpixel_y,
1726               filter_index, step_y, height, prediction, pred_stride);
1727         }
1728       } else {
1729         if (!is_compound && width == 2) {
1730           ConvolveVerticalScale4xH<4, 2, 2, is_compound>(
1731               intermediate, subpixel_y, filter_index, step_y, height,
1732               prediction, pred_stride);
1733         } else if (width == 4) {
1734           ConvolveVerticalScale4xH<4, 2, 4, is_compound>(
1735               intermediate, subpixel_y, filter_index, step_y, height,
1736               prediction, pred_stride);
1737         } else {
1738           ConvolveVerticalScale<4, 2, is_compound>(
1739               intermediate, intermediate_height, width, subpixel_y,
1740               filter_index, step_y, height, prediction, pred_stride);
1741         }
1742       }
1743   }
1744 }
1745 
ConvolveHorizontal_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int horizontal_filter_id,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1746 void ConvolveHorizontal_NEON(
1747     const void* LIBGAV1_RESTRICT const reference,
1748     const ptrdiff_t reference_stride, const int horizontal_filter_index,
1749     const int /*vertical_filter_index*/, const int horizontal_filter_id,
1750     const int /*vertical_filter_id*/, const int width, const int height,
1751     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
1752   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
1753   // Set |src| to the outermost tap.
1754   const auto* const src =
1755       static_cast<const uint8_t*>(reference) - kHorizontalOffset;
1756   auto* const dest = static_cast<uint8_t*>(prediction);
1757 
1758   DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height,
1759                    horizontal_filter_id, filter_index);
1760 }
1761 
1762 // The 1D compound shift is always |kInterRoundBitsHorizontal|, even for 1D
1763 // Vertical calculations.
Compound1DShift(const int16x8_t sum)1764 uint16x8_t Compound1DShift(const int16x8_t sum) {
1765   return vreinterpretq_u16_s16(
1766       vrshrq_n_s16(sum, kInterRoundBitsHorizontal - 1));
1767 }
1768 
1769 template <int filter_index, bool is_compound = false,
1770           bool negative_outside_taps = false>
FilterVertical(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const uint8x8_t * const taps)1771 void FilterVertical(const uint8_t* LIBGAV1_RESTRICT const src,
1772                     const ptrdiff_t src_stride,
1773                     void* LIBGAV1_RESTRICT const dst,
1774                     const ptrdiff_t dst_stride, const int width,
1775                     const int height, const uint8x8_t* const taps) {
1776   const int num_taps = GetNumTapsInFilter(filter_index);
1777   const int next_row = num_taps - 1;
1778   auto* const dst8 = static_cast<uint8_t*>(dst);
1779   auto* const dst16 = static_cast<uint16_t*>(dst);
1780   assert(width >= 8);
1781 
1782   int x = 0;
1783   do {
1784     const uint8_t* src_x = src + x;
1785     uint8x8_t srcs[8];
1786     srcs[0] = vld1_u8(src_x);
1787     src_x += src_stride;
1788     if (num_taps >= 4) {
1789       srcs[1] = vld1_u8(src_x);
1790       src_x += src_stride;
1791       srcs[2] = vld1_u8(src_x);
1792       src_x += src_stride;
1793       if (num_taps >= 6) {
1794         srcs[3] = vld1_u8(src_x);
1795         src_x += src_stride;
1796         srcs[4] = vld1_u8(src_x);
1797         src_x += src_stride;
1798         if (num_taps == 8) {
1799           srcs[5] = vld1_u8(src_x);
1800           src_x += src_stride;
1801           srcs[6] = vld1_u8(src_x);
1802           src_x += src_stride;
1803         }
1804       }
1805     }
1806 
1807     // Decreasing the y loop counter produces worse code with clang.
1808     // Don't unroll this loop since it generates too much code and the decoder
1809     // is even slower.
1810     int y = 0;
1811     do {
1812       srcs[next_row] = vld1_u8(src_x);
1813       src_x += src_stride;
1814 
1815       const int16x8_t sums =
1816           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
1817       if (is_compound) {
1818         const uint16x8_t results = Compound1DShift(sums);
1819         vst1q_u16(dst16 + x + y * dst_stride, results);
1820       } else {
1821         const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
1822         vst1_u8(dst8 + x + y * dst_stride, results);
1823       }
1824 
1825       srcs[0] = srcs[1];
1826       if (num_taps >= 4) {
1827         srcs[1] = srcs[2];
1828         srcs[2] = srcs[3];
1829         if (num_taps >= 6) {
1830           srcs[3] = srcs[4];
1831           srcs[4] = srcs[5];
1832           if (num_taps == 8) {
1833             srcs[5] = srcs[6];
1834             srcs[6] = srcs[7];
1835           }
1836         }
1837       }
1838     } while (++y < height);
1839     x += 8;
1840   } while (x < width);
1841 }
1842 
1843 template <int filter_index, bool is_compound = false,
1844           bool negative_outside_taps = false>
FilterVertical4xH(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int height,const uint8x8_t * const taps)1845 void FilterVertical4xH(const uint8_t* LIBGAV1_RESTRICT src,
1846                        const ptrdiff_t src_stride,
1847                        void* LIBGAV1_RESTRICT const dst,
1848                        const ptrdiff_t dst_stride, const int height,
1849                        const uint8x8_t* const taps) {
1850   const int num_taps = GetNumTapsInFilter(filter_index);
1851   auto* dst8 = static_cast<uint8_t*>(dst);
1852   auto* dst16 = static_cast<uint16_t*>(dst);
1853 
1854   uint8x8_t srcs[9];
1855 
1856   if (num_taps == 2) {
1857     srcs[2] = vdup_n_u8(0);
1858 
1859     srcs[0] = Load4(src);
1860     src += src_stride;
1861 
1862     int y = height;
1863     do {
1864       srcs[0] = Load4<1>(src, srcs[0]);
1865       src += src_stride;
1866       srcs[2] = Load4<0>(src, srcs[2]);
1867       src += src_stride;
1868       srcs[1] = vext_u8(srcs[0], srcs[2], 4);
1869 
1870       const int16x8_t sums =
1871           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
1872       if (is_compound) {
1873         const uint16x8_t results = Compound1DShift(sums);
1874 
1875         vst1q_u16(dst16, results);
1876         dst16 += 4 << 1;
1877       } else {
1878         const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
1879 
1880         StoreLo4(dst8, results);
1881         dst8 += dst_stride;
1882         StoreHi4(dst8, results);
1883         dst8 += dst_stride;
1884       }
1885 
1886       srcs[0] = srcs[2];
1887       y -= 2;
1888     } while (y != 0);
1889   } else if (num_taps == 4) {
1890     srcs[4] = vdup_n_u8(0);
1891 
1892     srcs[0] = Load4(src);
1893     src += src_stride;
1894     srcs[0] = Load4<1>(src, srcs[0]);
1895     src += src_stride;
1896     srcs[2] = Load4(src);
1897     src += src_stride;
1898     srcs[1] = vext_u8(srcs[0], srcs[2], 4);
1899 
1900     int y = height;
1901     do {
1902       srcs[2] = Load4<1>(src, srcs[2]);
1903       src += src_stride;
1904       srcs[4] = Load4<0>(src, srcs[4]);
1905       src += src_stride;
1906       srcs[3] = vext_u8(srcs[2], srcs[4], 4);
1907 
1908       const int16x8_t sums =
1909           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
1910       if (is_compound) {
1911         const uint16x8_t results = Compound1DShift(sums);
1912 
1913         vst1q_u16(dst16, results);
1914         dst16 += 4 << 1;
1915       } else {
1916         const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
1917 
1918         StoreLo4(dst8, results);
1919         dst8 += dst_stride;
1920         StoreHi4(dst8, results);
1921         dst8 += dst_stride;
1922       }
1923 
1924       srcs[0] = srcs[2];
1925       srcs[1] = srcs[3];
1926       srcs[2] = srcs[4];
1927       y -= 2;
1928     } while (y != 0);
1929   } else if (num_taps == 6) {
1930     srcs[6] = vdup_n_u8(0);
1931 
1932     srcs[0] = Load4(src);
1933     src += src_stride;
1934     srcs[0] = Load4<1>(src, srcs[0]);
1935     src += src_stride;
1936     srcs[2] = Load4(src);
1937     src += src_stride;
1938     srcs[1] = vext_u8(srcs[0], srcs[2], 4);
1939     srcs[2] = Load4<1>(src, srcs[2]);
1940     src += src_stride;
1941     srcs[4] = Load4(src);
1942     src += src_stride;
1943     srcs[3] = vext_u8(srcs[2], srcs[4], 4);
1944 
1945     int y = height;
1946     do {
1947       srcs[4] = Load4<1>(src, srcs[4]);
1948       src += src_stride;
1949       srcs[6] = Load4<0>(src, srcs[6]);
1950       src += src_stride;
1951       srcs[5] = vext_u8(srcs[4], srcs[6], 4);
1952 
1953       const int16x8_t sums =
1954           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
1955       if (is_compound) {
1956         const uint16x8_t results = Compound1DShift(sums);
1957 
1958         vst1q_u16(dst16, results);
1959         dst16 += 4 << 1;
1960       } else {
1961         const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
1962 
1963         StoreLo4(dst8, results);
1964         dst8 += dst_stride;
1965         StoreHi4(dst8, results);
1966         dst8 += dst_stride;
1967       }
1968 
1969       srcs[0] = srcs[2];
1970       srcs[1] = srcs[3];
1971       srcs[2] = srcs[4];
1972       srcs[3] = srcs[5];
1973       srcs[4] = srcs[6];
1974       y -= 2;
1975     } while (y != 0);
1976   } else if (num_taps == 8) {
1977     srcs[8] = vdup_n_u8(0);
1978 
1979     srcs[0] = Load4(src);
1980     src += src_stride;
1981     srcs[0] = Load4<1>(src, srcs[0]);
1982     src += src_stride;
1983     srcs[2] = Load4(src);
1984     src += src_stride;
1985     srcs[1] = vext_u8(srcs[0], srcs[2], 4);
1986     srcs[2] = Load4<1>(src, srcs[2]);
1987     src += src_stride;
1988     srcs[4] = Load4(src);
1989     src += src_stride;
1990     srcs[3] = vext_u8(srcs[2], srcs[4], 4);
1991     srcs[4] = Load4<1>(src, srcs[4]);
1992     src += src_stride;
1993     srcs[6] = Load4(src);
1994     src += src_stride;
1995     srcs[5] = vext_u8(srcs[4], srcs[6], 4);
1996 
1997     int y = height;
1998     do {
1999       srcs[6] = Load4<1>(src, srcs[6]);
2000       src += src_stride;
2001       srcs[8] = Load4<0>(src, srcs[8]);
2002       src += src_stride;
2003       srcs[7] = vext_u8(srcs[6], srcs[8], 4);
2004 
2005       const int16x8_t sums =
2006           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
2007       if (is_compound) {
2008         const uint16x8_t results = Compound1DShift(sums);
2009 
2010         vst1q_u16(dst16, results);
2011         dst16 += 4 << 1;
2012       } else {
2013         const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
2014 
2015         StoreLo4(dst8, results);
2016         dst8 += dst_stride;
2017         StoreHi4(dst8, results);
2018         dst8 += dst_stride;
2019       }
2020 
2021       srcs[0] = srcs[2];
2022       srcs[1] = srcs[3];
2023       srcs[2] = srcs[4];
2024       srcs[3] = srcs[5];
2025       srcs[4] = srcs[6];
2026       srcs[5] = srcs[7];
2027       srcs[6] = srcs[8];
2028       y -= 2;
2029     } while (y != 0);
2030   }
2031 }
2032 
2033 template <int filter_index, bool negative_outside_taps = false>
FilterVertical2xH(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int height,const uint8x8_t * const taps)2034 void FilterVertical2xH(const uint8_t* LIBGAV1_RESTRICT src,
2035                        const ptrdiff_t src_stride,
2036                        void* LIBGAV1_RESTRICT const dst,
2037                        const ptrdiff_t dst_stride, const int height,
2038                        const uint8x8_t* const taps) {
2039   const int num_taps = GetNumTapsInFilter(filter_index);
2040   auto* dst8 = static_cast<uint8_t*>(dst);
2041 
2042   uint8x8_t srcs[9];
2043 
2044   if (num_taps == 2) {
2045     srcs[2] = vdup_n_u8(0);
2046 
2047     srcs[0] = Load2(src);
2048     src += src_stride;
2049 
2050     int y = 0;
2051     do {
2052       srcs[0] = Load2<1>(src, srcs[0]);
2053       src += src_stride;
2054       srcs[0] = Load2<2>(src, srcs[0]);
2055       src += src_stride;
2056       srcs[0] = Load2<3>(src, srcs[0]);
2057       src += src_stride;
2058       srcs[2] = Load2<0>(src, srcs[2]);
2059       src += src_stride;
2060       srcs[1] = vext_u8(srcs[0], srcs[2], 2);
2061 
2062       // This uses srcs[0]..srcs[1].
2063       const int16x8_t sums =
2064           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
2065       const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
2066 
2067       Store2<0>(dst8, results);
2068       dst8 += dst_stride;
2069       Store2<1>(dst8, results);
2070       if (height == 2) return;
2071       dst8 += dst_stride;
2072       Store2<2>(dst8, results);
2073       dst8 += dst_stride;
2074       Store2<3>(dst8, results);
2075       dst8 += dst_stride;
2076 
2077       srcs[0] = srcs[2];
2078       y += 4;
2079     } while (y < height);
2080   } else if (num_taps == 4) {
2081     srcs[4] = vdup_n_u8(0);
2082 
2083     srcs[0] = Load2(src);
2084     src += src_stride;
2085     srcs[0] = Load2<1>(src, srcs[0]);
2086     src += src_stride;
2087     srcs[0] = Load2<2>(src, srcs[0]);
2088     src += src_stride;
2089 
2090     int y = 0;
2091     do {
2092       srcs[0] = Load2<3>(src, srcs[0]);
2093       src += src_stride;
2094       srcs[4] = Load2<0>(src, srcs[4]);
2095       src += src_stride;
2096       srcs[1] = vext_u8(srcs[0], srcs[4], 2);
2097       srcs[4] = Load2<1>(src, srcs[4]);
2098       src += src_stride;
2099       srcs[2] = vext_u8(srcs[0], srcs[4], 4);
2100       srcs[4] = Load2<2>(src, srcs[4]);
2101       src += src_stride;
2102       srcs[3] = vext_u8(srcs[0], srcs[4], 6);
2103 
2104       // This uses srcs[0]..srcs[3].
2105       const int16x8_t sums =
2106           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
2107       const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
2108 
2109       Store2<0>(dst8, results);
2110       dst8 += dst_stride;
2111       Store2<1>(dst8, results);
2112       if (height == 2) return;
2113       dst8 += dst_stride;
2114       Store2<2>(dst8, results);
2115       dst8 += dst_stride;
2116       Store2<3>(dst8, results);
2117       dst8 += dst_stride;
2118 
2119       srcs[0] = srcs[4];
2120       y += 4;
2121     } while (y < height);
2122   } else if (num_taps == 6) {
2123     // During the vertical pass the number of taps is restricted when
2124     // |height| <= 4.
2125     assert(height > 4);
2126     srcs[8] = vdup_n_u8(0);
2127 
2128     srcs[0] = Load2(src);
2129     src += src_stride;
2130     srcs[0] = Load2<1>(src, srcs[0]);
2131     src += src_stride;
2132     srcs[0] = Load2<2>(src, srcs[0]);
2133     src += src_stride;
2134     srcs[0] = Load2<3>(src, srcs[0]);
2135     src += src_stride;
2136     srcs[4] = Load2(src);
2137     src += src_stride;
2138     srcs[1] = vext_u8(srcs[0], srcs[4], 2);
2139 
2140     int y = 0;
2141     do {
2142       srcs[4] = Load2<1>(src, srcs[4]);
2143       src += src_stride;
2144       srcs[2] = vext_u8(srcs[0], srcs[4], 4);
2145       srcs[4] = Load2<2>(src, srcs[4]);
2146       src += src_stride;
2147       srcs[3] = vext_u8(srcs[0], srcs[4], 6);
2148       srcs[4] = Load2<3>(src, srcs[4]);
2149       src += src_stride;
2150       srcs[8] = Load2<0>(src, srcs[8]);
2151       src += src_stride;
2152       srcs[5] = vext_u8(srcs[4], srcs[8], 2);
2153 
2154       // This uses srcs[0]..srcs[5].
2155       const int16x8_t sums =
2156           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
2157       const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
2158 
2159       Store2<0>(dst8, results);
2160       dst8 += dst_stride;
2161       Store2<1>(dst8, results);
2162       dst8 += dst_stride;
2163       Store2<2>(dst8, results);
2164       dst8 += dst_stride;
2165       Store2<3>(dst8, results);
2166       dst8 += dst_stride;
2167 
2168       srcs[0] = srcs[4];
2169       srcs[1] = srcs[5];
2170       srcs[4] = srcs[8];
2171       y += 4;
2172     } while (y < height);
2173   } else if (num_taps == 8) {
2174     // During the vertical pass the number of taps is restricted when
2175     // |height| <= 4.
2176     assert(height > 4);
2177     srcs[8] = vdup_n_u8(0);
2178 
2179     srcs[0] = Load2(src);
2180     src += src_stride;
2181     srcs[0] = Load2<1>(src, srcs[0]);
2182     src += src_stride;
2183     srcs[0] = Load2<2>(src, srcs[0]);
2184     src += src_stride;
2185     srcs[0] = Load2<3>(src, srcs[0]);
2186     src += src_stride;
2187     srcs[4] = Load2(src);
2188     src += src_stride;
2189     srcs[1] = vext_u8(srcs[0], srcs[4], 2);
2190     srcs[4] = Load2<1>(src, srcs[4]);
2191     src += src_stride;
2192     srcs[2] = vext_u8(srcs[0], srcs[4], 4);
2193     srcs[4] = Load2<2>(src, srcs[4]);
2194     src += src_stride;
2195     srcs[3] = vext_u8(srcs[0], srcs[4], 6);
2196 
2197     int y = 0;
2198     do {
2199       srcs[4] = Load2<3>(src, srcs[4]);
2200       src += src_stride;
2201       srcs[8] = Load2<0>(src, srcs[8]);
2202       src += src_stride;
2203       srcs[5] = vext_u8(srcs[4], srcs[8], 2);
2204       srcs[8] = Load2<1>(src, srcs[8]);
2205       src += src_stride;
2206       srcs[6] = vext_u8(srcs[4], srcs[8], 4);
2207       srcs[8] = Load2<2>(src, srcs[8]);
2208       src += src_stride;
2209       srcs[7] = vext_u8(srcs[4], srcs[8], 6);
2210 
2211       // This uses srcs[0]..srcs[7].
2212       const int16x8_t sums =
2213           SumOnePassTaps<filter_index, negative_outside_taps>(srcs, taps);
2214       const uint8x8_t results = vqrshrun_n_s16(sums, kFilterBits - 1);
2215 
2216       Store2<0>(dst8, results);
2217       dst8 += dst_stride;
2218       Store2<1>(dst8, results);
2219       dst8 += dst_stride;
2220       Store2<2>(dst8, results);
2221       dst8 += dst_stride;
2222       Store2<3>(dst8, results);
2223       dst8 += dst_stride;
2224 
2225       srcs[0] = srcs[4];
2226       srcs[1] = srcs[5];
2227       srcs[2] = srcs[6];
2228       srcs[3] = srcs[7];
2229       srcs[4] = srcs[8];
2230       y += 4;
2231     } while (y < height);
2232   }
2233 }
2234 
2235 // This function is a simplified version of Convolve2D_C.
2236 // It is called when it is single prediction mode, where only vertical
2237 // filtering is required.
2238 // The output is the single prediction of the block, clipped to valid pixel
2239 // range.
ConvolveVertical_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)2240 void ConvolveVertical_NEON(
2241     const void* LIBGAV1_RESTRICT const reference,
2242     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
2243     const int vertical_filter_index, const int /*horizontal_filter_id*/,
2244     const int vertical_filter_id, const int width, const int height,
2245     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
2246   const int filter_index = GetFilterIndex(vertical_filter_index, height);
2247   const int vertical_taps = GetNumTapsInFilter(filter_index);
2248   const ptrdiff_t src_stride = reference_stride;
2249   const auto* src = static_cast<const uint8_t*>(reference) -
2250                     (vertical_taps / 2 - 1) * src_stride;
2251   auto* const dest = static_cast<uint8_t*>(prediction);
2252   const ptrdiff_t dest_stride = pred_stride;
2253   assert(vertical_filter_id != 0);
2254 
2255   uint8x8_t taps[8];
2256   for (int k = 0; k < kSubPixelTaps; ++k) {
2257     taps[k] =
2258         vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
2259   }
2260 
2261   if (filter_index == 0) {  // 6 tap.
2262     if (width == 2) {
2263       FilterVertical2xH<0>(src, src_stride, dest, dest_stride, height,
2264                            taps + 1);
2265     } else if (width == 4) {
2266       FilterVertical4xH<0>(src, src_stride, dest, dest_stride, height,
2267                            taps + 1);
2268     } else {
2269       FilterVertical<0>(src, src_stride, dest, dest_stride, width, height,
2270                         taps + 1);
2271     }
2272   } else if ((static_cast<int>(filter_index == 1) &
2273               (static_cast<int>(vertical_filter_id == 1) |
2274                static_cast<int>(vertical_filter_id == 15))) != 0) {  // 5 tap.
2275     if (width == 2) {
2276       FilterVertical2xH<1>(src, src_stride, dest, dest_stride, height,
2277                            taps + 1);
2278     } else if (width == 4) {
2279       FilterVertical4xH<1>(src, src_stride, dest, dest_stride, height,
2280                            taps + 1);
2281     } else {
2282       FilterVertical<1>(src, src_stride, dest, dest_stride, width, height,
2283                         taps + 1);
2284     }
2285   } else if ((static_cast<int>(filter_index == 1) &
2286               (static_cast<int>(vertical_filter_id == 7) |
2287                static_cast<int>(vertical_filter_id == 8) |
2288                static_cast<int>(vertical_filter_id == 9))) !=
2289              0) {  // 6 tap with weird negative taps.
2290     if (width == 2) {
2291       FilterVertical2xH<1,
2292                         /*negative_outside_taps=*/true>(
2293           src, src_stride, dest, dest_stride, height, taps + 1);
2294     } else if (width == 4) {
2295       FilterVertical4xH<1, /*is_compound=*/false,
2296                         /*negative_outside_taps=*/true>(
2297           src, src_stride, dest, dest_stride, height, taps + 1);
2298     } else {
2299       FilterVertical<1, /*is_compound=*/false, /*negative_outside_taps=*/true>(
2300           src, src_stride, dest, dest_stride, width, height, taps + 1);
2301     }
2302   } else if (filter_index == 2) {  // 8 tap.
2303     if (width == 2) {
2304       FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
2305     } else if (width == 4) {
2306       FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
2307     } else {
2308       FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
2309                         taps);
2310     }
2311   } else if (filter_index == 3) {  // 2 tap.
2312     if (width == 2) {
2313       FilterVertical2xH<3>(src, src_stride, dest, dest_stride, height,
2314                            taps + 3);
2315     } else if (width == 4) {
2316       FilterVertical4xH<3>(src, src_stride, dest, dest_stride, height,
2317                            taps + 3);
2318     } else {
2319       FilterVertical<3>(src, src_stride, dest, dest_stride, width, height,
2320                         taps + 3);
2321     }
2322   } else if (filter_index == 4) {  // 4 tap.
2323     // Outside taps are negative.
2324     if (width == 2) {
2325       FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height,
2326                            taps + 2);
2327     } else if (width == 4) {
2328       FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height,
2329                            taps + 2);
2330     } else {
2331       FilterVertical<4>(src, src_stride, dest, dest_stride, width, height,
2332                         taps + 2);
2333     }
2334   } else {
2335     // 4 tap. When |filter_index| == 1 the |vertical_filter_id| values listed
2336     // below map to 4 tap filters.
2337     assert(filter_index == 5 ||
2338            (filter_index == 1 &&
2339             (vertical_filter_id == 2 || vertical_filter_id == 3 ||
2340              vertical_filter_id == 4 || vertical_filter_id == 5 ||
2341              vertical_filter_id == 6 || vertical_filter_id == 10 ||
2342              vertical_filter_id == 11 || vertical_filter_id == 12 ||
2343              vertical_filter_id == 13 || vertical_filter_id == 14)));
2344     // According to GetNumTapsInFilter() this has 6 taps but here we are
2345     // treating it as though it has 4.
2346     if (filter_index == 1) src += src_stride;
2347     if (width == 2) {
2348       FilterVertical2xH<5>(src, src_stride, dest, dest_stride, height,
2349                            taps + 2);
2350     } else if (width == 4) {
2351       FilterVertical4xH<5>(src, src_stride, dest, dest_stride, height,
2352                            taps + 2);
2353     } else {
2354       FilterVertical<5>(src, src_stride, dest, dest_stride, width, height,
2355                         taps + 2);
2356     }
2357   }
2358 }
2359 
ConvolveCompoundCopy_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t)2360 void ConvolveCompoundCopy_NEON(
2361     const void* LIBGAV1_RESTRICT const reference,
2362     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
2363     const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
2364     const int /*vertical_filter_id*/, const int width, const int height,
2365     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
2366   const auto* src = static_cast<const uint8_t*>(reference);
2367   const ptrdiff_t src_stride = reference_stride;
2368   auto* dest = static_cast<uint16_t*>(prediction);
2369   constexpr int final_shift =
2370       kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
2371 
2372   if (width >= 16) {
2373     int y = height;
2374     do {
2375       int x = 0;
2376       do {
2377         const uint8x16_t v_src = vld1q_u8(&src[x]);
2378         const uint16x8_t v_dest_lo =
2379             vshll_n_u8(vget_low_u8(v_src), final_shift);
2380         const uint16x8_t v_dest_hi =
2381             vshll_n_u8(vget_high_u8(v_src), final_shift);
2382         vst1q_u16(&dest[x], v_dest_lo);
2383         x += 8;
2384         vst1q_u16(&dest[x], v_dest_hi);
2385         x += 8;
2386       } while (x < width);
2387       src += src_stride;
2388       dest += width;
2389     } while (--y != 0);
2390   } else if (width == 8) {
2391     int y = height;
2392     do {
2393       const uint8x8_t v_src = vld1_u8(&src[0]);
2394       const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift);
2395       vst1q_u16(&dest[0], v_dest);
2396       src += src_stride;
2397       dest += width;
2398     } while (--y != 0);
2399   } else {  // width == 4
2400     uint8x8_t v_src = vdup_n_u8(0);
2401 
2402     int y = height;
2403     do {
2404       v_src = Load4<0>(&src[0], v_src);
2405       src += src_stride;
2406       v_src = Load4<1>(&src[0], v_src);
2407       src += src_stride;
2408       const uint16x8_t v_dest = vshll_n_u8(v_src, final_shift);
2409       vst1q_u16(&dest[0], v_dest);
2410       dest += 4 << 1;
2411       y -= 2;
2412     } while (y != 0);
2413   }
2414 }
2415 
ConvolveCompoundVertical_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t)2416 void ConvolveCompoundVertical_NEON(
2417     const void* LIBGAV1_RESTRICT const reference,
2418     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
2419     const int vertical_filter_index, const int /*horizontal_filter_id*/,
2420     const int vertical_filter_id, const int width, const int height,
2421     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
2422   const int filter_index = GetFilterIndex(vertical_filter_index, height);
2423   const int vertical_taps = GetNumTapsInFilter(filter_index);
2424   const ptrdiff_t src_stride = reference_stride;
2425   const auto* src = static_cast<const uint8_t*>(reference) -
2426                     (vertical_taps / 2 - 1) * src_stride;
2427   auto* const dest = static_cast<uint16_t*>(prediction);
2428   assert(vertical_filter_id != 0);
2429 
2430   uint8x8_t taps[8];
2431   for (int k = 0; k < kSubPixelTaps; ++k) {
2432     taps[k] =
2433         vdup_n_u8(kAbsHalfSubPixelFilters[filter_index][vertical_filter_id][k]);
2434   }
2435 
2436   if (filter_index == 0) {  // 6 tap.
2437     if (width == 4) {
2438       FilterVertical4xH<0, /*is_compound=*/true>(src, src_stride, dest, 4,
2439                                                  height, taps + 1);
2440     } else {
2441       FilterVertical<0, /*is_compound=*/true>(src, src_stride, dest, width,
2442                                               width, height, taps + 1);
2443     }
2444   } else if ((static_cast<int>(filter_index == 1) &
2445               (static_cast<int>(vertical_filter_id == 1) |
2446                static_cast<int>(vertical_filter_id == 15))) != 0) {  // 5 tap.
2447     if (width == 4) {
2448       FilterVertical4xH<1, /*is_compound=*/true>(src, src_stride, dest, 4,
2449                                                  height, taps + 1);
2450     } else {
2451       FilterVertical<1, /*is_compound=*/true>(src, src_stride, dest, width,
2452                                               width, height, taps + 1);
2453     }
2454   } else if ((static_cast<int>(filter_index == 1) &
2455               (static_cast<int>(vertical_filter_id == 7) |
2456                static_cast<int>(vertical_filter_id == 8) |
2457                static_cast<int>(vertical_filter_id == 9))) !=
2458              0) {  // 6 tap with weird negative taps.
2459     if (width == 4) {
2460       FilterVertical4xH<1, /*is_compound=*/true,
2461                         /*negative_outside_taps=*/true>(src, src_stride, dest,
2462                                                         4, height, taps + 1);
2463     } else {
2464       FilterVertical<1, /*is_compound=*/true, /*negative_outside_taps=*/true>(
2465           src, src_stride, dest, width, width, height, taps + 1);
2466     }
2467   } else if (filter_index == 2) {  // 8 tap.
2468     if (width == 4) {
2469       FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
2470                                                  height, taps);
2471     } else {
2472       FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
2473                                               width, height, taps);
2474     }
2475   } else if (filter_index == 3) {  // 2 tap.
2476     if (width == 4) {
2477       FilterVertical4xH<3, /*is_compound=*/true>(src, src_stride, dest, 4,
2478                                                  height, taps + 3);
2479     } else {
2480       FilterVertical<3, /*is_compound=*/true>(src, src_stride, dest, width,
2481                                               width, height, taps + 3);
2482     }
2483   } else if (filter_index == 4) {  // 4 tap.
2484     if (width == 4) {
2485       FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4,
2486                                                  height, taps + 2);
2487     } else {
2488       FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width,
2489                                               width, height, taps + 2);
2490     }
2491   } else {
2492     // 4 tap. When |filter_index| == 1 the |filter_id| values listed below map
2493     // to 4 tap filters.
2494     assert(filter_index == 5 ||
2495            (filter_index == 1 &&
2496             (vertical_filter_id == 2 || vertical_filter_id == 3 ||
2497              vertical_filter_id == 4 || vertical_filter_id == 5 ||
2498              vertical_filter_id == 6 || vertical_filter_id == 10 ||
2499              vertical_filter_id == 11 || vertical_filter_id == 12 ||
2500              vertical_filter_id == 13 || vertical_filter_id == 14)));
2501     // According to GetNumTapsInFilter() this has 6 taps but here we are
2502     // treating it as though it has 4.
2503     if (filter_index == 1) src += src_stride;
2504     if (width == 4) {
2505       FilterVertical4xH<5, /*is_compound=*/true>(src, src_stride, dest, 4,
2506                                                  height, taps + 2);
2507     } else {
2508       FilterVertical<5, /*is_compound=*/true>(src, src_stride, dest, width,
2509                                               width, height, taps + 2);
2510     }
2511   }
2512 }
2513 
ConvolveCompoundHorizontal_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int horizontal_filter_id,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t)2514 void ConvolveCompoundHorizontal_NEON(
2515     const void* LIBGAV1_RESTRICT const reference,
2516     const ptrdiff_t reference_stride, const int horizontal_filter_index,
2517     const int /*vertical_filter_index*/, const int horizontal_filter_id,
2518     const int /*vertical_filter_id*/, const int width, const int height,
2519     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
2520   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
2521   const auto* const src =
2522       static_cast<const uint8_t*>(reference) - kHorizontalOffset;
2523   auto* const dest = static_cast<uint16_t*>(prediction);
2524 
2525   DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>(
2526       src, reference_stride, dest, width, width, height, horizontal_filter_id,
2527       filter_index);
2528 }
2529 
2530 template <int vertical_taps>
Compound2DVertical(const uint16_t * LIBGAV1_RESTRICT const intermediate_result,const int width,const int height,const int16x8_t taps,void * LIBGAV1_RESTRICT const prediction)2531 void Compound2DVertical(
2532     const uint16_t* LIBGAV1_RESTRICT const intermediate_result, const int width,
2533     const int height, const int16x8_t taps,
2534     void* LIBGAV1_RESTRICT const prediction) {
2535   auto* const dest = static_cast<uint16_t*>(prediction);
2536   if (width == 4) {
2537     Filter2DVerticalWidth4<vertical_taps, /*is_compound=*/true>(
2538         intermediate_result, dest, width, height, taps);
2539   } else {
2540     Filter2DVerticalWidth8AndUp<vertical_taps, /*is_compound=*/true>(
2541         intermediate_result, dest, width, width, height, taps);
2542   }
2543 }
2544 
ConvolveCompound2D_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int horizontal_filter_id,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t)2545 void ConvolveCompound2D_NEON(
2546     const void* LIBGAV1_RESTRICT const reference,
2547     const ptrdiff_t reference_stride, const int horizontal_filter_index,
2548     const int vertical_filter_index, const int horizontal_filter_id,
2549     const int vertical_filter_id, const int width, const int height,
2550     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t /*pred_stride*/) {
2551   // The output of the horizontal filter, i.e. the intermediate_result, is
2552   // guaranteed to fit in int16_t.
2553   uint16_t
2554       intermediate_result[kMaxSuperBlockSizeInPixels *
2555                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
2556 
2557   // Horizontal filter.
2558   // Filter types used for width <= 4 are different from those for width > 4.
2559   // When width > 4, the valid filter index range is always [0, 3].
2560   // When width <= 4, the valid filter index range is always [4, 5].
2561   // Similarly for height.
2562   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
2563   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
2564   const int vertical_taps = GetNumTapsInFilter(vert_filter_index);
2565   const int intermediate_height = height + vertical_taps - 1;
2566   const ptrdiff_t src_stride = reference_stride;
2567   const auto* const src = static_cast<const uint8_t*>(reference) -
2568                           (vertical_taps / 2 - 1) * src_stride -
2569                           kHorizontalOffset;
2570   DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
2571       src, src_stride, intermediate_result, width, width, intermediate_height,
2572       horizontal_filter_id, horiz_filter_index);
2573 
2574   // Vertical filter.
2575   assert(vertical_filter_id != 0);
2576   const int16x8_t taps = vmovl_s8(
2577       vld1_s8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]));
2578   if (vertical_taps == 8) {
2579     Compound2DVertical<8>(intermediate_result, width, height, taps, prediction);
2580   } else if (vertical_taps == 6) {
2581     Compound2DVertical<6>(intermediate_result, width, height, taps, prediction);
2582   } else if (vertical_taps == 4) {
2583     Compound2DVertical<4>(intermediate_result, width, height, taps, prediction);
2584   } else {  // |vertical_taps| == 2
2585     Compound2DVertical<2>(intermediate_result, width, height, taps, prediction);
2586   }
2587 }
2588 
HalfAddHorizontal(const uint8_t * LIBGAV1_RESTRICT const src,uint8_t * LIBGAV1_RESTRICT const dst)2589 inline void HalfAddHorizontal(const uint8_t* LIBGAV1_RESTRICT const src,
2590                               uint8_t* LIBGAV1_RESTRICT const dst) {
2591   const uint8x16_t left = vld1q_u8(src);
2592   const uint8x16_t right = vld1q_u8(src + 1);
2593   vst1q_u8(dst, vrhaddq_u8(left, right));
2594 }
2595 
2596 template <int width>
IntraBlockCopyHorizontal(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)2597 inline void IntraBlockCopyHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
2598                                      const ptrdiff_t src_stride,
2599                                      const int height,
2600                                      uint8_t* LIBGAV1_RESTRICT dst,
2601                                      const ptrdiff_t dst_stride) {
2602   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
2603   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
2604 
2605   int y = height;
2606   do {
2607     HalfAddHorizontal(src, dst);
2608     if (width >= 32) {
2609       src += 16;
2610       dst += 16;
2611       HalfAddHorizontal(src, dst);
2612       if (width >= 64) {
2613         src += 16;
2614         dst += 16;
2615         HalfAddHorizontal(src, dst);
2616         src += 16;
2617         dst += 16;
2618         HalfAddHorizontal(src, dst);
2619         if (width == 128) {
2620           src += 16;
2621           dst += 16;
2622           HalfAddHorizontal(src, dst);
2623           src += 16;
2624           dst += 16;
2625           HalfAddHorizontal(src, dst);
2626           src += 16;
2627           dst += 16;
2628           HalfAddHorizontal(src, dst);
2629           src += 16;
2630           dst += 16;
2631           HalfAddHorizontal(src, dst);
2632         }
2633       }
2634     }
2635     src += src_remainder_stride;
2636     dst += dst_remainder_stride;
2637   } while (--y != 0);
2638 }
2639 
ConvolveIntraBlockCopyHorizontal_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)2640 void ConvolveIntraBlockCopyHorizontal_NEON(
2641     const void* LIBGAV1_RESTRICT const reference,
2642     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
2643     const int /*vertical_filter_index*/, const int /*subpixel_x*/,
2644     const int /*subpixel_y*/, const int width, const int height,
2645     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
2646   assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
2647   assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
2648   const auto* src = static_cast<const uint8_t*>(reference);
2649   auto* dest = static_cast<uint8_t*>(prediction);
2650 
2651   if (width == 128) {
2652     IntraBlockCopyHorizontal<128>(src, reference_stride, height, dest,
2653                                   pred_stride);
2654   } else if (width == 64) {
2655     IntraBlockCopyHorizontal<64>(src, reference_stride, height, dest,
2656                                  pred_stride);
2657   } else if (width == 32) {
2658     IntraBlockCopyHorizontal<32>(src, reference_stride, height, dest,
2659                                  pred_stride);
2660   } else if (width == 16) {
2661     IntraBlockCopyHorizontal<16>(src, reference_stride, height, dest,
2662                                  pred_stride);
2663   } else if (width == 8) {
2664     int y = height;
2665     do {
2666       const uint8x8_t left = vld1_u8(src);
2667       const uint8x8_t right = vld1_u8(src + 1);
2668       vst1_u8(dest, vrhadd_u8(left, right));
2669 
2670       src += reference_stride;
2671       dest += pred_stride;
2672     } while (--y != 0);
2673   } else {  // width == 4
2674     uint8x8_t left = vdup_n_u8(0);
2675     uint8x8_t right = vdup_n_u8(0);
2676     int y = height;
2677     do {
2678       left = Load4<0>(src, left);
2679       right = Load4<0>(src + 1, right);
2680       src += reference_stride;
2681       left = Load4<1>(src, left);
2682       right = Load4<1>(src + 1, right);
2683       src += reference_stride;
2684 
2685       const uint8x8_t result = vrhadd_u8(left, right);
2686 
2687       StoreLo4(dest, result);
2688       dest += pred_stride;
2689       StoreHi4(dest, result);
2690       dest += pred_stride;
2691       y -= 2;
2692     } while (y != 0);
2693   }
2694 }
2695 
2696 template <int width>
IntraBlockCopyVertical(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)2697 inline void IntraBlockCopyVertical(const uint8_t* LIBGAV1_RESTRICT src,
2698                                    const ptrdiff_t src_stride, const int height,
2699                                    uint8_t* LIBGAV1_RESTRICT dst,
2700                                    const ptrdiff_t dst_stride) {
2701   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
2702   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
2703   uint8x16_t row[8], below[8];
2704 
2705   row[0] = vld1q_u8(src);
2706   if (width >= 32) {
2707     src += 16;
2708     row[1] = vld1q_u8(src);
2709     if (width >= 64) {
2710       src += 16;
2711       row[2] = vld1q_u8(src);
2712       src += 16;
2713       row[3] = vld1q_u8(src);
2714       if (width == 128) {
2715         src += 16;
2716         row[4] = vld1q_u8(src);
2717         src += 16;
2718         row[5] = vld1q_u8(src);
2719         src += 16;
2720         row[6] = vld1q_u8(src);
2721         src += 16;
2722         row[7] = vld1q_u8(src);
2723       }
2724     }
2725   }
2726   src += src_remainder_stride;
2727 
2728   int y = height;
2729   do {
2730     below[0] = vld1q_u8(src);
2731     if (width >= 32) {
2732       src += 16;
2733       below[1] = vld1q_u8(src);
2734       if (width >= 64) {
2735         src += 16;
2736         below[2] = vld1q_u8(src);
2737         src += 16;
2738         below[3] = vld1q_u8(src);
2739         if (width == 128) {
2740           src += 16;
2741           below[4] = vld1q_u8(src);
2742           src += 16;
2743           below[5] = vld1q_u8(src);
2744           src += 16;
2745           below[6] = vld1q_u8(src);
2746           src += 16;
2747           below[7] = vld1q_u8(src);
2748         }
2749       }
2750     }
2751     src += src_remainder_stride;
2752 
2753     vst1q_u8(dst, vrhaddq_u8(row[0], below[0]));
2754     row[0] = below[0];
2755     if (width >= 32) {
2756       dst += 16;
2757       vst1q_u8(dst, vrhaddq_u8(row[1], below[1]));
2758       row[1] = below[1];
2759       if (width >= 64) {
2760         dst += 16;
2761         vst1q_u8(dst, vrhaddq_u8(row[2], below[2]));
2762         row[2] = below[2];
2763         dst += 16;
2764         vst1q_u8(dst, vrhaddq_u8(row[3], below[3]));
2765         row[3] = below[3];
2766         if (width >= 128) {
2767           dst += 16;
2768           vst1q_u8(dst, vrhaddq_u8(row[4], below[4]));
2769           row[4] = below[4];
2770           dst += 16;
2771           vst1q_u8(dst, vrhaddq_u8(row[5], below[5]));
2772           row[5] = below[5];
2773           dst += 16;
2774           vst1q_u8(dst, vrhaddq_u8(row[6], below[6]));
2775           row[6] = below[6];
2776           dst += 16;
2777           vst1q_u8(dst, vrhaddq_u8(row[7], below[7]));
2778           row[7] = below[7];
2779         }
2780       }
2781     }
2782     dst += dst_remainder_stride;
2783   } while (--y != 0);
2784 }
2785 
ConvolveIntraBlockCopyVertical_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)2786 void ConvolveIntraBlockCopyVertical_NEON(
2787     const void* LIBGAV1_RESTRICT const reference,
2788     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
2789     const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
2790     const int /*vertical_filter_id*/, const int width, const int height,
2791     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
2792   assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
2793   assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
2794   const auto* src = static_cast<const uint8_t*>(reference);
2795   auto* dest = static_cast<uint8_t*>(prediction);
2796 
2797   if (width == 128) {
2798     IntraBlockCopyVertical<128>(src, reference_stride, height, dest,
2799                                 pred_stride);
2800   } else if (width == 64) {
2801     IntraBlockCopyVertical<64>(src, reference_stride, height, dest,
2802                                pred_stride);
2803   } else if (width == 32) {
2804     IntraBlockCopyVertical<32>(src, reference_stride, height, dest,
2805                                pred_stride);
2806   } else if (width == 16) {
2807     IntraBlockCopyVertical<16>(src, reference_stride, height, dest,
2808                                pred_stride);
2809   } else if (width == 8) {
2810     uint8x8_t row, below;
2811     row = vld1_u8(src);
2812     src += reference_stride;
2813 
2814     int y = height;
2815     do {
2816       below = vld1_u8(src);
2817       src += reference_stride;
2818 
2819       vst1_u8(dest, vrhadd_u8(row, below));
2820       dest += pred_stride;
2821 
2822       row = below;
2823     } while (--y != 0);
2824   } else {  // width == 4
2825     uint8x8_t row = Load4(src);
2826     uint8x8_t below = vdup_n_u8(0);
2827     src += reference_stride;
2828 
2829     int y = height;
2830     do {
2831       below = Load4<0>(src, below);
2832       src += reference_stride;
2833 
2834       StoreLo4(dest, vrhadd_u8(row, below));
2835       dest += pred_stride;
2836 
2837       row = below;
2838     } while (--y != 0);
2839   }
2840 }
2841 
2842 template <int width>
IntraBlockCopy2D(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)2843 inline void IntraBlockCopy2D(const uint8_t* LIBGAV1_RESTRICT src,
2844                              const ptrdiff_t src_stride, const int height,
2845                              uint8_t* LIBGAV1_RESTRICT dst,
2846                              const ptrdiff_t dst_stride) {
2847   const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
2848   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
2849   uint16x8_t row[16];
2850   row[0] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2851   if (width >= 16) {
2852     src += 8;
2853     row[1] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2854     if (width >= 32) {
2855       src += 8;
2856       row[2] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2857       src += 8;
2858       row[3] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2859       if (width >= 64) {
2860         src += 8;
2861         row[4] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2862         src += 8;
2863         row[5] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2864         src += 8;
2865         row[6] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2866         src += 8;
2867         row[7] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2868         if (width == 128) {
2869           src += 8;
2870           row[8] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2871           src += 8;
2872           row[9] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2873           src += 8;
2874           row[10] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2875           src += 8;
2876           row[11] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2877           src += 8;
2878           row[12] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2879           src += 8;
2880           row[13] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2881           src += 8;
2882           row[14] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2883           src += 8;
2884           row[15] = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2885         }
2886       }
2887     }
2888   }
2889   src += src_remainder_stride;
2890 
2891   int y = height;
2892   do {
2893     const uint16x8_t below_0 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2894     vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[0], below_0), 2));
2895     row[0] = below_0;
2896     if (width >= 16) {
2897       src += 8;
2898       dst += 8;
2899 
2900       const uint16x8_t below_1 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2901       vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[1], below_1), 2));
2902       row[1] = below_1;
2903       if (width >= 32) {
2904         src += 8;
2905         dst += 8;
2906 
2907         const uint16x8_t below_2 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2908         vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[2], below_2), 2));
2909         row[2] = below_2;
2910         src += 8;
2911         dst += 8;
2912 
2913         const uint16x8_t below_3 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2914         vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[3], below_3), 2));
2915         row[3] = below_3;
2916         if (width >= 64) {
2917           src += 8;
2918           dst += 8;
2919 
2920           const uint16x8_t below_4 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2921           vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[4], below_4), 2));
2922           row[4] = below_4;
2923           src += 8;
2924           dst += 8;
2925 
2926           const uint16x8_t below_5 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2927           vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[5], below_5), 2));
2928           row[5] = below_5;
2929           src += 8;
2930           dst += 8;
2931 
2932           const uint16x8_t below_6 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2933           vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[6], below_6), 2));
2934           row[6] = below_6;
2935           src += 8;
2936           dst += 8;
2937 
2938           const uint16x8_t below_7 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2939           vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[7], below_7), 2));
2940           row[7] = below_7;
2941           if (width == 128) {
2942             src += 8;
2943             dst += 8;
2944 
2945             const uint16x8_t below_8 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2946             vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[8], below_8), 2));
2947             row[8] = below_8;
2948             src += 8;
2949             dst += 8;
2950 
2951             const uint16x8_t below_9 = vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2952             vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[9], below_9), 2));
2953             row[9] = below_9;
2954             src += 8;
2955             dst += 8;
2956 
2957             const uint16x8_t below_10 =
2958                 vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2959             vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[10], below_10), 2));
2960             row[10] = below_10;
2961             src += 8;
2962             dst += 8;
2963 
2964             const uint16x8_t below_11 =
2965                 vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2966             vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[11], below_11), 2));
2967             row[11] = below_11;
2968             src += 8;
2969             dst += 8;
2970 
2971             const uint16x8_t below_12 =
2972                 vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2973             vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[12], below_12), 2));
2974             row[12] = below_12;
2975             src += 8;
2976             dst += 8;
2977 
2978             const uint16x8_t below_13 =
2979                 vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2980             vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[13], below_13), 2));
2981             row[13] = below_13;
2982             src += 8;
2983             dst += 8;
2984 
2985             const uint16x8_t below_14 =
2986                 vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2987             vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[14], below_14), 2));
2988             row[14] = below_14;
2989             src += 8;
2990             dst += 8;
2991 
2992             const uint16x8_t below_15 =
2993                 vaddl_u8(vld1_u8(src), vld1_u8(src + 1));
2994             vst1_u8(dst, vrshrn_n_u16(vaddq_u16(row[15], below_15), 2));
2995             row[15] = below_15;
2996           }
2997         }
2998       }
2999     }
3000     src += src_remainder_stride;
3001     dst += dst_remainder_stride;
3002   } while (--y != 0);
3003 }
3004 
ConvolveIntraBlockCopy2D_NEON(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)3005 void ConvolveIntraBlockCopy2D_NEON(
3006     const void* LIBGAV1_RESTRICT const reference,
3007     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
3008     const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
3009     const int /*vertical_filter_id*/, const int width, const int height,
3010     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
3011   assert(width >= 4 && width <= kMaxSuperBlockSizeInPixels);
3012   assert(height >= 4 && height <= kMaxSuperBlockSizeInPixels);
3013   const auto* src = static_cast<const uint8_t*>(reference);
3014   auto* dest = static_cast<uint8_t*>(prediction);
3015   // Note: allow vertical access to height + 1. Because this function is only
3016   // for u/v plane of intra block copy, such access is guaranteed to be within
3017   // the prediction block.
3018 
3019   if (width == 128) {
3020     IntraBlockCopy2D<128>(src, reference_stride, height, dest, pred_stride);
3021   } else if (width == 64) {
3022     IntraBlockCopy2D<64>(src, reference_stride, height, dest, pred_stride);
3023   } else if (width == 32) {
3024     IntraBlockCopy2D<32>(src, reference_stride, height, dest, pred_stride);
3025   } else if (width == 16) {
3026     IntraBlockCopy2D<16>(src, reference_stride, height, dest, pred_stride);
3027   } else if (width == 8) {
3028     IntraBlockCopy2D<8>(src, reference_stride, height, dest, pred_stride);
3029   } else {  // width == 4
3030     uint8x8_t left = Load4(src);
3031     uint8x8_t right = Load4(src + 1);
3032     src += reference_stride;
3033 
3034     uint16x4_t row = vget_low_u16(vaddl_u8(left, right));
3035 
3036     int y = height;
3037     do {
3038       left = Load4<0>(src, left);
3039       right = Load4<0>(src + 1, right);
3040       src += reference_stride;
3041       left = Load4<1>(src, left);
3042       right = Load4<1>(src + 1, right);
3043       src += reference_stride;
3044 
3045       const uint16x8_t below = vaddl_u8(left, right);
3046 
3047       const uint8x8_t result = vrshrn_n_u16(
3048           vaddq_u16(vcombine_u16(row, vget_low_u16(below)), below), 2);
3049       StoreLo4(dest, result);
3050       dest += pred_stride;
3051       StoreHi4(dest, result);
3052       dest += pred_stride;
3053 
3054       row = vget_high_u16(below);
3055       y -= 2;
3056     } while (y != 0);
3057   }
3058 }
3059 
Init8bpp()3060 void Init8bpp() {
3061   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
3062   assert(dsp != nullptr);
3063   dsp->convolve[0][0][0][1] = ConvolveHorizontal_NEON;
3064   dsp->convolve[0][0][1][0] = ConvolveVertical_NEON;
3065   dsp->convolve[0][0][1][1] = Convolve2D_NEON;
3066 
3067   dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_NEON;
3068   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_NEON;
3069   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_NEON;
3070   dsp->convolve[0][1][1][1] = ConvolveCompound2D_NEON;
3071 
3072   dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_NEON;
3073   dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_NEON;
3074   dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_NEON;
3075 
3076   dsp->convolve_scale[0] = ConvolveScale2D_NEON<false>;
3077   dsp->convolve_scale[1] = ConvolveScale2D_NEON<true>;
3078 }
3079 
3080 }  // namespace
3081 }  // namespace low_bitdepth
3082 
ConvolveInit_NEON()3083 void ConvolveInit_NEON() { low_bitdepth::Init8bpp(); }
3084 
3085 }  // namespace dsp
3086 }  // namespace libgav1
3087 
3088 #else   // !LIBGAV1_ENABLE_NEON
3089 
3090 namespace libgav1 {
3091 namespace dsp {
3092 
ConvolveInit_NEON()3093 void ConvolveInit_NEON() {}
3094 
3095 }  // namespace dsp
3096 }  // namespace libgav1
3097 #endif  // LIBGAV1_ENABLE_NEON
3098