xref: /aosp_15_r20/external/libaom/av1/common/arm/wiener_convolve_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "config/aom_config.h"
16 #include "config/av1_rtcd.h"
17 
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_dsp/arm/transpose_neon.h"
20 #include "aom_dsp/txfm_common.h"
21 #include "aom_ports/mem.h"
22 #include "av1/common/common.h"
23 #include "av1/common/restoration.h"
24 
wiener_convolve5_8_2d_h(const uint8x8_t t0,const uint8x8_t t1,const uint8x8_t t2,const uint8x8_t t3,const uint8x8_t t4,const int16x4_t x_filter,const int32x4_t round_vec,const uint16x8_t im_max_val)25 static inline uint16x8_t wiener_convolve5_8_2d_h(
26     const uint8x8_t t0, const uint8x8_t t1, const uint8x8_t t2,
27     const uint8x8_t t3, const uint8x8_t t4, const int16x4_t x_filter,
28     const int32x4_t round_vec, const uint16x8_t im_max_val) {
29   // Since the Wiener filter is symmetric about the middle tap (tap 2) add
30   // mirrored source elements before multiplying filter coefficients.
31   int16x8_t s04 = vreinterpretq_s16_u16(vaddl_u8(t0, t4));
32   int16x8_t s13 = vreinterpretq_s16_u16(vaddl_u8(t1, t3));
33   int16x8_t s2 = vreinterpretq_s16_u16(vmovl_u8(t2));
34 
35   // x_filter[0] = 0. (5-tap filters are 0-padded to 7 taps.)
36   int32x4_t sum_lo = vmlal_lane_s16(round_vec, vget_low_s16(s04), x_filter, 1);
37   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s13), x_filter, 2);
38   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s2), x_filter, 3);
39 
40   int32x4_t sum_hi = vmlal_lane_s16(round_vec, vget_high_s16(s04), x_filter, 1);
41   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s13), x_filter, 2);
42   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s2), x_filter, 3);
43 
44   uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum_lo, WIENER_ROUND0_BITS),
45                                 vqrshrun_n_s32(sum_hi, WIENER_ROUND0_BITS));
46 
47   return vminq_u16(res, im_max_val);
48 }
49 
convolve_add_src_horiz_5tap_neon(const uint8_t * src_ptr,ptrdiff_t src_stride,uint16_t * dst_ptr,ptrdiff_t dst_stride,int w,int h,const int16x4_t x_filter,const int32x4_t round_vec,const uint16x8_t im_max_val)50 static inline void convolve_add_src_horiz_5tap_neon(
51     const uint8_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
52     ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter,
53     const int32x4_t round_vec, const uint16x8_t im_max_val) {
54   do {
55     const uint8_t *s = src_ptr;
56     uint16_t *d = dst_ptr;
57     int width = w;
58 
59     do {
60       uint8x8_t s0, s1, s2, s3, s4;
61       load_u8_8x5(s, 1, &s0, &s1, &s2, &s3, &s4);
62 
63       uint16x8_t d0 = wiener_convolve5_8_2d_h(s0, s1, s2, s3, s4, x_filter,
64                                               round_vec, im_max_val);
65 
66       vst1q_u16(d, d0);
67 
68       s += 8;
69       d += 8;
70       width -= 8;
71     } while (width != 0);
72     src_ptr += src_stride;
73     dst_ptr += dst_stride;
74   } while (--h != 0);
75 }
76 
wiener_convolve7_8_2d_h(const uint8x8_t t0,const uint8x8_t t1,const uint8x8_t t2,const uint8x8_t t3,const uint8x8_t t4,const uint8x8_t t5,const uint8x8_t t6,const int16x4_t x_filter,const int32x4_t round_vec,const uint16x8_t im_max_val)77 static inline uint16x8_t wiener_convolve7_8_2d_h(
78     const uint8x8_t t0, const uint8x8_t t1, const uint8x8_t t2,
79     const uint8x8_t t3, const uint8x8_t t4, const uint8x8_t t5,
80     const uint8x8_t t6, const int16x4_t x_filter, const int32x4_t round_vec,
81     const uint16x8_t im_max_val) {
82   // Since the Wiener filter is symmetric about the middle tap (tap 3) add
83   // mirrored source elements before multiplying by filter coefficients.
84   int16x8_t s06 = vreinterpretq_s16_u16(vaddl_u8(t0, t6));
85   int16x8_t s15 = vreinterpretq_s16_u16(vaddl_u8(t1, t5));
86   int16x8_t s24 = vreinterpretq_s16_u16(vaddl_u8(t2, t4));
87   int16x8_t s3 = vreinterpretq_s16_u16(vmovl_u8(t3));
88 
89   int32x4_t sum_lo = vmlal_lane_s16(round_vec, vget_low_s16(s06), x_filter, 0);
90   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s15), x_filter, 1);
91   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s24), x_filter, 2);
92   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s3), x_filter, 3);
93 
94   int32x4_t sum_hi = vmlal_lane_s16(round_vec, vget_high_s16(s06), x_filter, 0);
95   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s15), x_filter, 1);
96   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s24), x_filter, 2);
97   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s3), x_filter, 3);
98 
99   uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum_lo, WIENER_ROUND0_BITS),
100                                 vqrshrun_n_s32(sum_hi, WIENER_ROUND0_BITS));
101 
102   return vminq_u16(res, im_max_val);
103 }
104 
convolve_add_src_horiz_7tap_neon(const uint8_t * src_ptr,ptrdiff_t src_stride,uint16_t * dst_ptr,ptrdiff_t dst_stride,int w,int h,const int16x4_t x_filter,const int32x4_t round_vec,const uint16x8_t im_max_val)105 static inline void convolve_add_src_horiz_7tap_neon(
106     const uint8_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
107     ptrdiff_t dst_stride, int w, int h, const int16x4_t x_filter,
108     const int32x4_t round_vec, const uint16x8_t im_max_val) {
109   do {
110     const uint8_t *s = src_ptr;
111     uint16_t *d = dst_ptr;
112     int width = w;
113 
114     do {
115       uint8x8_t s0, s1, s2, s3, s4, s5, s6;
116       load_u8_8x7(s, 1, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
117 
118       uint16x8_t d0 = wiener_convolve7_8_2d_h(s0, s1, s2, s3, s4, s5, s6,
119                                               x_filter, round_vec, im_max_val);
120 
121       vst1q_u16(d, d0);
122 
123       s += 8;
124       d += 8;
125       width -= 8;
126     } while (width != 0);
127     src_ptr += src_stride;
128     dst_ptr += dst_stride;
129   } while (--h != 0);
130 }
131 
wiener_convolve5_8_2d_v(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x4_t y_filter,const int32x4_t round_vec)132 static inline uint8x8_t wiener_convolve5_8_2d_v(
133     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
134     const int16x8_t s3, const int16x8_t s4, const int16x4_t y_filter,
135     const int32x4_t round_vec) {
136   // Since the Wiener filter is symmetric about the middle tap (tap 2) add
137   // mirrored source elements before multiplying by filter coefficients.
138   int16x8_t s04 = vaddq_s16(s0, s4);
139   int16x8_t s13 = vaddq_s16(s1, s3);
140 
141   int32x4_t sum_lo = vmlal_lane_s16(round_vec, vget_low_s16(s04), y_filter, 1);
142   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s13), y_filter, 2);
143   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s2), y_filter, 3);
144 
145   int32x4_t sum_hi = vmlal_lane_s16(round_vec, vget_high_s16(s04), y_filter, 1);
146   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s13), y_filter, 2);
147   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s2), y_filter, 3);
148 
149   int16x4_t res_lo = vshrn_n_s32(sum_lo, 2 * FILTER_BITS - WIENER_ROUND0_BITS);
150   int16x4_t res_hi = vshrn_n_s32(sum_hi, 2 * FILTER_BITS - WIENER_ROUND0_BITS);
151 
152   return vqmovun_s16(vcombine_s16(res_lo, res_hi));
153 }
154 
convolve_add_src_vert_5tap_neon(const uint16_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,int w,int h,const int16x4_t y_filter,const int32x4_t round_vec)155 static inline void convolve_add_src_vert_5tap_neon(
156     const uint16_t *src, ptrdiff_t src_stride, uint8_t *dst,
157     ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,
158     const int32x4_t round_vec) {
159   do {
160     const int16_t *s = (int16_t *)src;
161     uint8_t *d = dst;
162     int height = h;
163 
164     while (height > 3) {
165       int16x8_t s0, s1, s2, s3, s4, s5, s6, s7;
166       load_s16_8x8(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
167 
168       uint8x8_t d0 =
169           wiener_convolve5_8_2d_v(s0, s1, s2, s3, s4, y_filter, round_vec);
170       uint8x8_t d1 =
171           wiener_convolve5_8_2d_v(s1, s2, s3, s4, s5, y_filter, round_vec);
172       uint8x8_t d2 =
173           wiener_convolve5_8_2d_v(s2, s3, s4, s5, s6, y_filter, round_vec);
174       uint8x8_t d3 =
175           wiener_convolve5_8_2d_v(s3, s4, s5, s6, s7, y_filter, round_vec);
176 
177       store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
178 
179       s += 4 * src_stride;
180       d += 4 * dst_stride;
181       height -= 4;
182     }
183 
184     while (height-- != 0) {
185       int16x8_t s0, s1, s2, s3, s4;
186       load_s16_8x5(s, src_stride, &s0, &s1, &s2, &s3, &s4);
187 
188       uint8x8_t d0 =
189           wiener_convolve5_8_2d_v(s0, s1, s2, s3, s4, y_filter, round_vec);
190 
191       vst1_u8(d, d0);
192 
193       d += dst_stride;
194       s += src_stride;
195     }
196 
197     src += 8;
198     dst += 8;
199     w -= 8;
200   } while (w != 0);
201 }
202 
wiener_convolve7_8_2d_v(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x8_t s4,const int16x8_t s5,const int16x8_t s6,const int16x4_t y_filter,const int32x4_t round_vec)203 static inline uint8x8_t wiener_convolve7_8_2d_v(
204     const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
205     const int16x8_t s3, const int16x8_t s4, const int16x8_t s5,
206     const int16x8_t s6, const int16x4_t y_filter, const int32x4_t round_vec) {
207   // Since the Wiener filter is symmetric about the middle tap (tap 3) add
208   // mirrored source elements before multiplying by filter coefficients.
209   int16x8_t s06 = vaddq_s16(s0, s6);
210   int16x8_t s15 = vaddq_s16(s1, s5);
211   int16x8_t s24 = vaddq_s16(s2, s4);
212 
213   int32x4_t sum_lo = vmlal_lane_s16(round_vec, vget_low_s16(s06), y_filter, 0);
214   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s15), y_filter, 1);
215   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s24), y_filter, 2);
216   sum_lo = vmlal_lane_s16(sum_lo, vget_low_s16(s3), y_filter, 3);
217 
218   int32x4_t sum_hi = vmlal_lane_s16(round_vec, vget_high_s16(s06), y_filter, 0);
219   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s15), y_filter, 1);
220   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s24), y_filter, 2);
221   sum_hi = vmlal_lane_s16(sum_hi, vget_high_s16(s3), y_filter, 3);
222 
223   int16x4_t res_lo = vshrn_n_s32(sum_lo, 2 * FILTER_BITS - WIENER_ROUND0_BITS);
224   int16x4_t res_hi = vshrn_n_s32(sum_hi, 2 * FILTER_BITS - WIENER_ROUND0_BITS);
225 
226   return vqmovun_s16(vcombine_s16(res_lo, res_hi));
227 }
228 
convolve_add_src_vert_7tap_neon(const uint16_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,int w,int h,const int16x4_t y_filter,const int32x4_t round_vec)229 static inline void convolve_add_src_vert_7tap_neon(
230     const uint16_t *src, ptrdiff_t src_stride, uint8_t *dst,
231     ptrdiff_t dst_stride, int w, int h, const int16x4_t y_filter,
232     const int32x4_t round_vec) {
233   do {
234     const int16_t *s = (int16_t *)src;
235     uint8_t *d = dst;
236     int height = h;
237 
238     while (height > 3) {
239       int16x8_t s0, s1, s2, s3, s4, s5, s6, s7, s8, s9;
240       load_s16_8x10(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7, &s8,
241                     &s9);
242 
243       uint8x8_t d0 = wiener_convolve7_8_2d_v(s0, s1, s2, s3, s4, s5, s6,
244                                              y_filter, round_vec);
245       uint8x8_t d1 = wiener_convolve7_8_2d_v(s1, s2, s3, s4, s5, s6, s7,
246                                              y_filter, round_vec);
247       uint8x8_t d2 = wiener_convolve7_8_2d_v(s2, s3, s4, s5, s6, s7, s8,
248                                              y_filter, round_vec);
249       uint8x8_t d3 = wiener_convolve7_8_2d_v(s3, s4, s5, s6, s7, s8, s9,
250                                              y_filter, round_vec);
251 
252       store_u8_8x4(d, dst_stride, d0, d1, d2, d3);
253 
254       s += 4 * src_stride;
255       d += 4 * dst_stride;
256       height -= 4;
257     }
258 
259     while (height-- != 0) {
260       int16x8_t s0, s1, s2, s3, s4, s5, s6;
261       load_s16_8x7(s, src_stride, &s0, &s1, &s2, &s3, &s4, &s5, &s6);
262 
263       uint8x8_t d0 = wiener_convolve7_8_2d_v(s0, s1, s2, s3, s4, s5, s6,
264                                              y_filter, round_vec);
265 
266       vst1_u8(d, d0);
267 
268       d += dst_stride;
269       s += src_stride;
270     }
271 
272     src += 8;
273     dst += 8;
274     w -= 8;
275   } while (w != 0);
276 }
277 
get_wiener_filter_taps(const int16_t * filter)278 static inline int get_wiener_filter_taps(const int16_t *filter) {
279   assert(filter[7] == 0);
280   if (filter[0] == 0 && filter[6] == 0) {
281     return WIENER_WIN_REDUCED;
282   }
283   return WIENER_WIN;
284 }
285 
286 // Wiener filter 2D
287 // Apply horizontal filter and store in a temporary buffer. When applying
288 // vertical filter, overwrite the original pixel values.
av1_wiener_convolve_add_src_neon(const uint8_t * src,ptrdiff_t src_stride,uint8_t * dst,ptrdiff_t dst_stride,const int16_t * x_filter,int x_step_q4,const int16_t * y_filter,int y_step_q4,int w,int h,const WienerConvolveParams * conv_params)289 void av1_wiener_convolve_add_src_neon(const uint8_t *src, ptrdiff_t src_stride,
290                                       uint8_t *dst, ptrdiff_t dst_stride,
291                                       const int16_t *x_filter, int x_step_q4,
292                                       const int16_t *y_filter, int y_step_q4,
293                                       int w, int h,
294                                       const WienerConvolveParams *conv_params) {
295   (void)x_step_q4;
296   (void)y_step_q4;
297   (void)conv_params;
298 
299   assert(w % 8 == 0);
300   assert(w <= MAX_SB_SIZE && h <= MAX_SB_SIZE);
301   assert(x_step_q4 == 16 && y_step_q4 == 16);
302   assert(x_filter[7] == 0 && y_filter[7] == 0);
303   // For bd == 8, assert horizontal filtering output will not exceed 15-bit:
304   assert(8 + 1 + FILTER_BITS - conv_params->round_0 <= 15);
305 
306   DECLARE_ALIGNED(16, uint16_t,
307                   im_block[(MAX_SB_SIZE + WIENER_WIN - 1) * MAX_SB_SIZE]);
308 
309   const int x_filter_taps = get_wiener_filter_taps(x_filter);
310   const int y_filter_taps = get_wiener_filter_taps(y_filter);
311   int16x4_t x_filter_s16 = vld1_s16(x_filter);
312   int16x4_t y_filter_s16 = vld1_s16(y_filter);
313   // Add 128 to tap 3. (Needed for rounding.)
314   x_filter_s16 = vadd_s16(x_filter_s16, vcreate_s16(128ULL << 48));
315   y_filter_s16 = vadd_s16(y_filter_s16, vcreate_s16(128ULL << 48));
316 
317   const int im_stride = MAX_SB_SIZE;
318   const int im_h = h + y_filter_taps - 1;
319   const int horiz_offset = x_filter_taps / 2;
320   const int vert_offset = (y_filter_taps / 2) * (int)src_stride;
321 
322   const int bd = 8;
323   const uint16x8_t im_max_val =
324       vdupq_n_u16((1 << (bd + 1 + FILTER_BITS - WIENER_ROUND0_BITS)) - 1);
325   const int32x4_t horiz_round_vec = vdupq_n_s32(1 << (bd + FILTER_BITS - 1));
326 
327   const int32x4_t vert_round_vec =
328       vdupq_n_s32((1 << (2 * FILTER_BITS - WIENER_ROUND0_BITS - 1)) -
329                   (1 << (bd + (2 * FILTER_BITS - WIENER_ROUND0_BITS) - 1)));
330 
331   if (x_filter_taps == WIENER_WIN_REDUCED) {
332     convolve_add_src_horiz_5tap_neon(src - horiz_offset - vert_offset,
333                                      src_stride, im_block, im_stride, w, im_h,
334                                      x_filter_s16, horiz_round_vec, im_max_val);
335   } else {
336     convolve_add_src_horiz_7tap_neon(src - horiz_offset - vert_offset,
337                                      src_stride, im_block, im_stride, w, im_h,
338                                      x_filter_s16, horiz_round_vec, im_max_val);
339   }
340 
341   if (y_filter_taps == WIENER_WIN_REDUCED) {
342     convolve_add_src_vert_5tap_neon(im_block, im_stride, dst, dst_stride, w, h,
343                                     y_filter_s16, vert_round_vec);
344   } else {
345     convolve_add_src_vert_7tap_neon(im_block, im_stride, dst, dst_stride, w, h,
346                                     y_filter_s16, vert_round_vec);
347   }
348 }
349