1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <stdbool.h>
15
16 #include "aom_dsp/aom_dsp_common.h"
17 #include "aom_dsp/arm/mem_neon.h"
18 #include "aom_dsp/arm/sum_neon.h"
19 #include "aom_dsp/arm/transpose_neon.h"
20 #include "aom_ports/mem.h"
21 #include "av1/common/scale.h"
22 #include "av1/common/warped_motion.h"
23 #include "config/av1_rtcd.h"
24 #include "highbd_warp_plane_neon.h"
25
26 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_4x1_f4(int16x8_t rv0,int16x8_t rv1,int16x8_t rv2,int16x8_t rv3,int bd,int sx,int alpha)27 highbd_horizontal_filter_4x1_f4(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
28 int16x8_t rv3, int bd, int sx, int alpha) {
29 int16x8_t f[4];
30 load_filters_4(f, sx, alpha);
31
32 int32x4_t m0 = vmull_s16(vget_low_s16(f[0]), vget_low_s16(rv0));
33 m0 = vmlal_s16(m0, vget_high_s16(f[0]), vget_high_s16(rv0));
34 int32x4_t m1 = vmull_s16(vget_low_s16(f[1]), vget_low_s16(rv1));
35 m1 = vmlal_s16(m1, vget_high_s16(f[1]), vget_high_s16(rv1));
36 int32x4_t m2 = vmull_s16(vget_low_s16(f[2]), vget_low_s16(rv2));
37 m2 = vmlal_s16(m2, vget_high_s16(f[2]), vget_high_s16(rv2));
38 int32x4_t m3 = vmull_s16(vget_low_s16(f[3]), vget_low_s16(rv3));
39 m3 = vmlal_s16(m3, vget_high_s16(f[3]), vget_high_s16(rv3));
40
41 int32x4_t m0123[] = { m0, m1, m2, m3 };
42
43 const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
44 const int offset_bits_horiz = bd + FILTER_BITS - 1;
45
46 int32x4_t res = horizontal_add_4d_s32x4(m0123);
47 res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
48 res = vrshlq_s32(res, vdupq_n_s32(-round0));
49 return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
50 }
51
highbd_horizontal_filter_8x1_f8(int16x8_t rv0,int16x8_t rv1,int16x8_t rv2,int16x8_t rv3,int16x8_t rv4,int16x8_t rv5,int16x8_t rv6,int16x8_t rv7,int bd,int sx,int alpha)52 static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f8(
53 int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
54 int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx, int alpha) {
55 int16x8_t f[8];
56 load_filters_8(f, sx, alpha);
57
58 int32x4_t m0 = vmull_s16(vget_low_s16(f[0]), vget_low_s16(rv0));
59 m0 = vmlal_s16(m0, vget_high_s16(f[0]), vget_high_s16(rv0));
60 int32x4_t m1 = vmull_s16(vget_low_s16(f[1]), vget_low_s16(rv1));
61 m1 = vmlal_s16(m1, vget_high_s16(f[1]), vget_high_s16(rv1));
62 int32x4_t m2 = vmull_s16(vget_low_s16(f[2]), vget_low_s16(rv2));
63 m2 = vmlal_s16(m2, vget_high_s16(f[2]), vget_high_s16(rv2));
64 int32x4_t m3 = vmull_s16(vget_low_s16(f[3]), vget_low_s16(rv3));
65 m3 = vmlal_s16(m3, vget_high_s16(f[3]), vget_high_s16(rv3));
66 int32x4_t m4 = vmull_s16(vget_low_s16(f[4]), vget_low_s16(rv4));
67 m4 = vmlal_s16(m4, vget_high_s16(f[4]), vget_high_s16(rv4));
68 int32x4_t m5 = vmull_s16(vget_low_s16(f[5]), vget_low_s16(rv5));
69 m5 = vmlal_s16(m5, vget_high_s16(f[5]), vget_high_s16(rv5));
70 int32x4_t m6 = vmull_s16(vget_low_s16(f[6]), vget_low_s16(rv6));
71 m6 = vmlal_s16(m6, vget_high_s16(f[6]), vget_high_s16(rv6));
72 int32x4_t m7 = vmull_s16(vget_low_s16(f[7]), vget_low_s16(rv7));
73 m7 = vmlal_s16(m7, vget_high_s16(f[7]), vget_high_s16(rv7));
74
75 int32x4_t m0123[] = { m0, m1, m2, m3 };
76 int32x4_t m4567[] = { m4, m5, m6, m7 };
77
78 const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
79 const int offset_bits_horiz = bd + FILTER_BITS - 1;
80
81 int32x4_t res0 = horizontal_add_4d_s32x4(m0123);
82 int32x4_t res1 = horizontal_add_4d_s32x4(m4567);
83 res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
84 res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
85 res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
86 res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
87 return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
88 }
89
90 static AOM_FORCE_INLINE int16x8_t
highbd_horizontal_filter_4x1_f1(int16x8_t rv0,int16x8_t rv1,int16x8_t rv2,int16x8_t rv3,int bd,int sx)91 highbd_horizontal_filter_4x1_f1(int16x8_t rv0, int16x8_t rv1, int16x8_t rv2,
92 int16x8_t rv3, int bd, int sx) {
93 int16x8_t f = load_filters_1(sx);
94
95 int32x4_t m0 = vmull_s16(vget_low_s16(f), vget_low_s16(rv0));
96 m0 = vmlal_s16(m0, vget_high_s16(f), vget_high_s16(rv0));
97 int32x4_t m1 = vmull_s16(vget_low_s16(f), vget_low_s16(rv1));
98 m1 = vmlal_s16(m1, vget_high_s16(f), vget_high_s16(rv1));
99 int32x4_t m2 = vmull_s16(vget_low_s16(f), vget_low_s16(rv2));
100 m2 = vmlal_s16(m2, vget_high_s16(f), vget_high_s16(rv2));
101 int32x4_t m3 = vmull_s16(vget_low_s16(f), vget_low_s16(rv3));
102 m3 = vmlal_s16(m3, vget_high_s16(f), vget_high_s16(rv3));
103
104 int32x4_t m0123[] = { m0, m1, m2, m3 };
105
106 const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
107 const int offset_bits_horiz = bd + FILTER_BITS - 1;
108
109 int32x4_t res = horizontal_add_4d_s32x4(m0123);
110 res = vaddq_s32(res, vdupq_n_s32(1 << offset_bits_horiz));
111 res = vrshlq_s32(res, vdupq_n_s32(-round0));
112 return vcombine_s16(vmovn_s32(res), vdup_n_s16(0));
113 }
114
highbd_horizontal_filter_8x1_f1(int16x8_t rv0,int16x8_t rv1,int16x8_t rv2,int16x8_t rv3,int16x8_t rv4,int16x8_t rv5,int16x8_t rv6,int16x8_t rv7,int bd,int sx)115 static AOM_FORCE_INLINE int16x8_t highbd_horizontal_filter_8x1_f1(
116 int16x8_t rv0, int16x8_t rv1, int16x8_t rv2, int16x8_t rv3, int16x8_t rv4,
117 int16x8_t rv5, int16x8_t rv6, int16x8_t rv7, int bd, int sx) {
118 int16x8_t f = load_filters_1(sx);
119
120 int32x4_t m0 = vmull_s16(vget_low_s16(f), vget_low_s16(rv0));
121 m0 = vmlal_s16(m0, vget_high_s16(f), vget_high_s16(rv0));
122 int32x4_t m1 = vmull_s16(vget_low_s16(f), vget_low_s16(rv1));
123 m1 = vmlal_s16(m1, vget_high_s16(f), vget_high_s16(rv1));
124 int32x4_t m2 = vmull_s16(vget_low_s16(f), vget_low_s16(rv2));
125 m2 = vmlal_s16(m2, vget_high_s16(f), vget_high_s16(rv2));
126 int32x4_t m3 = vmull_s16(vget_low_s16(f), vget_low_s16(rv3));
127 m3 = vmlal_s16(m3, vget_high_s16(f), vget_high_s16(rv3));
128 int32x4_t m4 = vmull_s16(vget_low_s16(f), vget_low_s16(rv4));
129 m4 = vmlal_s16(m4, vget_high_s16(f), vget_high_s16(rv4));
130 int32x4_t m5 = vmull_s16(vget_low_s16(f), vget_low_s16(rv5));
131 m5 = vmlal_s16(m5, vget_high_s16(f), vget_high_s16(rv5));
132 int32x4_t m6 = vmull_s16(vget_low_s16(f), vget_low_s16(rv6));
133 m6 = vmlal_s16(m6, vget_high_s16(f), vget_high_s16(rv6));
134 int32x4_t m7 = vmull_s16(vget_low_s16(f), vget_low_s16(rv7));
135 m7 = vmlal_s16(m7, vget_high_s16(f), vget_high_s16(rv7));
136
137 int32x4_t m0123[] = { m0, m1, m2, m3 };
138 int32x4_t m4567[] = { m4, m5, m6, m7 };
139
140 const int round0 = (bd == 12) ? ROUND0_BITS + 2 : ROUND0_BITS;
141 const int offset_bits_horiz = bd + FILTER_BITS - 1;
142
143 int32x4_t res0 = horizontal_add_4d_s32x4(m0123);
144 int32x4_t res1 = horizontal_add_4d_s32x4(m4567);
145 res0 = vaddq_s32(res0, vdupq_n_s32(1 << offset_bits_horiz));
146 res1 = vaddq_s32(res1, vdupq_n_s32(1 << offset_bits_horiz));
147 res0 = vrshlq_s32(res0, vdupq_n_s32(-round0));
148 res1 = vrshlq_s32(res1, vdupq_n_s32(-round0));
149 return vcombine_s16(vmovn_s32(res0), vmovn_s32(res1));
150 }
151
vertical_filter_4x1_f1(const int16x8_t * tmp,int sy)152 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f1(const int16x8_t *tmp,
153 int sy) {
154 const int16x8_t f = load_filters_1(sy);
155 const int16x4_t f0123 = vget_low_s16(f);
156 const int16x4_t f4567 = vget_high_s16(f);
157
158 int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
159 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
160 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
161 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
162 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
163 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
164 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
165 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
166 return m0123;
167 }
168
vertical_filter_8x1_f1(const int16x8_t * tmp,int sy)169 static AOM_FORCE_INLINE int32x4x2_t vertical_filter_8x1_f1(const int16x8_t *tmp,
170 int sy) {
171 const int16x8_t f = load_filters_1(sy);
172 const int16x4_t f0123 = vget_low_s16(f);
173 const int16x4_t f4567 = vget_high_s16(f);
174
175 int32x4_t m0123 = vmull_lane_s16(vget_low_s16(tmp[0]), f0123, 0);
176 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[1]), f0123, 1);
177 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[2]), f0123, 2);
178 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[3]), f0123, 3);
179 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[4]), f4567, 0);
180 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[5]), f4567, 1);
181 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[6]), f4567, 2);
182 m0123 = vmlal_lane_s16(m0123, vget_low_s16(tmp[7]), f4567, 3);
183
184 int32x4_t m4567 = vmull_lane_s16(vget_high_s16(tmp[0]), f0123, 0);
185 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[1]), f0123, 1);
186 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[2]), f0123, 2);
187 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[3]), f0123, 3);
188 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[4]), f4567, 0);
189 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[5]), f4567, 1);
190 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[6]), f4567, 2);
191 m4567 = vmlal_lane_s16(m4567, vget_high_s16(tmp[7]), f4567, 3);
192 return (int32x4x2_t){ { m0123, m4567 } };
193 }
194
vertical_filter_4x1_f4(const int16x8_t * tmp,int sy,int gamma)195 static AOM_FORCE_INLINE int32x4_t vertical_filter_4x1_f4(const int16x8_t *tmp,
196 int sy, int gamma) {
197 int16x8_t s0, s1, s2, s3;
198 transpose_elems_s16_4x8(
199 vget_low_s16(tmp[0]), vget_low_s16(tmp[1]), vget_low_s16(tmp[2]),
200 vget_low_s16(tmp[3]), vget_low_s16(tmp[4]), vget_low_s16(tmp[5]),
201 vget_low_s16(tmp[6]), vget_low_s16(tmp[7]), &s0, &s1, &s2, &s3);
202
203 int16x8_t f[4];
204 load_filters_4(f, sy, gamma);
205
206 int32x4_t m0 = vmull_s16(vget_low_s16(s0), vget_low_s16(f[0]));
207 m0 = vmlal_s16(m0, vget_high_s16(s0), vget_high_s16(f[0]));
208 int32x4_t m1 = vmull_s16(vget_low_s16(s1), vget_low_s16(f[1]));
209 m1 = vmlal_s16(m1, vget_high_s16(s1), vget_high_s16(f[1]));
210 int32x4_t m2 = vmull_s16(vget_low_s16(s2), vget_low_s16(f[2]));
211 m2 = vmlal_s16(m2, vget_high_s16(s2), vget_high_s16(f[2]));
212 int32x4_t m3 = vmull_s16(vget_low_s16(s3), vget_low_s16(f[3]));
213 m3 = vmlal_s16(m3, vget_high_s16(s3), vget_high_s16(f[3]));
214
215 int32x4_t m0123[] = { m0, m1, m2, m3 };
216 return horizontal_add_4d_s32x4(m0123);
217 }
218
vertical_filter_8x1_f8(const int16x8_t * tmp,int sy,int gamma)219 static AOM_FORCE_INLINE int32x4x2_t vertical_filter_8x1_f8(const int16x8_t *tmp,
220 int sy, int gamma) {
221 int16x8_t s0 = tmp[0];
222 int16x8_t s1 = tmp[1];
223 int16x8_t s2 = tmp[2];
224 int16x8_t s3 = tmp[3];
225 int16x8_t s4 = tmp[4];
226 int16x8_t s5 = tmp[5];
227 int16x8_t s6 = tmp[6];
228 int16x8_t s7 = tmp[7];
229 transpose_elems_inplace_s16_8x8(&s0, &s1, &s2, &s3, &s4, &s5, &s6, &s7);
230
231 int16x8_t f[8];
232 load_filters_8(f, sy, gamma);
233
234 int32x4_t m0 = vmull_s16(vget_low_s16(s0), vget_low_s16(f[0]));
235 m0 = vmlal_s16(m0, vget_high_s16(s0), vget_high_s16(f[0]));
236 int32x4_t m1 = vmull_s16(vget_low_s16(s1), vget_low_s16(f[1]));
237 m1 = vmlal_s16(m1, vget_high_s16(s1), vget_high_s16(f[1]));
238 int32x4_t m2 = vmull_s16(vget_low_s16(s2), vget_low_s16(f[2]));
239 m2 = vmlal_s16(m2, vget_high_s16(s2), vget_high_s16(f[2]));
240 int32x4_t m3 = vmull_s16(vget_low_s16(s3), vget_low_s16(f[3]));
241 m3 = vmlal_s16(m3, vget_high_s16(s3), vget_high_s16(f[3]));
242 int32x4_t m4 = vmull_s16(vget_low_s16(s4), vget_low_s16(f[4]));
243 m4 = vmlal_s16(m4, vget_high_s16(s4), vget_high_s16(f[4]));
244 int32x4_t m5 = vmull_s16(vget_low_s16(s5), vget_low_s16(f[5]));
245 m5 = vmlal_s16(m5, vget_high_s16(s5), vget_high_s16(f[5]));
246 int32x4_t m6 = vmull_s16(vget_low_s16(s6), vget_low_s16(f[6]));
247 m6 = vmlal_s16(m6, vget_high_s16(s6), vget_high_s16(f[6]));
248 int32x4_t m7 = vmull_s16(vget_low_s16(s7), vget_low_s16(f[7]));
249 m7 = vmlal_s16(m7, vget_high_s16(s7), vget_high_s16(f[7]));
250
251 int32x4_t m0123[] = { m0, m1, m2, m3 };
252 int32x4_t m4567[] = { m4, m5, m6, m7 };
253
254 int32x4x2_t ret;
255 ret.val[0] = horizontal_add_4d_s32x4(m0123);
256 ret.val[1] = horizontal_add_4d_s32x4(m4567);
257 return ret;
258 }
259
av1_highbd_warp_affine_neon(const int32_t * mat,const uint16_t * ref,int width,int height,int stride,uint16_t * pred,int p_col,int p_row,int p_width,int p_height,int p_stride,int subsampling_x,int subsampling_y,int bd,ConvolveParams * conv_params,int16_t alpha,int16_t beta,int16_t gamma,int16_t delta)260 void av1_highbd_warp_affine_neon(const int32_t *mat, const uint16_t *ref,
261 int width, int height, int stride,
262 uint16_t *pred, int p_col, int p_row,
263 int p_width, int p_height, int p_stride,
264 int subsampling_x, int subsampling_y, int bd,
265 ConvolveParams *conv_params, int16_t alpha,
266 int16_t beta, int16_t gamma, int16_t delta) {
267 highbd_warp_affine_common(mat, ref, width, height, stride, pred, p_col, p_row,
268 p_width, p_height, p_stride, subsampling_x,
269 subsampling_y, bd, conv_params, alpha, beta, gamma,
270 delta);
271 }
272