xref: /aosp_15_r20/external/libaom/av1/common/arm/warp_plane_neon.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 #ifndef AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_
12 #define AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_
13 
14 #include <assert.h>
15 #include <arm_neon.h>
16 #include <memory.h>
17 #include <math.h>
18 
19 #include "aom_dsp/aom_dsp_common.h"
20 #include "aom_dsp/arm/sum_neon.h"
21 #include "aom_dsp/arm/transpose_neon.h"
22 #include "aom_ports/mem.h"
23 #include "config/av1_rtcd.h"
24 #include "av1/common/warped_motion.h"
25 #include "av1/common/scale.h"
26 
27 static AOM_FORCE_INLINE int16x8_t horizontal_filter_4x1_f4(const uint8x16_t in,
28                                                            int sx, int alpha);
29 
30 static AOM_FORCE_INLINE int16x8_t horizontal_filter_8x1_f8(const uint8x16_t in,
31                                                            int sx, int alpha);
32 
33 static AOM_FORCE_INLINE int16x8_t horizontal_filter_4x1_f1(const uint8x16_t in,
34                                                            int sx);
35 
36 static AOM_FORCE_INLINE int16x8_t horizontal_filter_8x1_f1(const uint8x16_t in,
37                                                            int sx);
38 
39 static AOM_FORCE_INLINE int16x8_t
40 horizontal_filter_4x1_f1_beta0(const uint8x16_t in, int16x8_t f_s16);
41 
42 static AOM_FORCE_INLINE int16x8_t
43 horizontal_filter_8x1_f1_beta0(const uint8x16_t in, int16x8_t f_s16);
44 
45 static AOM_FORCE_INLINE void vertical_filter_4x1_f1(const int16x8_t *src,
46                                                     int32x4_t *res, int sy);
47 
48 static AOM_FORCE_INLINE void vertical_filter_4x1_f4(const int16x8_t *src,
49                                                     int32x4_t *res, int sy,
50                                                     int gamma);
51 
52 static AOM_FORCE_INLINE void vertical_filter_8x1_f1(const int16x8_t *src,
53                                                     int32x4_t *res_low,
54                                                     int32x4_t *res_high,
55                                                     int sy);
56 
57 static AOM_FORCE_INLINE void vertical_filter_8x1_f8(const int16x8_t *src,
58                                                     int32x4_t *res_low,
59                                                     int32x4_t *res_high, int sy,
60                                                     int gamma);
61 
load_filters_4(int16x8_t out[],int offset,int stride)62 static AOM_FORCE_INLINE void load_filters_4(int16x8_t out[], int offset,
63                                             int stride) {
64   out[0] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 0 * stride) >>
65                                                       WARPEDDIFF_PREC_BITS)));
66   out[1] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 1 * stride) >>
67                                                       WARPEDDIFF_PREC_BITS)));
68   out[2] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 2 * stride) >>
69                                                       WARPEDDIFF_PREC_BITS)));
70   out[3] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 3 * stride) >>
71                                                       WARPEDDIFF_PREC_BITS)));
72 }
73 
load_filters_8(int16x8_t out[],int offset,int stride)74 static AOM_FORCE_INLINE void load_filters_8(int16x8_t out[], int offset,
75                                             int stride) {
76   out[0] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 0 * stride) >>
77                                                       WARPEDDIFF_PREC_BITS)));
78   out[1] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 1 * stride) >>
79                                                       WARPEDDIFF_PREC_BITS)));
80   out[2] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 2 * stride) >>
81                                                       WARPEDDIFF_PREC_BITS)));
82   out[3] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 3 * stride) >>
83                                                       WARPEDDIFF_PREC_BITS)));
84   out[4] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 4 * stride) >>
85                                                       WARPEDDIFF_PREC_BITS)));
86   out[5] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 5 * stride) >>
87                                                       WARPEDDIFF_PREC_BITS)));
88   out[6] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 6 * stride) >>
89                                                       WARPEDDIFF_PREC_BITS)));
90   out[7] = vld1q_s16((int16_t *)(av1_warped_filter + ((offset + 7 * stride) >>
91                                                       WARPEDDIFF_PREC_BITS)));
92 }
93 
clamp_iy(int iy,int height)94 static AOM_FORCE_INLINE int clamp_iy(int iy, int height) {
95   return clamp(iy, 0, height - 1);
96 }
97 
warp_affine_horizontal(const uint8_t * ref,int width,int height,int stride,int p_width,int p_height,int16_t alpha,int16_t beta,const int64_t x4,const int64_t y4,const int i,int16x8_t tmp[])98 static AOM_FORCE_INLINE void warp_affine_horizontal(
99     const uint8_t *ref, int width, int height, int stride, int p_width,
100     int p_height, int16_t alpha, int16_t beta, const int64_t x4,
101     const int64_t y4, const int i, int16x8_t tmp[]) {
102   const int bd = 8;
103   const int reduce_bits_horiz = ROUND0_BITS;
104   const int height_limit = AOMMIN(8, p_height - i) + 7;
105 
106   int32_t ix4 = (int32_t)(x4 >> WARPEDMODEL_PREC_BITS);
107   int32_t iy4 = (int32_t)(y4 >> WARPEDMODEL_PREC_BITS);
108 
109   int32_t sx4 = x4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
110   sx4 += alpha * (-4) + beta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) +
111          (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
112   sx4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
113 
114   if (ix4 <= -7) {
115     for (int k = 0; k < height_limit; ++k) {
116       int iy = clamp_iy(iy4 + k - 7, height);
117       int16_t dup_val =
118           (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) +
119           ref[iy * stride] * (1 << (FILTER_BITS - reduce_bits_horiz));
120       tmp[k] = vdupq_n_s16(dup_val);
121     }
122     return;
123   } else if (ix4 >= width + 6) {
124     for (int k = 0; k < height_limit; ++k) {
125       int iy = clamp_iy(iy4 + k - 7, height);
126       int16_t dup_val = (1 << (bd + FILTER_BITS - reduce_bits_horiz - 1)) +
127                         ref[iy * stride + (width - 1)] *
128                             (1 << (FILTER_BITS - reduce_bits_horiz));
129       tmp[k] = vdupq_n_s16(dup_val);
130     }
131     return;
132   }
133 
134   static const uint8_t kIotaArr[] = { 0, 1, 2,  3,  4,  5,  6,  7,
135                                       8, 9, 10, 11, 12, 13, 14, 15 };
136   const uint8x16_t indx = vld1q_u8(kIotaArr);
137 
138   const int out_of_boundary_left = -(ix4 - 6);
139   const int out_of_boundary_right = (ix4 + 8) - width;
140 
141 #define APPLY_HORIZONTAL_SHIFT(fn, ...)                                \
142   do {                                                                 \
143     if (out_of_boundary_left >= 0 || out_of_boundary_right >= 0) {     \
144       for (int k = 0; k < height_limit; ++k) {                         \
145         const int iy = clamp_iy(iy4 + k - 7, height);                  \
146         const uint8_t *src = ref + iy * stride + ix4 - 7;              \
147         uint8x16_t src_1 = vld1q_u8(src);                              \
148                                                                        \
149         if (out_of_boundary_left >= 0) {                               \
150           int limit = out_of_boundary_left + 1;                        \
151           uint8x16_t cmp_vec = vdupq_n_u8(out_of_boundary_left);       \
152           uint8x16_t vec_dup = vdupq_n_u8(*(src + limit));             \
153           uint8x16_t mask_val = vcleq_u8(indx, cmp_vec);               \
154           src_1 = vbslq_u8(mask_val, vec_dup, src_1);                  \
155         }                                                              \
156         if (out_of_boundary_right >= 0) {                              \
157           int limit = 15 - (out_of_boundary_right + 1);                \
158           uint8x16_t cmp_vec = vdupq_n_u8(15 - out_of_boundary_right); \
159           uint8x16_t vec_dup = vdupq_n_u8(*(src + limit));             \
160           uint8x16_t mask_val = vcgeq_u8(indx, cmp_vec);               \
161           src_1 = vbslq_u8(mask_val, vec_dup, src_1);                  \
162         }                                                              \
163         tmp[k] = (fn)(src_1, __VA_ARGS__);                             \
164       }                                                                \
165     } else {                                                           \
166       for (int k = 0; k < height_limit; ++k) {                         \
167         const int iy = clamp_iy(iy4 + k - 7, height);                  \
168         const uint8_t *src = ref + iy * stride + ix4 - 7;              \
169         uint8x16_t src_1 = vld1q_u8(src);                              \
170         tmp[k] = (fn)(src_1, __VA_ARGS__);                             \
171       }                                                                \
172     }                                                                  \
173   } while (0)
174 
175   if (p_width == 4) {
176     if (beta == 0) {
177       if (alpha == 0) {
178         int16x8_t f_s16 = vld1q_s16(
179             (int16_t *)(av1_warped_filter + (sx4 >> WARPEDDIFF_PREC_BITS)));
180         APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f1_beta0, f_s16);
181       } else {
182         APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f4, sx4, alpha);
183       }
184     } else {
185       if (alpha == 0) {
186         APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f1,
187                                (sx4 + beta * (k - 3)));
188       } else {
189         APPLY_HORIZONTAL_SHIFT(horizontal_filter_4x1_f4, (sx4 + beta * (k - 3)),
190                                alpha);
191       }
192     }
193   } else {
194     if (beta == 0) {
195       if (alpha == 0) {
196         int16x8_t f_s16 = vld1q_s16(
197             (int16_t *)(av1_warped_filter + (sx4 >> WARPEDDIFF_PREC_BITS)));
198         APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f1_beta0, f_s16);
199       } else {
200         APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f8, sx4, alpha);
201       }
202     } else {
203       if (alpha == 0) {
204         APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f1,
205                                (sx4 + beta * (k - 3)));
206       } else {
207         APPLY_HORIZONTAL_SHIFT(horizontal_filter_8x1_f8, (sx4 + beta * (k - 3)),
208                                alpha);
209       }
210     }
211   }
212 }
213 
warp_affine_vertical(uint8_t * pred,int p_width,int p_height,int p_stride,int is_compound,uint16_t * dst,int dst_stride,int do_average,int use_dist_wtd_comp_avg,int16_t gamma,int16_t delta,const int64_t y4,const int i,const int j,int16x8_t tmp[],const int fwd,const int bwd)214 static AOM_FORCE_INLINE void warp_affine_vertical(
215     uint8_t *pred, int p_width, int p_height, int p_stride, int is_compound,
216     uint16_t *dst, int dst_stride, int do_average, int use_dist_wtd_comp_avg,
217     int16_t gamma, int16_t delta, const int64_t y4, const int i, const int j,
218     int16x8_t tmp[], const int fwd, const int bwd) {
219   const int bd = 8;
220   const int reduce_bits_horiz = ROUND0_BITS;
221   const int offset_bits_vert = bd + 2 * FILTER_BITS - reduce_bits_horiz;
222   int add_const_vert;
223   if (is_compound) {
224     add_const_vert =
225         (1 << offset_bits_vert) + (1 << (COMPOUND_ROUND1_BITS - 1));
226   } else {
227     add_const_vert =
228         (1 << offset_bits_vert) + (1 << (2 * FILTER_BITS - ROUND0_BITS - 1));
229   }
230   const int sub_constant = (1 << (bd - 1)) + (1 << bd);
231 
232   const int offset_bits = bd + 2 * FILTER_BITS - ROUND0_BITS;
233   const int res_sub_const =
234       (1 << (2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS - 1)) -
235       (1 << (offset_bits - COMPOUND_ROUND1_BITS)) -
236       (1 << (offset_bits - COMPOUND_ROUND1_BITS - 1));
237 
238   int32_t sy4 = y4 & ((1 << WARPEDMODEL_PREC_BITS) - 1);
239   sy4 += gamma * (-4) + delta * (-4) + (1 << (WARPEDDIFF_PREC_BITS - 1)) +
240          (WARPEDPIXEL_PREC_SHIFTS << WARPEDDIFF_PREC_BITS);
241   sy4 &= ~((1 << WARP_PARAM_REDUCE_BITS) - 1);
242 
243   if (p_width > 4) {
244     for (int k = -4; k < AOMMIN(4, p_height - i - 4); ++k) {
245       int sy = sy4 + delta * (k + 4);
246       const int16x8_t *v_src = tmp + (k + 4);
247 
248       int32x4_t res_lo, res_hi;
249       if (gamma == 0) {
250         vertical_filter_8x1_f1(v_src, &res_lo, &res_hi, sy);
251       } else {
252         vertical_filter_8x1_f8(v_src, &res_lo, &res_hi, sy, gamma);
253       }
254 
255       res_lo = vaddq_s32(res_lo, vdupq_n_s32(add_const_vert));
256       res_hi = vaddq_s32(res_hi, vdupq_n_s32(add_const_vert));
257 
258       if (is_compound) {
259         uint16_t *const p = (uint16_t *)&dst[(i + k + 4) * dst_stride + j];
260         int16x8_t res_s16 =
261             vcombine_s16(vshrn_n_s32(res_lo, COMPOUND_ROUND1_BITS),
262                          vshrn_n_s32(res_hi, COMPOUND_ROUND1_BITS));
263         if (do_average) {
264           int16x8_t tmp16 = vreinterpretq_s16_u16(vld1q_u16(p));
265           if (use_dist_wtd_comp_avg) {
266             int32x4_t tmp32_lo = vmull_n_s16(vget_low_s16(tmp16), fwd);
267             int32x4_t tmp32_hi = vmull_n_s16(vget_high_s16(tmp16), fwd);
268             tmp32_lo = vmlal_n_s16(tmp32_lo, vget_low_s16(res_s16), bwd);
269             tmp32_hi = vmlal_n_s16(tmp32_hi, vget_high_s16(res_s16), bwd);
270             tmp16 = vcombine_s16(vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS),
271                                  vshrn_n_s32(tmp32_hi, DIST_PRECISION_BITS));
272           } else {
273             tmp16 = vhaddq_s16(tmp16, res_s16);
274           }
275           int16x8_t res = vaddq_s16(tmp16, vdupq_n_s16(res_sub_const));
276           uint8x8_t res8 = vqshrun_n_s16(
277               res, 2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS);
278           vst1_u8(&pred[(i + k + 4) * p_stride + j], res8);
279         } else {
280           vst1q_u16(p, vreinterpretq_u16_s16(res_s16));
281         }
282       } else {
283         int16x8_t res16 =
284             vcombine_s16(vshrn_n_s32(res_lo, 2 * FILTER_BITS - ROUND0_BITS),
285                          vshrn_n_s32(res_hi, 2 * FILTER_BITS - ROUND0_BITS));
286         res16 = vsubq_s16(res16, vdupq_n_s16(sub_constant));
287 
288         uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j];
289         vst1_u8(p, vqmovun_s16(res16));
290       }
291     }
292   } else {
293     // p_width == 4
294     for (int k = -4; k < AOMMIN(4, p_height - i - 4); ++k) {
295       int sy = sy4 + delta * (k + 4);
296       const int16x8_t *v_src = tmp + (k + 4);
297 
298       int32x4_t res_lo;
299       if (gamma == 0) {
300         vertical_filter_4x1_f1(v_src, &res_lo, sy);
301       } else {
302         vertical_filter_4x1_f4(v_src, &res_lo, sy, gamma);
303       }
304 
305       res_lo = vaddq_s32(res_lo, vdupq_n_s32(add_const_vert));
306 
307       if (is_compound) {
308         uint16_t *const p = (uint16_t *)&dst[(i + k + 4) * dst_stride + j];
309 
310         int16x4_t res_lo_s16 = vshrn_n_s32(res_lo, COMPOUND_ROUND1_BITS);
311         if (do_average) {
312           uint8_t *const dst8 = &pred[(i + k + 4) * p_stride + j];
313           int16x4_t tmp16_lo = vreinterpret_s16_u16(vld1_u16(p));
314           if (use_dist_wtd_comp_avg) {
315             int32x4_t tmp32_lo = vmull_n_s16(tmp16_lo, fwd);
316             tmp32_lo = vmlal_n_s16(tmp32_lo, res_lo_s16, bwd);
317             tmp16_lo = vshrn_n_s32(tmp32_lo, DIST_PRECISION_BITS);
318           } else {
319             tmp16_lo = vhadd_s16(tmp16_lo, res_lo_s16);
320           }
321           int16x4_t res = vadd_s16(tmp16_lo, vdup_n_s16(res_sub_const));
322           uint8x8_t res8 = vqshrun_n_s16(
323               vcombine_s16(res, vdup_n_s16(0)),
324               2 * FILTER_BITS - ROUND0_BITS - COMPOUND_ROUND1_BITS);
325           vst1_lane_u32((uint32_t *)dst8, vreinterpret_u32_u8(res8), 0);
326         } else {
327           uint16x4_t res_u16_low = vreinterpret_u16_s16(res_lo_s16);
328           vst1_u16(p, res_u16_low);
329         }
330       } else {
331         int16x4_t res16 = vshrn_n_s32(res_lo, 2 * FILTER_BITS - ROUND0_BITS);
332         res16 = vsub_s16(res16, vdup_n_s16(sub_constant));
333 
334         uint8_t *const p = (uint8_t *)&pred[(i + k + 4) * p_stride + j];
335         uint8x8_t val = vqmovun_s16(vcombine_s16(res16, vdup_n_s16(0)));
336         vst1_lane_u32((uint32_t *)p, vreinterpret_u32_u8(val), 0);
337       }
338     }
339   }
340 }
341 
av1_warp_affine_common(const int32_t * mat,const uint8_t * ref,int width,int height,int stride,uint8_t * pred,int p_col,int p_row,int p_width,int p_height,int p_stride,int subsampling_x,int subsampling_y,ConvolveParams * conv_params,int16_t alpha,int16_t beta,int16_t gamma,int16_t delta)342 static AOM_FORCE_INLINE void av1_warp_affine_common(
343     const int32_t *mat, const uint8_t *ref, int width, int height, int stride,
344     uint8_t *pred, int p_col, int p_row, int p_width, int p_height,
345     int p_stride, int subsampling_x, int subsampling_y,
346     ConvolveParams *conv_params, int16_t alpha, int16_t beta, int16_t gamma,
347     int16_t delta) {
348   const int w0 = conv_params->fwd_offset;
349   const int w1 = conv_params->bck_offset;
350   const int is_compound = conv_params->is_compound;
351   uint16_t *const dst = conv_params->dst;
352   const int dst_stride = conv_params->dst_stride;
353   const int do_average = conv_params->do_average;
354   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
355 
356   assert(IMPLIES(is_compound, dst != NULL));
357   assert(IMPLIES(do_average, is_compound));
358 
359   for (int i = 0; i < p_height; i += 8) {
360     for (int j = 0; j < p_width; j += 8) {
361       const int32_t src_x = (p_col + j + 4) << subsampling_x;
362       const int32_t src_y = (p_row + i + 4) << subsampling_y;
363       const int64_t dst_x =
364           (int64_t)mat[2] * src_x + (int64_t)mat[3] * src_y + (int64_t)mat[0];
365       const int64_t dst_y =
366           (int64_t)mat[4] * src_x + (int64_t)mat[5] * src_y + (int64_t)mat[1];
367 
368       const int64_t x4 = dst_x >> subsampling_x;
369       const int64_t y4 = dst_y >> subsampling_y;
370 
371       int16x8_t tmp[15];
372       warp_affine_horizontal(ref, width, height, stride, p_width, p_height,
373                              alpha, beta, x4, y4, i, tmp);
374       warp_affine_vertical(pred, p_width, p_height, p_stride, is_compound, dst,
375                            dst_stride, do_average, use_dist_wtd_comp_avg, gamma,
376                            delta, y4, i, j, tmp, w0, w1);
377     }
378   }
379 }
380 
381 #endif  // AOM_AV1_COMMON_ARM_WARP_PLANE_NEON_H_
382