xref: /aosp_15_r20/external/libaom/av1/common/arm/highbd_warp_plane_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <stdbool.h>
15 
16 #include "aom_dsp/aom_dsp_common.h"
17 #include "aom_dsp/arm/mem_neon.h"
18 #include "aom_dsp/arm/sum_neon.h"
19 #include "aom_dsp/arm/transpose_neon.h"
20 #include "aom_ports/mem.h"
21 #include "av1/common/scale.h"
22 #include "av1/common/warped_motion.h"
23 #include "config/av1_rtcd.h"
24 #include "highbd_warp_plane_neon.h"
25 
26 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_4x1_f4(int16x8_t rv0,int16x8_t rv1,int16x8_t rv2,int16x8_t rv3,int bd,int sx,int alpha)27 highbd_horizontal_filter_4x1_f4(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
28                                 int16x8_t rv3, int bd, int sx, int alpha) {
29   int16x8_t f[4];
30   load_filters_4(f, sx, alpha);
31 
32   int32x4_t m0 = vmull_s16(vget_low_s16(f[0]), vget_low_s16(rv0));
33   m0 = vmlal_s16(m0, vget_high_s16(f[0]), vget_high_s16(rv0));
34   int32x4_t m1 = vmull_s16(vget_low_s16(f[1]), vget_low_s16(rv1));
35   m1 = vmlal_s16(m1, vget_high_s16(f[1]), vget_high_s16(rv1));
36   int32x4_t m2 = vmull_s16(vget_low_s16(f[2]), vget_low_s16(rv2));
37   m2 = vmlal_s16(m2, vget_high_s16(f[2]), vget_high_s16(rv2));
38   int32x4_t m3 = vmull_s16(vget_low_s16(f[3]), vget_low_s16(rv3));
39   m3 = vmlal_s16(m3, vget_high_s16(f[3]), vget_high_s16(rv3));
40 
41   int32x4_t m0123[] = { m0, m1, m2, m3 };
42 
43   const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
44   const int offset_bits_horiz = bd + FILTER_BITS - 1;
45 
46   int32x4_t res = horizontal_add_4d_s32x4(m0123);
47   res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
48   res = vrshlq_s32(res, vdupq_n_s32(-round0));
49   return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
50 }
51 
highbd_horizontal_filter_8x1_f8(int16x8_t rv0,int16x8_t rv1,int16x8_t rv2,int16x8_t rv3,int16x8_t rv4,int16x8_t rv5,int16x8_t rv6,int16x8_t rv7,int bd,int sx,int alpha)52 static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f8(
53     int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
54     int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx, int alpha) {
55   int16x8_t f[8];
56   load_filters_8(f, sx, alpha);
57 
58   int32x4_t m0 = vmull_s16(vget_low_s16(f[0]), vget_low_s16(rv0));
59   m0 = vmlal_s16(m0, vget_high_s16(f[0]), vget_high_s16(rv0));
60   int32x4_t m1 = vmull_s16(vget_low_s16(f[1]), vget_low_s16(rv1));
61   m1 = vmlal_s16(m1, vget_high_s16(f[1]), vget_high_s16(rv1));
62   int32x4_t m2 = vmull_s16(vget_low_s16(f[2]), vget_low_s16(rv2));
63   m2 = vmlal_s16(m2, vget_high_s16(f[2]), vget_high_s16(rv2));
64   int32x4_t m3 = vmull_s16(vget_low_s16(f[3]), vget_low_s16(rv3));
65   m3 = vmlal_s16(m3, vget_high_s16(f[3]), vget_high_s16(rv3));
66   int32x4_t m4 = vmull_s16(vget_low_s16(f[4]), vget_low_s16(rv4));
67   m4 = vmlal_s16(m4, vget_high_s16(f[4]), vget_high_s16(rv4));
68   int32x4_t m5 = vmull_s16(vget_low_s16(f[5]), vget_low_s16(rv5));
69   m5 = vmlal_s16(m5, vget_high_s16(f[5]), vget_high_s16(rv5));
70   int32x4_t m6 = vmull_s16(vget_low_s16(f[6]), vget_low_s16(rv6));
71   m6 = vmlal_s16(m6, vget_high_s16(f[6]), vget_high_s16(rv6));
72   int32x4_t m7 = vmull_s16(vget_low_s16(f[7]), vget_low_s16(rv7));
73   m7 = vmlal_s16(m7, vget_high_s16(f[7]), vget_high_s16(rv7));
74 
75   int32x4_t m0123[] = { m0, m1, m2, m3 };
76   int32x4_t m4567[] = { m4, m5, m6, m7 };
77 
78   const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
79   const int offset_bits_horiz = bd + FILTER_BITS - 1;
80 
81   int32x4_t res0 = horizontal_add_4d_s32x4(m0123);
82   int32x4_t res1 = horizontal_add_4d_s32x4(m4567);
83   res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
84   res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
85   res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
86   res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
87   return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
88 }
89 
90 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_4x1_f1(int16x8_t rv0,int16x8_t rv1,int16x8_t rv2,int16x8_t rv3,int bd,int sx)91 highbd_horizontal_filter_4x1_f1(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
92                                 int16x8_t rv3, int bd, int sx) {
93   int16x8_t f = load_filters_1(sx);
94 
95   int32x4_t m0 = vmull_s16(vget_low_s16(f), vget_low_s16(rv0));
96   m0 = vmlal_s16(m0, vget_high_s16(f), vget_high_s16(rv0));
97   int32x4_t m1 = vmull_s16(vget_low_s16(f), vget_low_s16(rv1));
98   m1 = vmlal_s16(m1, vget_high_s16(f), vget_high_s16(rv1));
99   int32x4_t m2 = vmull_s16(vget_low_s16(f), vget_low_s16(rv2));
100   m2 = vmlal_s16(m2, vget_high_s16(f), vget_high_s16(rv2));
101   int32x4_t m3 = vmull_s16(vget_low_s16(f), vget_low_s16(rv3));
102   m3 = vmlal_s16(m3, vget_high_s16(f), vget_high_s16(rv3));
103 
104   int32x4_t m0123[] = { m0, m1, m2, m3 };
105 
106   const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
107   const int offset_bits_horiz = bd + FILTER_BITS - 1;
108 
109   int32x4_t res = horizontal_add_4d_s32x4(m0123);
110   res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
111   res = vrshlq_s32(res, vdupq_n_s32(-round0));
112   return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
113 }
114 
highbd_horizontal_filter_8x1_f1(int16x8_t rv0,int16x8_t rv1,int16x8_t rv2,int16x8_t rv3,int16x8_t rv4,int16x8_t rv5,int16x8_t rv6,int16x8_t rv7,int bd,int sx)115 static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f1(
116     int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
117     int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx) {
118   int16x8_t f = load_filters_1(sx);
119 
120   int32x4_t m0 = vmull_s16(vget_low_s16(f), vget_low_s16(rv0));
121   m0 = vmlal_s16(m0, vget_high_s16(f), vget_high_s16(rv0));
122   int32x4_t m1 = vmull_s16(vget_low_s16(f), vget_low_s16(rv1));
123   m1 = vmlal_s16(m1, vget_high_s16(f), vget_high_s16(rv1));
124   int32x4_t m2 = vmull_s16(vget_low_s16(f), vget_low_s16(rv2));
125   m2 = vmlal_s16(m2, vget_high_s16(f), vget_high_s16(rv2));
126   int32x4_t m3 = vmull_s16(vget_low_s16(f), vget_low_s16(rv3));
127   m3 = vmlal_s16(m3, vget_high_s16(f), vget_high_s16(rv3));
128   int32x4_t m4 = vmull_s16(vget_low_s16(f), vget_low_s16(rv4));
129   m4 = vmlal_s16(m4, vget_high_s16(f), vget_high_s16(rv4));
130   int32x4_t m5 = vmull_s16(vget_low_s16(f), vget_low_s16(rv5));
131   m5 = vmlal_s16(m5, vget_high_s16(f), vget_high_s16(rv5));
132   int32x4_t m6 = vmull_s16(vget_low_s16(f), vget_low_s16(rv6));
133   m6 = vmlal_s16(m6, vget_high_s16(f), vget_high_s16(rv6));
134   int32x4_t m7 = vmull_s16(vget_low_s16(f), vget_low_s16(rv7));
135   m7 = vmlal_s16(m7, vget_high_s16(f), vget_high_s16(rv7));
136 
137   int32x4_t m0123[] = { m0, m1, m2, m3 };
138   int32x4_t m4567[] = { m4, m5, m6, m7 };
139 
140   const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
141   const int offset_bits_horiz = bd + FILTER_BITS - 1;
142 
143   int32x4_t res0 = horizontal_add_4d_s32x4(m0123);
144   int32x4_t res1 = horizontal_add_4d_s32x4(m4567);
145   res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
146   res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
147   res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
148   res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
149   return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
150 }
151 
vertical_filter_4x1_f1(const int16x8_t * tmp,int sy)152 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f1(const int16x8_t *tmp,
153                                                          int sy) {
154   const int16x8_t f = load_filters_1(sy);
155   const int16x4_t f0123 = vget_low_s16(f);
156   const int16x4_t f4567 = vget_high_s16(f);
157 
158   int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
159   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
160   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
161   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
162   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
163   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
164   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
165   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
166   return m0123;
167 }
168 
vertical_filter_8x1_f1(const int16x8_t * tmp,int sy)169 static AOM_FORCE_INLINE int32x4x2_t vertical_filter_8x1_f1(const int16x8_t *tmp,
170                                                            int sy) {
171   const int16x8_t f = load_filters_1(sy);
172   const int16x4_t f0123 = vget_low_s16(f);
173   const int16x4_t f4567 = vget_high_s16(f);
174 
175   int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
176   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
177   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
178   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
179   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
180   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
181   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
182   m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
183 
184   int32x4_t m4567 = vmull_lane_s16(vget_high_s16(tmp[0]), f0123, 0);
185   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[1]), f0123, 1);
186   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[2]), f0123, 2);
187   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[3]), f0123, 3);
188   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[4]), f4567, 0);
189   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[5]), f4567, 1);
190   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[6]), f4567, 2);
191   m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[7]), f4567, 3);
192   return (int32x4x2_t){ { m0123, m4567 } };
193 }
194 
vertical_filter_4x1_f4(const int16x8_t * tmp,int sy,int gamma)195 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f4(const int16x8_t *tmp,
196                                                          int sy, int gamma) {
197   int16x8_t s0, s1, s2, s3;
198   transpose_elems_s16_4x8(
199       vget_low_s16(tmp[0]), vget_low_s16(tmp[1]), vget_low_s16(tmp[2]),
200       vget_low_s16(tmp[3]), vget_low_s16(tmp[4]), vget_low_s16(tmp[5]),
201       vget_low_s16(tmp[6]), vget_low_s16(tmp[7]), &s0, &s1, &s2, &s3);
202 
203   int16x8_t f[4];
204   load_filters_4(f, sy, gamma);
205 
206   int32x4_t m0 = vmull_s16(vget_low_s16(s0), vget_low_s16(f[0]));
207   m0 = vmlal_s16(m0, vget_high_s16(s0), vget_high_s16(f[0]));
208   int32x4_t m1 = vmull_s16(vget_low_s16(s1), vget_low_s16(f[1]));
209   m1 = vmlal_s16(m1, vget_high_s16(s1), vget_high_s16(f[1]));
210   int32x4_t m2 = vmull_s16(vget_low_s16(s2), vget_low_s16(f[2]));
211   m2 = vmlal_s16(m2, vget_high_s16(s2), vget_high_s16(f[2]));
212   int32x4_t m3 = vmull_s16(vget_low_s16(s3), vget_low_s16(f[3]));
213   m3 = vmlal_s16(m3, vget_high_s16(s3), vget_high_s16(f[3]));
214 
215   int32x4_t m0123[] = { m0, m1, m2, m3 };
216   return horizontal_add_4d_s32x4(m0123);
217 }
218 
vertical_filter_8x1_f8(const int16x8_t * tmp,int sy,int gamma)219 static AOM_FORCE_INLINE int32x4x2_t vertical_filter_8x1_f8(const int16x8_t *tmp,
220                                                            int sy, int gamma) {
221   int16x8_t s0 = tmp[0];
222   int16x8_t s1 = tmp[1];
223   int16x8_t s2 = tmp[2];
224   int16x8_t s3 = tmp[3];
225   int16x8_t s4 = tmp[4];
226   int16x8_t s5 = tmp[5];
227   int16x8_t s6 = tmp[6];
228   int16x8_t s7 = tmp[7];
229   transpose_elems_inplace_s16_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
230 
231   int16x8_t f[8];
232   load_filters_8(f, sy, gamma);
233 
234   int32x4_t m0 = vmull_s16(vget_low_s16(s0), vget_low_s16(f[0]));
235   m0 = vmlal_s16(m0, vget_high_s16(s0), vget_high_s16(f[0]));
236   int32x4_t m1 = vmull_s16(vget_low_s16(s1), vget_low_s16(f[1]));
237   m1 = vmlal_s16(m1, vget_high_s16(s1), vget_high_s16(f[1]));
238   int32x4_t m2 = vmull_s16(vget_low_s16(s2), vget_low_s16(f[2]));
239   m2 = vmlal_s16(m2, vget_high_s16(s2), vget_high_s16(f[2]));
240   int32x4_t m3 = vmull_s16(vget_low_s16(s3), vget_low_s16(f[3]));
241   m3 = vmlal_s16(m3, vget_high_s16(s3), vget_high_s16(f[3]));
242   int32x4_t m4 = vmull_s16(vget_low_s16(s4), vget_low_s16(f[4]));
243   m4 = vmlal_s16(m4, vget_high_s16(s4), vget_high_s16(f[4]));
244   int32x4_t m5 = vmull_s16(vget_low_s16(s5), vget_low_s16(f[5]));
245   m5 = vmlal_s16(m5, vget_high_s16(s5), vget_high_s16(f[5]));
246   int32x4_t m6 = vmull_s16(vget_low_s16(s6), vget_low_s16(f[6]));
247   m6 = vmlal_s16(m6, vget_high_s16(s6), vget_high_s16(f[6]));
248   int32x4_t m7 = vmull_s16(vget_low_s16(s7), vget_low_s16(f[7]));
249   m7 = vmlal_s16(m7, vget_high_s16(s7), vget_high_s16(f[7]));
250 
251   int32x4_t m0123[] = { m0, m1, m2, m3 };
252   int32x4_t m4567[] = { m4, m5, m6, m7 };
253 
254   int32x4x2_t ret;
255   ret.val[0] = horizontal_add_4d_s32x4(m0123);
256   ret.val[1] = horizontal_add_4d_s32x4(m4567);
257   return ret;
258 }
259 
av1_highbd_warp_affine_neon(const int32_t * mat,const uint16_t * ref,int width,int height,int stride,uint16_t * pred,int p_col,int p_row,int p_width,int p_height,int p_stride,int subsampling_x,int subsampling_y,int bd,ConvolveParams * conv_params,int16_t alpha,int16_t beta,int16_t gamma,int16_t delta)260 void av1_highbd_warp_affine_neon(const int32_t *mat, const uint16_t *ref,
261                                  int width, int height, int stride,
262                                  uint16_t *pred, int p_col, int p_row,
263                                  int p_width, int p_height, int p_stride,
264                                  int subsampling_x, int subsampling_y, int bd,
265                                  ConvolveParams *conv_params, int16_t alpha,
266                                  int16_t beta, int16_t gamma, int16_t delta) {
267   highbd_warp_affine_common(mat, ref, width, height, stride, pred, p_col, p_row,
268                             p_width, p_height, p_stride, subsampling_x,
269                             subsampling_y, bd, conv_params, alpha, beta, gamma,
270                             delta);
271 }
272