1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_dsp/arm/sum_neon.h"
20
highbd_obmc_variance_8x1_s16_neon(uint16x8_t pre,const int32_t * wsrc,const int32_t * mask,uint32x4_t * sse,int32x4_t * sum)21 static inline void highbd_obmc_variance_8x1_s16_neon(uint16x8_t pre,
22 const int32_t *wsrc,
23 const int32_t *mask,
24 uint32x4_t *sse,
25 int32x4_t *sum) {
26 int16x8_t pre_s16 = vreinterpretq_s16_u16(pre);
27 int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]);
28 int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]);
29
30 int32x4_t mask_lo = vld1q_s32(&mask[0]);
31 int32x4_t mask_hi = vld1q_s32(&mask[4]);
32
33 int16x8_t mask_s16 = vcombine_s16(vmovn_s32(mask_lo), vmovn_s32(mask_hi));
34
35 int32x4_t diff_lo = vmull_s16(vget_low_s16(pre_s16), vget_low_s16(mask_s16));
36 int32x4_t diff_hi =
37 vmull_s16(vget_high_s16(pre_s16), vget_high_s16(mask_s16));
38
39 diff_lo = vsubq_s32(wsrc_lo, diff_lo);
40 diff_hi = vsubq_s32(wsrc_hi, diff_hi);
41
42 // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away
43 // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up.
44 // This difference only affects the bit patterns at the rounding breakpoints
45 // exactly, so we can add -1 to all negative numbers to move the breakpoint
46 // one value across and into the correct rounding region.
47 diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31);
48 diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31);
49 int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12);
50 int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12);
51
52 *sum = vaddq_s32(*sum, round_lo);
53 *sum = vaddq_s32(*sum, round_hi);
54 *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_lo),
55 vreinterpretq_u32_s32(round_lo));
56 *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_hi),
57 vreinterpretq_u32_s32(round_hi));
58 }
59
60 // For 12-bit data, we can only accumulate up to 256 elements in the unsigned
61 // 32-bit elements (4095*4095*256 = 4292870400) before we have to accumulate
62 // into 64-bit elements. Therefore blocks of size 32x64, 64x32, 64x64, 64x128,
63 // 128x64, 128x128 are processed in a different helper function.
highbd_obmc_variance_xlarge_neon(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int width,int h,int h_limit,uint64_t * sse,int64_t * sum)64 static inline void highbd_obmc_variance_xlarge_neon(
65 const uint8_t *pre, int pre_stride, const int32_t *wsrc,
66 const int32_t *mask, int width, int h, int h_limit, uint64_t *sse,
67 int64_t *sum) {
68 uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre);
69 int32x4_t sum_s32 = vdupq_n_s32(0);
70 uint64x2_t sse_u64 = vdupq_n_u64(0);
71
72 // 'h_limit' is the number of 'w'-width rows we can process before our 32-bit
73 // accumulator overflows. After hitting this limit we accumulate into 64-bit
74 // elements.
75 int h_tmp = h > h_limit ? h_limit : h;
76
77 do {
78 uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
79 int j = 0;
80
81 do {
82 int i = 0;
83
84 do {
85 uint16x8_t pre0 = vld1q_u16(pre_ptr + i);
86 highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32[0],
87 &sum_s32);
88
89 uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8);
90 highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32[1],
91 &sum_s32);
92
93 i += 16;
94 wsrc += 16;
95 mask += 16;
96 } while (i < width);
97
98 pre_ptr += pre_stride;
99 j++;
100 } while (j < h_tmp);
101
102 sse_u64 = vpadalq_u32(sse_u64, sse_u32[0]);
103 sse_u64 = vpadalq_u32(sse_u64, sse_u32[1]);
104 h -= h_tmp;
105 } while (h != 0);
106
107 *sse = horizontal_add_u64x2(sse_u64);
108 *sum = horizontal_long_add_s32x4(sum_s32);
109 }
110
highbd_obmc_variance_xlarge_neon_128xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)111 static inline void highbd_obmc_variance_xlarge_neon_128xh(
112 const uint8_t *pre, int pre_stride, const int32_t *wsrc,
113 const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
114 highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 128, h, 16, sse,
115 sum);
116 }
117
highbd_obmc_variance_xlarge_neon_64xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)118 static inline void highbd_obmc_variance_xlarge_neon_64xh(
119 const uint8_t *pre, int pre_stride, const int32_t *wsrc,
120 const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
121 highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 64, h, 32, sse,
122 sum);
123 }
124
highbd_obmc_variance_xlarge_neon_32xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)125 static inline void highbd_obmc_variance_xlarge_neon_32xh(
126 const uint8_t *pre, int pre_stride, const int32_t *wsrc,
127 const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
128 highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 32, h, 64, sse,
129 sum);
130 }
131
highbd_obmc_variance_large_neon(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int width,int h,uint64_t * sse,int64_t * sum)132 static inline void highbd_obmc_variance_large_neon(
133 const uint8_t *pre, int pre_stride, const int32_t *wsrc,
134 const int32_t *mask, int width, int h, uint64_t *sse, int64_t *sum) {
135 uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre);
136 uint32x4_t sse_u32 = vdupq_n_u32(0);
137 int32x4_t sum_s32 = vdupq_n_s32(0);
138
139 do {
140 int i = 0;
141 do {
142 uint16x8_t pre0 = vld1q_u16(pre_ptr + i);
143 highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32, &sum_s32);
144
145 uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8);
146 highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32,
147 &sum_s32);
148
149 i += 16;
150 wsrc += 16;
151 mask += 16;
152 } while (i < width);
153
154 pre_ptr += pre_stride;
155 } while (--h != 0);
156
157 *sse = horizontal_long_add_u32x4(sse_u32);
158 *sum = horizontal_long_add_s32x4(sum_s32);
159 }
160
highbd_obmc_variance_neon_128xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)161 static inline void highbd_obmc_variance_neon_128xh(
162 const uint8_t *pre, int pre_stride, const int32_t *wsrc,
163 const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
164 highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse,
165 sum);
166 }
167
highbd_obmc_variance_neon_64xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)168 static inline void highbd_obmc_variance_neon_64xh(const uint8_t *pre,
169 int pre_stride,
170 const int32_t *wsrc,
171 const int32_t *mask, int h,
172 uint64_t *sse, int64_t *sum) {
173 highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum);
174 }
175
highbd_obmc_variance_neon_32xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)176 static inline void highbd_obmc_variance_neon_32xh(const uint8_t *pre,
177 int pre_stride,
178 const int32_t *wsrc,
179 const int32_t *mask, int h,
180 uint64_t *sse, int64_t *sum) {
181 highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum);
182 }
183
highbd_obmc_variance_neon_16xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)184 static inline void highbd_obmc_variance_neon_16xh(const uint8_t *pre,
185 int pre_stride,
186 const int32_t *wsrc,
187 const int32_t *mask, int h,
188 uint64_t *sse, int64_t *sum) {
189 highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum);
190 }
191
highbd_obmc_variance_neon_8xh(const uint8_t * pre8,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)192 static inline void highbd_obmc_variance_neon_8xh(const uint8_t *pre8,
193 int pre_stride,
194 const int32_t *wsrc,
195 const int32_t *mask, int h,
196 uint64_t *sse, int64_t *sum) {
197 uint16_t *pre = CONVERT_TO_SHORTPTR(pre8);
198 uint32x4_t sse_u32 = vdupq_n_u32(0);
199 int32x4_t sum_s32 = vdupq_n_s32(0);
200
201 do {
202 uint16x8_t pre_u16 = vld1q_u16(pre);
203
204 highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32);
205
206 pre += pre_stride;
207 wsrc += 8;
208 mask += 8;
209 } while (--h != 0);
210
211 *sse = horizontal_long_add_u32x4(sse_u32);
212 *sum = horizontal_long_add_s32x4(sum_s32);
213 }
214
highbd_obmc_variance_neon_4xh(const uint8_t * pre8,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)215 static inline void highbd_obmc_variance_neon_4xh(const uint8_t *pre8,
216 int pre_stride,
217 const int32_t *wsrc,
218 const int32_t *mask, int h,
219 uint64_t *sse, int64_t *sum) {
220 assert(h % 2 == 0);
221 uint16_t *pre = CONVERT_TO_SHORTPTR(pre8);
222 uint32x4_t sse_u32 = vdupq_n_u32(0);
223 int32x4_t sum_s32 = vdupq_n_s32(0);
224
225 do {
226 uint16x8_t pre_u16 = load_unaligned_u16_4x2(pre, pre_stride);
227
228 highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32);
229
230 pre += 2 * pre_stride;
231 wsrc += 8;
232 mask += 8;
233 h -= 2;
234 } while (h != 0);
235
236 *sse = horizontal_long_add_u32x4(sse_u32);
237 *sum = horizontal_long_add_s32x4(sum_s32);
238 }
239
highbd_8_obmc_variance_cast(int64_t sum64,uint64_t sse64,int * sum,unsigned int * sse)240 static inline void highbd_8_obmc_variance_cast(int64_t sum64, uint64_t sse64,
241 int *sum, unsigned int *sse) {
242 *sum = (int)sum64;
243 *sse = (unsigned int)sse64;
244 }
245
highbd_10_obmc_variance_cast(int64_t sum64,uint64_t sse64,int * sum,unsigned int * sse)246 static inline void highbd_10_obmc_variance_cast(int64_t sum64, uint64_t sse64,
247 int *sum, unsigned int *sse) {
248 *sum = (int)ROUND_POWER_OF_TWO(sum64, 2);
249 *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 4);
250 }
251
highbd_12_obmc_variance_cast(int64_t sum64,uint64_t sse64,int * sum,unsigned int * sse)252 static inline void highbd_12_obmc_variance_cast(int64_t sum64, uint64_t sse64,
253 int *sum, unsigned int *sse) {
254 *sum = (int)ROUND_POWER_OF_TWO(sum64, 4);
255 *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 8);
256 }
257
258 #define HIGHBD_OBMC_VARIANCE_WXH_NEON(w, h, bitdepth) \
259 unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \
260 const uint8_t *pre, int pre_stride, const int32_t *wsrc, \
261 const int32_t *mask, unsigned int *sse) { \
262 int sum; \
263 int64_t sum64; \
264 uint64_t sse64; \
265 highbd_obmc_variance_neon_##w##xh(pre, pre_stride, wsrc, mask, h, &sse64, \
266 &sum64); \
267 highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse); \
268 return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h)); \
269 }
270
271 #define HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(w, h, bitdepth) \
272 unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon( \
273 const uint8_t *pre, int pre_stride, const int32_t *wsrc, \
274 const int32_t *mask, unsigned int *sse) { \
275 int sum; \
276 int64_t sum64; \
277 uint64_t sse64; \
278 highbd_obmc_variance_xlarge_neon_##w##xh(pre, pre_stride, wsrc, mask, h, \
279 &sse64, &sum64); \
280 highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse); \
281 return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h)); \
282 }
283
284 // 8-bit
285 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 8)
286 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 8)
287 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 8)
288
289 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 8)
290 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 8)
291 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 8)
292 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 8)
293
294 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 8)
295 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 8)
296 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 8)
297 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 8)
298 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 8)
299
300 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 8)
301 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 8)
302 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 8)
303 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 8)
304
305 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 8)
306 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 8)
307 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 8)
308 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 8)
309
310 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 8)
311 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 8)
312
313 // 10-bit
314 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 10)
315 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 10)
316 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 10)
317
318 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 10)
319 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 10)
320 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 10)
321 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 10)
322
323 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 10)
324 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 10)
325 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 10)
326 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 10)
327 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 10)
328
329 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 10)
330 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 10)
331 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 10)
332 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 10)
333
334 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 10)
335 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 10)
336 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 10)
337 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 10)
338
339 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 10)
340 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 10)
341
342 // 12-bit
343 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 12)
344 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 12)
345 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 12)
346
347 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 12)
348 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 12)
349 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 12)
350 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 12)
351
352 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 12)
353 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 12)
354 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 12)
355 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 12)
356 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 12)
357
358 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 12)
359 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 12)
360 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 12)
361 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(32, 64, 12)
362
363 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 12)
364 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 32, 12)
365 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 64, 12)
366 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 128, 12)
367
368 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 64, 12)
369 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 128, 12)
370