1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14
15 #include "aom_dsp/arm/mem_neon.h"
16 #include "av1/common/arm/compound_convolve_neon.h"
17 #include "config/aom_config.h"
18 #include "config/av1_rtcd.h"
19
20 DECLARE_ALIGNED(16, static const uint8_t, dot_prod_permute_tbl[48]) = {
21 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
22 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
23 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
24 };
25
convolve4_4_2d_h(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)26 static inline int16x4_t convolve4_4_2d_h(uint8x16_t samples,
27 const int8x8_t x_filter,
28 const int32x4_t correction,
29 const uint8x16_t range_limit,
30 const uint8x16_t permute_tbl) {
31 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
32 int8x16_t clamped_samples =
33 vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
34
35 // Permute samples ready for dot product.
36 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
37 int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
38
39 // Accumulate dot product into 'correction' to account for range clamp.
40 int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, x_filter, 0);
41
42 // We halved the convolution filter values so -1 from the right shift.
43 return vshrn_n_s32(sum, ROUND0_BITS - 1);
44 }
45
convolve8_8_2d_h(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)46 static inline int16x8_t convolve8_8_2d_h(uint8x16_t samples,
47 const int8x8_t x_filter,
48 const int32x4_t correction,
49 const uint8x16_t range_limit,
50 const uint8x16x3_t permute_tbl) {
51 int8x16_t clamped_samples, permuted_samples[3];
52 int32x4_t sum[2];
53
54 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
55 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
56
57 // Permute samples ready for dot product. */
58 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
59 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
60 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
61 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
62 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
63 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
64
65 // Accumulate dot product into 'correction' to account for range clamp.
66 // First 4 output values.
67 sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
68 sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
69 // Second 4 output values.
70 sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
71 sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
72
73 // Narrow and re-pack.
74 // We halved the convolution filter values so -1 from the right shift.
75 return vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
76 vshrn_n_s32(sum[1], ROUND0_BITS - 1));
77 }
78
dist_wtd_convolve_2d_horiz_neon_dotprod(const uint8_t * src,int src_stride,int16_t * im_block,const int im_stride,const int16_t * x_filter_ptr,const int im_h,int w)79 static inline void dist_wtd_convolve_2d_horiz_neon_dotprod(
80 const uint8_t *src, int src_stride, int16_t *im_block, const int im_stride,
81 const int16_t *x_filter_ptr, const int im_h, int w) {
82 const int bd = 8;
83 // Dot product constants and other shims.
84 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
85 // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
86 // - which are generally faster than rounding shifts on modern CPUs.
87 const int32_t horiz_const =
88 ((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
89 // Halve the total because we will halve the filter values.
90 const int32x4_t correction =
91 vdupq_n_s32(((128 << FILTER_BITS) + horiz_const) / 2);
92 const uint8x16_t range_limit = vdupq_n_u8(128);
93
94 const uint8_t *src_ptr = src;
95 int16_t *dst_ptr = im_block;
96 int dst_stride = im_stride;
97 int height = im_h;
98
99 if (w == 4) {
100 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
101 // 4-tap filters are used for blocks having width <= 4.
102 // Filter values are even, so halve to reduce intermediate precision reqs.
103 const int8x8_t x_filter =
104 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
105
106 src_ptr += 2;
107
108 do {
109 uint8x16_t s0, s1, s2, s3;
110 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
111
112 int16x4_t d0 =
113 convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
114 int16x4_t d1 =
115 convolve4_4_2d_h(s1, x_filter, correction, range_limit, permute_tbl);
116 int16x4_t d2 =
117 convolve4_4_2d_h(s2, x_filter, correction, range_limit, permute_tbl);
118 int16x4_t d3 =
119 convolve4_4_2d_h(s3, x_filter, correction, range_limit, permute_tbl);
120
121 store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
122
123 src_ptr += 4 * src_stride;
124 dst_ptr += 4 * dst_stride;
125 height -= 4;
126 } while (height > 4);
127
128 do {
129 uint8x16_t s0 = vld1q_u8(src_ptr);
130
131 int16x4_t d0 =
132 convolve4_4_2d_h(s0, x_filter, correction, range_limit, permute_tbl);
133
134 vst1_s16(dst_ptr, d0);
135
136 src_ptr += src_stride;
137 dst_ptr += dst_stride;
138 } while (--height != 0);
139 } else {
140 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
141 // Filter values are even, so halve to reduce intermediate precision reqs.
142 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
143
144 do {
145 const uint8_t *s = src_ptr;
146 int16_t *d = dst_ptr;
147 int width = w;
148
149 do {
150 uint8x16_t s0, s1, s2, s3;
151 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
152
153 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
154 permute_tbl);
155 int16x8_t d1 = convolve8_8_2d_h(s1, x_filter, correction, range_limit,
156 permute_tbl);
157 int16x8_t d2 = convolve8_8_2d_h(s2, x_filter, correction, range_limit,
158 permute_tbl);
159 int16x8_t d3 = convolve8_8_2d_h(s3, x_filter, correction, range_limit,
160 permute_tbl);
161
162 store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
163
164 s += 8;
165 d += 8;
166 width -= 8;
167 } while (width > 0);
168 src_ptr += 4 * src_stride;
169 dst_ptr += 4 * dst_stride;
170 height -= 4;
171 } while (height > 4);
172
173 do {
174 const uint8_t *s = src_ptr;
175 int16_t *d = dst_ptr;
176 int width = w;
177
178 do {
179 uint8x16_t s0 = vld1q_u8(s);
180
181 int16x8_t d0 = convolve8_8_2d_h(s0, x_filter, correction, range_limit,
182 permute_tbl);
183
184 vst1q_s16(d, d0);
185
186 s += 8;
187 d += 8;
188 width -= 8;
189 } while (width > 0);
190 src_ptr += src_stride;
191 dst_ptr += dst_stride;
192 } while (--height != 0);
193 }
194 }
195
av1_dist_wtd_convolve_2d_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)196 void av1_dist_wtd_convolve_2d_neon_dotprod(
197 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
198 int h, const InterpFilterParams *filter_params_x,
199 const InterpFilterParams *filter_params_y, const int subpel_x_qn,
200 const int subpel_y_qn, ConvolveParams *conv_params) {
201 assert(w % 4 == 0);
202 assert(h % 4 == 0);
203
204 DECLARE_ALIGNED(16, int16_t,
205 im_block[(MAX_SB_SIZE + SUBPEL_TAPS - 1) * MAX_SB_SIZE]);
206
207 const int y_filter_taps = get_filter_tap(filter_params_y, subpel_y_qn);
208 const int clamped_y_taps = y_filter_taps < 6 ? 6 : y_filter_taps;
209
210 const int im_h = h + clamped_y_taps - 1;
211 const int im_stride = MAX_SB_SIZE;
212 const int vert_offset = clamped_y_taps / 2 - 1;
213 const int horiz_offset = filter_params_x->taps / 2 - 1;
214 const uint8_t *src_ptr = src - vert_offset * src_stride - horiz_offset;
215 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
216 filter_params_x, subpel_x_qn & SUBPEL_MASK);
217 const int16_t *y_filter_ptr = av1_get_interp_filter_subpel_kernel(
218 filter_params_y, subpel_y_qn & SUBPEL_MASK);
219
220 const int16x8_t y_filter = vld1q_s16(y_filter_ptr);
221
222 dist_wtd_convolve_2d_horiz_neon_dotprod(src_ptr, src_stride, im_block,
223 im_stride, x_filter_ptr, im_h, w);
224
225 if (clamped_y_taps == 6) {
226 if (conv_params->do_average) {
227 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
228 dist_wtd_convolve_2d_vert_6tap_dist_wtd_avg_neon(
229 im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
230 w);
231 } else {
232 dist_wtd_convolve_2d_vert_6tap_avg_neon(im_block, im_stride, dst8,
233 dst8_stride, conv_params,
234 y_filter, h, w);
235 }
236 } else {
237 dist_wtd_convolve_2d_vert_6tap_neon(im_block, im_stride, conv_params,
238 y_filter, h, w);
239 }
240 } else {
241 if (conv_params->do_average) {
242 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
243 dist_wtd_convolve_2d_vert_8tap_dist_wtd_avg_neon(
244 im_block, im_stride, dst8, dst8_stride, conv_params, y_filter, h,
245 w);
246 } else {
247 dist_wtd_convolve_2d_vert_8tap_avg_neon(im_block, im_stride, dst8,
248 dst8_stride, conv_params,
249 y_filter, h, w);
250 }
251 } else {
252 dist_wtd_convolve_2d_vert_8tap_neon(im_block, im_stride, conv_params,
253 y_filter, h, w);
254 }
255 }
256 }
257
convolve4_4_x(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16_t permute_tbl)258 static inline uint16x4_t convolve4_4_x(uint8x16_t samples,
259 const int8x8_t x_filter,
260 const int32x4_t correction,
261 const uint8x16_t range_limit,
262 const uint8x16_t permute_tbl) {
263 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
264 int8x16_t clamped_samples =
265 vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
266
267 // Permute samples ready for dot product.
268 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
269 int8x16_t permuted_samples = vqtbl1q_s8(clamped_samples, permute_tbl);
270
271 // Accumulate dot product into 'correction' to account for range clamp.
272 int32x4_t sum = vdotq_lane_s32(correction, permuted_samples, x_filter, 0);
273
274 // We halved the convolution filter values so -1 from the right shift.
275 return vreinterpret_u16_s16(vshrn_n_s32(sum, ROUND0_BITS - 1));
276 }
277
convolve8_8_x(uint8x16_t samples,const int8x8_t x_filter,const int32x4_t correction,const uint8x16_t range_limit,const uint8x16x3_t permute_tbl)278 static inline uint16x8_t convolve8_8_x(uint8x16_t samples,
279 const int8x8_t x_filter,
280 const int32x4_t correction,
281 const uint8x16_t range_limit,
282 const uint8x16x3_t permute_tbl) {
283 int8x16_t clamped_samples, permuted_samples[3];
284 int32x4_t sum[2];
285
286 // Clamp sample range to [-128, 127] for 8-bit signed dot product.
287 clamped_samples = vreinterpretq_s8_u8(vsubq_u8(samples, range_limit));
288
289 // Permute samples ready for dot product. */
290 // { 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6 }
291 permuted_samples[0] = vqtbl1q_s8(clamped_samples, permute_tbl.val[0]);
292 // { 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10 }
293 permuted_samples[1] = vqtbl1q_s8(clamped_samples, permute_tbl.val[1]);
294 // { 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14 }
295 permuted_samples[2] = vqtbl1q_s8(clamped_samples, permute_tbl.val[2]);
296
297 // Accumulate dot product into 'correction' to account for range clamp.
298 // First 4 output values.
299 sum[0] = vdotq_lane_s32(correction, permuted_samples[0], x_filter, 0);
300 sum[0] = vdotq_lane_s32(sum[0], permuted_samples[1], x_filter, 1);
301 // Second 4 output values.
302 sum[1] = vdotq_lane_s32(correction, permuted_samples[1], x_filter, 0);
303 sum[1] = vdotq_lane_s32(sum[1], permuted_samples[2], x_filter, 1);
304
305 // Narrow and re-pack.
306 // We halved the convolution filter values so -1 from the right shift.
307 int16x8_t res = vcombine_s16(vshrn_n_s32(sum[0], ROUND0_BITS - 1),
308 vshrn_n_s32(sum[1], ROUND0_BITS - 1));
309 return vreinterpretq_u16_s16(res);
310 }
311
dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)312 static inline void dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
313 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
314 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
315 ConvolveParams *conv_params) {
316 assert(w % 4 == 0);
317 assert(h % 4 == 0);
318
319 const int bd = 8;
320 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
321 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
322 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
323 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
324
325 const uint16_t fwd_offset = conv_params->fwd_offset;
326 const uint16_t bck_offset = conv_params->bck_offset;
327
328 // Horizontal filter.
329 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
330 filter_params_x, subpel_x_qn & SUBPEL_MASK);
331 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
332
333 // Dot-product constants and other shims.
334 const uint8x16_t range_limit = vdupq_n_u8(128);
335 // Fold round_offset into the dot-product filter correction constant. The
336 // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
337 // shifts - which are generally faster than rounding shifts on modern CPUs.
338 // Halve the total because we will halve the filter values.
339 int32x4_t correction =
340 vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
341 (1 << (ROUND0_BITS - 1))) /
342 2);
343
344 const int horiz_offset = filter_params_x->taps / 2 - 1;
345 const uint8_t *src_ptr = src - horiz_offset;
346 CONV_BUF_TYPE *dst_ptr = conv_params->dst;
347 uint8_t *dst8_ptr = dst8;
348 int dst_stride = conv_params->dst_stride;
349 int height = h;
350
351 if (w == 4) {
352 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
353 // 4-tap filters are used for blocks having width <= 4.
354 // Filter values are even, so halve to reduce intermediate precision reqs.
355 const int8x8_t x_filter =
356 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
357
358 src_ptr += 2;
359
360 do {
361 uint8x16_t s0, s1, s2, s3;
362 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
363
364 uint16x4_t d0 =
365 convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
366 uint16x4_t d1 =
367 convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
368 uint16x4_t d2 =
369 convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
370 uint16x4_t d3 =
371 convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
372
373 uint16x4_t dd0, dd1, dd2, dd3;
374 load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
375
376 uint8x8_t d01_u8, d23_u8;
377 compute_dist_wtd_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
378 bck_offset, round_offset_vec, &d01_u8, &d23_u8);
379
380 store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
381 store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
382
383 src_ptr += 4 * src_stride;
384 dst_ptr += 4 * dst_stride;
385 dst8_ptr += 4 * dst8_stride;
386 height -= 4;
387 } while (height != 0);
388 } else {
389 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
390 // Filter values are even, so halve to reduce intermediate precision reqs.
391 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
392
393 do {
394 const uint8_t *s = src_ptr;
395 CONV_BUF_TYPE *d = dst_ptr;
396 uint8_t *d_u8 = dst8_ptr;
397 int width = w;
398
399 do {
400 uint8x16_t s0, s1, s2, s3;
401 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
402
403 uint16x8_t d0 =
404 convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
405 uint16x8_t d1 =
406 convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
407 uint16x8_t d2 =
408 convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
409 uint16x8_t d3 =
410 convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
411
412 uint16x8_t dd0, dd1, dd2, dd3;
413 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
414
415 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
416 compute_dist_wtd_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3, fwd_offset,
417 bck_offset, round_offset_vec, &d0_u8, &d1_u8,
418 &d2_u8, &d3_u8);
419
420 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
421
422 s += 8;
423 d += 8;
424 d_u8 += 8;
425 width -= 8;
426 } while (width != 0);
427 src_ptr += 4 * src_stride;
428 dst_ptr += 4 * dst_stride;
429 dst8_ptr += 4 * dst8_stride;
430 height -= 4;
431 } while (height != 0);
432 }
433 }
434
dist_wtd_convolve_x_avg_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)435 static inline void dist_wtd_convolve_x_avg_neon_dotprod(
436 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
437 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
438 ConvolveParams *conv_params) {
439 assert(w % 4 == 0);
440 assert(h % 4 == 0);
441
442 const int bd = 8;
443 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
444 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
445 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
446 const int16x8_t round_offset_vec = vdupq_n_s16(round_offset);
447
448 // Horizontal filter.
449 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
450 filter_params_x, subpel_x_qn & SUBPEL_MASK);
451 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
452
453 // Dot-product constants and other shims.
454 const uint8x16_t range_limit = vdupq_n_u8(128);
455 // Fold round_offset into the dot-product filter correction constant. The
456 // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
457 // shifts - which are generally faster than rounding shifts on modern CPUs.
458 // Halve the total because we will halve the filter values.
459 int32x4_t correction =
460 vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
461 (1 << (ROUND0_BITS - 1))) /
462 2);
463
464 const int horiz_offset = filter_params_x->taps / 2 - 1;
465 const uint8_t *src_ptr = src - horiz_offset;
466 CONV_BUF_TYPE *dst_ptr = conv_params->dst;
467 uint8_t *dst8_ptr = dst8;
468 int dst_stride = conv_params->dst_stride;
469 int height = h;
470
471 if (w == 4) {
472 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
473 // 4-tap filters are used for blocks having width <= 4.
474 // Filter values are even, so halve to reduce intermediate precision reqs.
475 const int8x8_t x_filter =
476 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
477
478 src_ptr += 2;
479
480 do {
481 uint8x16_t s0, s1, s2, s3;
482 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
483
484 uint16x4_t d0 =
485 convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
486 uint16x4_t d1 =
487 convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
488 uint16x4_t d2 =
489 convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
490 uint16x4_t d3 =
491 convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
492
493 uint16x4_t dd0, dd1, dd2, dd3;
494 load_u16_4x4(dst_ptr, dst_stride, &dd0, &dd1, &dd2, &dd3);
495
496 uint8x8_t d01_u8, d23_u8;
497 compute_basic_avg_4x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
498 round_offset_vec, &d01_u8, &d23_u8);
499
500 store_u8x4_strided_x2(dst8_ptr + 0 * dst8_stride, dst8_stride, d01_u8);
501 store_u8x4_strided_x2(dst8_ptr + 2 * dst8_stride, dst8_stride, d23_u8);
502
503 src_ptr += 4 * src_stride;
504 dst_ptr += 4 * dst_stride;
505 dst8_ptr += 4 * dst8_stride;
506 height -= 4;
507 } while (height != 0);
508 } else {
509 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
510 // Filter values are even, so halve to reduce intermediate precision reqs.
511 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
512
513 do {
514 const uint8_t *s = src_ptr;
515 CONV_BUF_TYPE *d = dst_ptr;
516 uint8_t *d_u8 = dst8_ptr;
517 int width = w;
518
519 do {
520 uint8x16_t s0, s1, s2, s3;
521 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
522
523 uint16x8_t d0 =
524 convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
525 uint16x8_t d1 =
526 convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
527 uint16x8_t d2 =
528 convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
529 uint16x8_t d3 =
530 convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
531
532 uint16x8_t dd0, dd1, dd2, dd3;
533 load_u16_8x4(d, dst_stride, &dd0, &dd1, &dd2, &dd3);
534
535 uint8x8_t d0_u8, d1_u8, d2_u8, d3_u8;
536 compute_basic_avg_8x4(dd0, dd1, dd2, dd3, d0, d1, d2, d3,
537 round_offset_vec, &d0_u8, &d1_u8, &d2_u8, &d3_u8);
538
539 store_u8_8x4(d_u8, dst8_stride, d0_u8, d1_u8, d2_u8, d3_u8);
540
541 s += 8;
542 d += 8;
543 d_u8 += 8;
544 width -= 8;
545 } while (width != 0);
546 src_ptr += 4 * src_stride;
547 dst_ptr += 4 * dst_stride;
548 dst8_ptr += 4 * dst8_stride;
549 height -= 4;
550 } while (height != 0);
551 }
552 }
553
dist_wtd_convolve_x_neon_dotprod(const uint8_t * src,int src_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)554 static inline void dist_wtd_convolve_x_neon_dotprod(
555 const uint8_t *src, int src_stride, int w, int h,
556 const InterpFilterParams *filter_params_x, const int subpel_x_qn,
557 ConvolveParams *conv_params) {
558 assert(w % 4 == 0);
559 assert(h % 4 == 0);
560
561 const int bd = 8;
562 const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
563 const int16_t round_offset = (1 << (offset_bits - COMPOUND_ROUND1_BITS)) +
564 (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
565
566 // Horizontal filter.
567 const int16_t *x_filter_ptr = av1_get_interp_filter_subpel_kernel(
568 filter_params_x, subpel_x_qn & SUBPEL_MASK);
569 const int16x8_t x_filter_s16 = vld1q_s16(x_filter_ptr);
570
571 // Dot-product constants and other shims.
572 const uint8x16_t range_limit = vdupq_n_u8(128);
573 // Fold round_offset into the dot-product filter correction constant. The
574 // additional shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding
575 // shifts - which are generally faster than rounding shifts on modern CPUs.
576 // Halve the total because we will halve the vilter values.
577 int32x4_t correction =
578 vdupq_n_s32(((128 << FILTER_BITS) + (round_offset << ROUND0_BITS) +
579 (1 << (ROUND0_BITS - 1))) /
580 2);
581
582 const int horiz_offset = filter_params_x->taps / 2 - 1;
583 const uint8_t *src_ptr = src - horiz_offset;
584 CONV_BUF_TYPE *dst_ptr = conv_params->dst;
585 int dst_stride = conv_params->dst_stride;
586 int height = h;
587
588 if (w == 4) {
589 const uint8x16_t permute_tbl = vld1q_u8(dot_prod_permute_tbl);
590 // 4-tap filters are used for blocks having width <= 4.
591 // Filter values are even, so halve to reduce intermediate precision reqs.
592 const int8x8_t x_filter =
593 vshrn_n_s16(vcombine_s16(vld1_s16(x_filter_ptr + 2), vdup_n_s16(0)), 1);
594
595 src_ptr += 2;
596
597 do {
598 uint8x16_t s0, s1, s2, s3;
599 load_u8_16x4(src_ptr, src_stride, &s0, &s1, &s2, &s3);
600
601 uint16x4_t d0 =
602 convolve4_4_x(s0, x_filter, correction, range_limit, permute_tbl);
603 uint16x4_t d1 =
604 convolve4_4_x(s1, x_filter, correction, range_limit, permute_tbl);
605 uint16x4_t d2 =
606 convolve4_4_x(s2, x_filter, correction, range_limit, permute_tbl);
607 uint16x4_t d3 =
608 convolve4_4_x(s3, x_filter, correction, range_limit, permute_tbl);
609
610 store_u16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
611
612 src_ptr += 4 * src_stride;
613 dst_ptr += 4 * dst_stride;
614 height -= 4;
615 } while (height != 0);
616 } else {
617 const uint8x16x3_t permute_tbl = vld1q_u8_x3(dot_prod_permute_tbl);
618 // Filter values are even, so halve to reduce intermediate precision reqs.
619 const int8x8_t x_filter = vshrn_n_s16(x_filter_s16, 1);
620
621 do {
622 const uint8_t *s = src_ptr;
623 CONV_BUF_TYPE *d = dst_ptr;
624 int width = w;
625
626 do {
627 uint8x16_t s0, s1, s2, s3;
628 load_u8_16x4(s, src_stride, &s0, &s1, &s2, &s3);
629
630 uint16x8_t d0 =
631 convolve8_8_x(s0, x_filter, correction, range_limit, permute_tbl);
632 uint16x8_t d1 =
633 convolve8_8_x(s1, x_filter, correction, range_limit, permute_tbl);
634 uint16x8_t d2 =
635 convolve8_8_x(s2, x_filter, correction, range_limit, permute_tbl);
636 uint16x8_t d3 =
637 convolve8_8_x(s3, x_filter, correction, range_limit, permute_tbl);
638
639 store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
640
641 s += 8;
642 d += 8;
643 width -= 8;
644 } while (width != 0);
645 src_ptr += 4 * src_stride;
646 dst_ptr += 4 * dst_stride;
647 height -= 4;
648 } while (height != 0);
649 }
650 }
651
av1_dist_wtd_convolve_x_neon_dotprod(const uint8_t * src,int src_stride,uint8_t * dst8,int dst8_stride,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)652 void av1_dist_wtd_convolve_x_neon_dotprod(
653 const uint8_t *src, int src_stride, uint8_t *dst8, int dst8_stride, int w,
654 int h, const InterpFilterParams *filter_params_x, const int subpel_x_qn,
655 ConvolveParams *conv_params) {
656 if (conv_params->do_average) {
657 if (UNLIKELY(conv_params->use_dist_wtd_comp_avg)) {
658 dist_wtd_convolve_x_dist_wtd_avg_neon_dotprod(
659 src, src_stride, dst8, dst8_stride, w, h, filter_params_x,
660 subpel_x_qn, conv_params);
661 } else {
662 dist_wtd_convolve_x_avg_neon_dotprod(src, src_stride, dst8, dst8_stride,
663 w, h, filter_params_x, subpel_x_qn,
664 conv_params);
665 }
666 } else {
667 dist_wtd_convolve_x_neon_dotprod(src, src_stride, w, h, filter_params_x,
668 subpel_x_qn, conv_params);
669 }
670 }
671