1 /*
2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #ifndef AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_
13 #define AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_
14
15 #include <arm_neon.h>
16
17 #include "config/aom_config.h"
18 #include "aom_dsp/arm/mem_neon.h"
19
highbd_convolve8_horiz_2tap_neon(const uint16_t * src_ptr,ptrdiff_t src_stride,uint16_t * dst_ptr,ptrdiff_t dst_stride,const int16_t * x_filter_ptr,int w,int h,int bd)20 static inline void highbd_convolve8_horiz_2tap_neon(
21 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
22 ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int w, int h, int bd) {
23 // Bilinear filter values are all positive and multiples of 8. Divide by 8 to
24 // reduce intermediate precision requirements and allow the use of non
25 // widening multiply.
26 const uint16x8_t f0 = vdupq_n_u16((uint16_t)x_filter_ptr[3] / 8);
27 const uint16x8_t f1 = vdupq_n_u16((uint16_t)x_filter_ptr[4] / 8);
28
29 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
30
31 if (w == 4) {
32 do {
33 uint16x8_t s0 =
34 load_unaligned_u16_4x2(src_ptr + 0 * src_stride + 0, (int)src_stride);
35 uint16x8_t s1 =
36 load_unaligned_u16_4x2(src_ptr + 0 * src_stride + 1, (int)src_stride);
37 uint16x8_t s2 =
38 load_unaligned_u16_4x2(src_ptr + 2 * src_stride + 0, (int)src_stride);
39 uint16x8_t s3 =
40 load_unaligned_u16_4x2(src_ptr + 2 * src_stride + 1, (int)src_stride);
41
42 uint16x8_t sum01 = vmulq_u16(s0, f0);
43 sum01 = vmlaq_u16(sum01, s1, f1);
44 uint16x8_t sum23 = vmulq_u16(s2, f0);
45 sum23 = vmlaq_u16(sum23, s3, f1);
46
47 // We divided filter taps by 8 so subtract 3 from right shift.
48 sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3);
49 sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3);
50
51 sum01 = vminq_u16(sum01, max);
52 sum23 = vminq_u16(sum23, max);
53
54 store_u16x4_strided_x2(dst_ptr + 0 * dst_stride, (int)dst_stride, sum01);
55 store_u16x4_strided_x2(dst_ptr + 2 * dst_stride, (int)dst_stride, sum23);
56
57 src_ptr += 4 * src_stride;
58 dst_ptr += 4 * dst_stride;
59 h -= 4;
60 } while (h > 0);
61 } else {
62 do {
63 int width = w;
64 const uint16_t *s = src_ptr;
65 uint16_t *d = dst_ptr;
66
67 do {
68 uint16x8_t s0 = vld1q_u16(s + 0 * src_stride + 0);
69 uint16x8_t s1 = vld1q_u16(s + 0 * src_stride + 1);
70 uint16x8_t s2 = vld1q_u16(s + 1 * src_stride + 0);
71 uint16x8_t s3 = vld1q_u16(s + 1 * src_stride + 1);
72
73 uint16x8_t sum01 = vmulq_u16(s0, f0);
74 sum01 = vmlaq_u16(sum01, s1, f1);
75 uint16x8_t sum23 = vmulq_u16(s2, f0);
76 sum23 = vmlaq_u16(sum23, s3, f1);
77
78 // We divided filter taps by 8 so subtract 3 from right shift.
79 sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3);
80 sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3);
81
82 sum01 = vminq_u16(sum01, max);
83 sum23 = vminq_u16(sum23, max);
84
85 vst1q_u16(d + 0 * dst_stride, sum01);
86 vst1q_u16(d + 1 * dst_stride, sum23);
87
88 s += 8;
89 d += 8;
90 width -= 8;
91 } while (width != 0);
92 src_ptr += 2 * src_stride;
93 dst_ptr += 2 * dst_stride;
94 h -= 2;
95 } while (h > 0);
96 }
97 }
98
highbd_convolve4_4(const int16x4_t s0,const int16x4_t s1,const int16x4_t s2,const int16x4_t s3,const int16x4_t filter,const uint16x4_t max)99 static inline uint16x4_t highbd_convolve4_4(
100 const int16x4_t s0, const int16x4_t s1, const int16x4_t s2,
101 const int16x4_t s3, const int16x4_t filter, const uint16x4_t max) {
102 int32x4_t sum = vmull_lane_s16(s0, filter, 0);
103 sum = vmlal_lane_s16(sum, s1, filter, 1);
104 sum = vmlal_lane_s16(sum, s2, filter, 2);
105 sum = vmlal_lane_s16(sum, s3, filter, 3);
106
107 uint16x4_t res = vqrshrun_n_s32(sum, FILTER_BITS);
108
109 return vmin_u16(res, max);
110 }
111
highbd_convolve4_8(const int16x8_t s0,const int16x8_t s1,const int16x8_t s2,const int16x8_t s3,const int16x4_t filter,const uint16x8_t max)112 static inline uint16x8_t highbd_convolve4_8(
113 const int16x8_t s0, const int16x8_t s1, const int16x8_t s2,
114 const int16x8_t s3, const int16x4_t filter, const uint16x8_t max) {
115 int32x4_t sum0 = vmull_lane_s16(vget_low_s16(s0), filter, 0);
116 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s1), filter, 1);
117 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s2), filter, 2);
118 sum0 = vmlal_lane_s16(sum0, vget_low_s16(s3), filter, 3);
119
120 int32x4_t sum1 = vmull_lane_s16(vget_high_s16(s0), filter, 0);
121 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s1), filter, 1);
122 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s2), filter, 2);
123 sum1 = vmlal_lane_s16(sum1, vget_high_s16(s3), filter, 3);
124
125 uint16x8_t res = vcombine_u16(vqrshrun_n_s32(sum0, FILTER_BITS),
126 vqrshrun_n_s32(sum1, FILTER_BITS));
127
128 return vminq_u16(res, max);
129 }
130
highbd_convolve8_vert_4tap_neon(const uint16_t * src_ptr,ptrdiff_t src_stride,uint16_t * dst_ptr,ptrdiff_t dst_stride,const int16_t * y_filter_ptr,int w,int h,int bd)131 static inline void highbd_convolve8_vert_4tap_neon(
132 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
133 ptrdiff_t dst_stride, const int16_t *y_filter_ptr, int w, int h, int bd) {
134 assert(w >= 4 && h >= 4);
135 const int16x4_t y_filter = vld1_s16(y_filter_ptr + 2);
136
137 if (w == 4) {
138 const uint16x4_t max = vdup_n_u16((1 << bd) - 1);
139 const int16_t *s = (const int16_t *)src_ptr;
140 uint16_t *d = dst_ptr;
141
142 int16x4_t s0, s1, s2;
143 load_s16_4x3(s, src_stride, &s0, &s1, &s2);
144 s += 3 * src_stride;
145
146 do {
147 int16x4_t s3, s4, s5, s6;
148 load_s16_4x4(s, src_stride, &s3, &s4, &s5, &s6);
149
150 uint16x4_t d0 = highbd_convolve4_4(s0, s1, s2, s3, y_filter, max);
151 uint16x4_t d1 = highbd_convolve4_4(s1, s2, s3, s4, y_filter, max);
152 uint16x4_t d2 = highbd_convolve4_4(s2, s3, s4, s5, y_filter, max);
153 uint16x4_t d3 = highbd_convolve4_4(s3, s4, s5, s6, y_filter, max);
154
155 store_u16_4x4(d, dst_stride, d0, d1, d2, d3);
156
157 s0 = s4;
158 s1 = s5;
159 s2 = s6;
160
161 s += 4 * src_stride;
162 d += 4 * dst_stride;
163 h -= 4;
164 } while (h > 0);
165 } else {
166 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
167
168 do {
169 int height = h;
170 const int16_t *s = (const int16_t *)src_ptr;
171 uint16_t *d = dst_ptr;
172
173 int16x8_t s0, s1, s2;
174 load_s16_8x3(s, src_stride, &s0, &s1, &s2);
175 s += 3 * src_stride;
176
177 do {
178 int16x8_t s3, s4, s5, s6;
179 load_s16_8x4(s, src_stride, &s3, &s4, &s5, &s6);
180
181 uint16x8_t d0 = highbd_convolve4_8(s0, s1, s2, s3, y_filter, max);
182 uint16x8_t d1 = highbd_convolve4_8(s1, s2, s3, s4, y_filter, max);
183 uint16x8_t d2 = highbd_convolve4_8(s2, s3, s4, s5, y_filter, max);
184 uint16x8_t d3 = highbd_convolve4_8(s3, s4, s5, s6, y_filter, max);
185
186 store_u16_8x4(d, dst_stride, d0, d1, d2, d3);
187
188 s0 = s4;
189 s1 = s5;
190 s2 = s6;
191
192 s += 4 * src_stride;
193 d += 4 * dst_stride;
194 height -= 4;
195 } while (height > 0);
196 src_ptr += 8;
197 dst_ptr += 8;
198 w -= 8;
199 } while (w > 0);
200 }
201 }
202
highbd_convolve8_vert_2tap_neon(const uint16_t * src_ptr,ptrdiff_t src_stride,uint16_t * dst_ptr,ptrdiff_t dst_stride,const int16_t * x_filter_ptr,int w,int h,int bd)203 static inline void highbd_convolve8_vert_2tap_neon(
204 const uint16_t *src_ptr, ptrdiff_t src_stride, uint16_t *dst_ptr,
205 ptrdiff_t dst_stride, const int16_t *x_filter_ptr, int w, int h, int bd) {
206 // Bilinear filter values are all positive and multiples of 8. Divide by 8 to
207 // reduce intermediate precision requirements and allow the use of non
208 // widening multiply.
209 const uint16x8_t f0 = vdupq_n_u16((uint16_t)x_filter_ptr[3] / 8);
210 const uint16x8_t f1 = vdupq_n_u16((uint16_t)x_filter_ptr[4] / 8);
211
212 const uint16x8_t max = vdupq_n_u16((1 << bd) - 1);
213
214 if (w == 4) {
215 do {
216 uint16x8_t s0 =
217 load_unaligned_u16_4x2(src_ptr + 0 * src_stride, (int)src_stride);
218 uint16x8_t s1 =
219 load_unaligned_u16_4x2(src_ptr + 1 * src_stride, (int)src_stride);
220 uint16x8_t s2 =
221 load_unaligned_u16_4x2(src_ptr + 2 * src_stride, (int)src_stride);
222 uint16x8_t s3 =
223 load_unaligned_u16_4x2(src_ptr + 3 * src_stride, (int)src_stride);
224
225 uint16x8_t sum01 = vmulq_u16(s0, f0);
226 sum01 = vmlaq_u16(sum01, s1, f1);
227 uint16x8_t sum23 = vmulq_u16(s2, f0);
228 sum23 = vmlaq_u16(sum23, s3, f1);
229
230 // We divided filter taps by 8 so subtract 3 from right shift.
231 sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3);
232 sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3);
233
234 sum01 = vminq_u16(sum01, max);
235 sum23 = vminq_u16(sum23, max);
236
237 store_u16x4_strided_x2(dst_ptr + 0 * dst_stride, (int)dst_stride, sum01);
238 store_u16x4_strided_x2(dst_ptr + 2 * dst_stride, (int)dst_stride, sum23);
239
240 src_ptr += 4 * src_stride;
241 dst_ptr += 4 * dst_stride;
242 h -= 4;
243 } while (h > 0);
244 } else {
245 do {
246 int width = w;
247 const uint16_t *s = src_ptr;
248 uint16_t *d = dst_ptr;
249
250 do {
251 uint16x8_t s0, s1, s2;
252 load_u16_8x3(s, src_stride, &s0, &s1, &s2);
253
254 uint16x8_t sum01 = vmulq_u16(s0, f0);
255 sum01 = vmlaq_u16(sum01, s1, f1);
256 uint16x8_t sum23 = vmulq_u16(s1, f0);
257 sum23 = vmlaq_u16(sum23, s2, f1);
258
259 // We divided filter taps by 8 so subtract 3 from right shift.
260 sum01 = vrshrq_n_u16(sum01, FILTER_BITS - 3);
261 sum23 = vrshrq_n_u16(sum23, FILTER_BITS - 3);
262
263 sum01 = vminq_u16(sum01, max);
264 sum23 = vminq_u16(sum23, max);
265
266 vst1q_u16(d + 0 * dst_stride, sum01);
267 vst1q_u16(d + 1 * dst_stride, sum23);
268
269 s += 8;
270 d += 8;
271 width -= 8;
272 } while (width != 0);
273 src_ptr += 2 * src_stride;
274 dst_ptr += 2 * dst_stride;
275 h -= 2;
276 } while (h > 0);
277 }
278 }
279
280 #endif // AOM_AOM_DSP_ARM_HIGHBD_CONVOLVE8_NEON_H_
281