xref: /aosp_15_r20/external/libaom/aom_dsp/arm/highbd_obmc_variance_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_dsp/arm/sum_neon.h"
20 
highbd_obmc_variance_8x1_s16_neon(uint16x8_t pre,const int32_t * wsrc,const int32_t * mask,uint32x4_t * sse,int32x4_t * sum)21 static inline void highbd_obmc_variance_8x1_s16_neon(uint16x8_t pre,
22                                                      const int32_t *wsrc,
23                                                      const int32_t *mask,
24                                                      uint32x4_t *sse,
25                                                      int32x4_t *sum) {
26   int16x8_t pre_s16 = vreinterpretq_s16_u16(pre);
27   int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]);
28   int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]);
29 
30   int32x4_t mask_lo = vld1q_s32(&mask[0]);
31   int32x4_t mask_hi = vld1q_s32(&mask[4]);
32 
33   int16x8_t mask_s16 = vcombine_s16(vmovn_s32(mask_lo), vmovn_s32(mask_hi));
34 
35   int32x4_t diff_lo = vmull_s16(vget_low_s16(pre_s16), vget_low_s16(mask_s16));
36   int32x4_t diff_hi =
37       vmull_s16(vget_high_s16(pre_s16), vget_high_s16(mask_s16));
38 
39   diff_lo = vsubq_s32(wsrc_lo, diff_lo);
40   diff_hi = vsubq_s32(wsrc_hi, diff_hi);
41 
42   // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away
43   // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up.
44   // This difference only affects the bit patterns at the rounding breakpoints
45   // exactly, so we can add -1 to all negative numbers to move the breakpoint
46   // one value across and into the correct rounding region.
47   diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31);
48   diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31);
49   int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12);
50   int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12);
51 
52   *sum = vaddq_s32(*sum, round_lo);
53   *sum = vaddq_s32(*sum, round_hi);
54   *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_lo),
55                    vreinterpretq_u32_s32(round_lo));
56   *sse = vmlaq_u32(*sse, vreinterpretq_u32_s32(round_hi),
57                    vreinterpretq_u32_s32(round_hi));
58 }
59 
60 // For 12-bit data, we can only accumulate up to 256 elements in the unsigned
61 // 32-bit elements (4095*4095*256 = 4292870400) before we have to accumulate
62 // into 64-bit elements. Therefore blocks of size 32x64, 64x32, 64x64, 64x128,
63 // 128x64, 128x128 are processed in a different helper function.
highbd_obmc_variance_xlarge_neon(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int width,int h,int h_limit,uint64_t * sse,int64_t * sum)64 static inline void highbd_obmc_variance_xlarge_neon(
65     const uint8_t *pre, int pre_stride, const int32_t *wsrc,
66     const int32_t *mask, int width, int h, int h_limit, uint64_t *sse,
67     int64_t *sum) {
68   uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre);
69   int32x4_t sum_s32 = vdupq_n_s32(0);
70   uint64x2_t sse_u64 = vdupq_n_u64(0);
71 
72   // 'h_limit' is the number of 'w'-width rows we can process before our 32-bit
73   // accumulator overflows. After hitting this limit we accumulate into 64-bit
74   // elements.
75   int h_tmp = h > h_limit ? h_limit : h;
76 
77   do {
78     uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
79     int j = 0;
80 
81     do {
82       int i = 0;
83 
84       do {
85         uint16x8_t pre0 = vld1q_u16(pre_ptr + i);
86         highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32[0],
87                                           &sum_s32);
88 
89         uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8);
90         highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32[1],
91                                           &sum_s32);
92 
93         i += 16;
94         wsrc += 16;
95         mask += 16;
96       } while (i < width);
97 
98       pre_ptr += pre_stride;
99       j++;
100     } while (j < h_tmp);
101 
102     sse_u64 = vpadalq_u32(sse_u64, sse_u32[0]);
103     sse_u64 = vpadalq_u32(sse_u64, sse_u32[1]);
104     h -= h_tmp;
105   } while (h != 0);
106 
107   *sse = horizontal_add_u64x2(sse_u64);
108   *sum = horizontal_long_add_s32x4(sum_s32);
109 }
110 
highbd_obmc_variance_xlarge_neon_128xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)111 static inline void highbd_obmc_variance_xlarge_neon_128xh(
112     const uint8_t *pre, int pre_stride, const int32_t *wsrc,
113     const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
114   highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 128, h, 16, sse,
115                                    sum);
116 }
117 
highbd_obmc_variance_xlarge_neon_64xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)118 static inline void highbd_obmc_variance_xlarge_neon_64xh(
119     const uint8_t *pre, int pre_stride, const int32_t *wsrc,
120     const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
121   highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 64, h, 32, sse,
122                                    sum);
123 }
124 
highbd_obmc_variance_xlarge_neon_32xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)125 static inline void highbd_obmc_variance_xlarge_neon_32xh(
126     const uint8_t *pre, int pre_stride, const int32_t *wsrc,
127     const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
128   highbd_obmc_variance_xlarge_neon(pre, pre_stride, wsrc, mask, 32, h, 64, sse,
129                                    sum);
130 }
131 
highbd_obmc_variance_large_neon(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int width,int h,uint64_t * sse,int64_t * sum)132 static inline void highbd_obmc_variance_large_neon(
133     const uint8_t *pre, int pre_stride, const int32_t *wsrc,
134     const int32_t *mask, int width, int h, uint64_t *sse, int64_t *sum) {
135   uint16_t *pre_ptr = CONVERT_TO_SHORTPTR(pre);
136   uint32x4_t sse_u32 = vdupq_n_u32(0);
137   int32x4_t sum_s32 = vdupq_n_s32(0);
138 
139   do {
140     int i = 0;
141     do {
142       uint16x8_t pre0 = vld1q_u16(pre_ptr + i);
143       highbd_obmc_variance_8x1_s16_neon(pre0, wsrc, mask, &sse_u32, &sum_s32);
144 
145       uint16x8_t pre1 = vld1q_u16(pre_ptr + i + 8);
146       highbd_obmc_variance_8x1_s16_neon(pre1, wsrc + 8, mask + 8, &sse_u32,
147                                         &sum_s32);
148 
149       i += 16;
150       wsrc += 16;
151       mask += 16;
152     } while (i < width);
153 
154     pre_ptr += pre_stride;
155   } while (--h != 0);
156 
157   *sse = horizontal_long_add_u32x4(sse_u32);
158   *sum = horizontal_long_add_s32x4(sum_s32);
159 }
160 
highbd_obmc_variance_neon_128xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)161 static inline void highbd_obmc_variance_neon_128xh(
162     const uint8_t *pre, int pre_stride, const int32_t *wsrc,
163     const int32_t *mask, int h, uint64_t *sse, int64_t *sum) {
164   highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse,
165                                   sum);
166 }
167 
highbd_obmc_variance_neon_64xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)168 static inline void highbd_obmc_variance_neon_64xh(const uint8_t *pre,
169                                                   int pre_stride,
170                                                   const int32_t *wsrc,
171                                                   const int32_t *mask, int h,
172                                                   uint64_t *sse, int64_t *sum) {
173   highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum);
174 }
175 
highbd_obmc_variance_neon_32xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)176 static inline void highbd_obmc_variance_neon_32xh(const uint8_t *pre,
177                                                   int pre_stride,
178                                                   const int32_t *wsrc,
179                                                   const int32_t *mask, int h,
180                                                   uint64_t *sse, int64_t *sum) {
181   highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum);
182 }
183 
highbd_obmc_variance_neon_16xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)184 static inline void highbd_obmc_variance_neon_16xh(const uint8_t *pre,
185                                                   int pre_stride,
186                                                   const int32_t *wsrc,
187                                                   const int32_t *mask, int h,
188                                                   uint64_t *sse, int64_t *sum) {
189   highbd_obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum);
190 }
191 
highbd_obmc_variance_neon_8xh(const uint8_t * pre8,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)192 static inline void highbd_obmc_variance_neon_8xh(const uint8_t *pre8,
193                                                  int pre_stride,
194                                                  const int32_t *wsrc,
195                                                  const int32_t *mask, int h,
196                                                  uint64_t *sse, int64_t *sum) {
197   uint16_t *pre = CONVERT_TO_SHORTPTR(pre8);
198   uint32x4_t sse_u32 = vdupq_n_u32(0);
199   int32x4_t sum_s32 = vdupq_n_s32(0);
200 
201   do {
202     uint16x8_t pre_u16 = vld1q_u16(pre);
203 
204     highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32);
205 
206     pre += pre_stride;
207     wsrc += 8;
208     mask += 8;
209   } while (--h != 0);
210 
211   *sse = horizontal_long_add_u32x4(sse_u32);
212   *sum = horizontal_long_add_s32x4(sum_s32);
213 }
214 
highbd_obmc_variance_neon_4xh(const uint8_t * pre8,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,uint64_t * sse,int64_t * sum)215 static inline void highbd_obmc_variance_neon_4xh(const uint8_t *pre8,
216                                                  int pre_stride,
217                                                  const int32_t *wsrc,
218                                                  const int32_t *mask, int h,
219                                                  uint64_t *sse, int64_t *sum) {
220   assert(h % 2 == 0);
221   uint16_t *pre = CONVERT_TO_SHORTPTR(pre8);
222   uint32x4_t sse_u32 = vdupq_n_u32(0);
223   int32x4_t sum_s32 = vdupq_n_s32(0);
224 
225   do {
226     uint16x8_t pre_u16 = load_unaligned_u16_4x2(pre, pre_stride);
227 
228     highbd_obmc_variance_8x1_s16_neon(pre_u16, wsrc, mask, &sse_u32, &sum_s32);
229 
230     pre += 2 * pre_stride;
231     wsrc += 8;
232     mask += 8;
233     h -= 2;
234   } while (h != 0);
235 
236   *sse = horizontal_long_add_u32x4(sse_u32);
237   *sum = horizontal_long_add_s32x4(sum_s32);
238 }
239 
highbd_8_obmc_variance_cast(int64_t sum64,uint64_t sse64,int * sum,unsigned int * sse)240 static inline void highbd_8_obmc_variance_cast(int64_t sum64, uint64_t sse64,
241                                                int *sum, unsigned int *sse) {
242   *sum = (int)sum64;
243   *sse = (unsigned int)sse64;
244 }
245 
highbd_10_obmc_variance_cast(int64_t sum64,uint64_t sse64,int * sum,unsigned int * sse)246 static inline void highbd_10_obmc_variance_cast(int64_t sum64, uint64_t sse64,
247                                                 int *sum, unsigned int *sse) {
248   *sum = (int)ROUND_POWER_OF_TWO(sum64, 2);
249   *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 4);
250 }
251 
highbd_12_obmc_variance_cast(int64_t sum64,uint64_t sse64,int * sum,unsigned int * sse)252 static inline void highbd_12_obmc_variance_cast(int64_t sum64, uint64_t sse64,
253                                                 int *sum, unsigned int *sse) {
254   *sum = (int)ROUND_POWER_OF_TWO(sum64, 4);
255   *sse = (unsigned int)ROUND_POWER_OF_TWO(sse64, 8);
256 }
257 
258 #define HIGHBD_OBMC_VARIANCE_WXH_NEON(w, h, bitdepth)                         \
259   unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon(         \
260       const uint8_t *pre, int pre_stride, const int32_t *wsrc,                \
261       const int32_t *mask, unsigned int *sse) {                               \
262     int sum;                                                                  \
263     int64_t sum64;                                                            \
264     uint64_t sse64;                                                           \
265     highbd_obmc_variance_neon_##w##xh(pre, pre_stride, wsrc, mask, h, &sse64, \
266                                       &sum64);                                \
267     highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse);          \
268     return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h));             \
269   }
270 
271 #define HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(w, h, bitdepth)                 \
272   unsigned int aom_highbd_##bitdepth##_obmc_variance##w##x##h##_neon(        \
273       const uint8_t *pre, int pre_stride, const int32_t *wsrc,               \
274       const int32_t *mask, unsigned int *sse) {                              \
275     int sum;                                                                 \
276     int64_t sum64;                                                           \
277     uint64_t sse64;                                                          \
278     highbd_obmc_variance_xlarge_neon_##w##xh(pre, pre_stride, wsrc, mask, h, \
279                                              &sse64, &sum64);                \
280     highbd_##bitdepth##_obmc_variance_cast(sum64, sse64, &sum, sse);         \
281     return *sse - (unsigned int)(((int64_t)sum * sum) / (w * h));            \
282   }
283 
284 // 8-bit
285 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 8)
286 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 8)
287 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 8)
288 
289 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 8)
290 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 8)
291 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 8)
292 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 8)
293 
294 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 8)
295 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 8)
296 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 8)
297 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 8)
298 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 8)
299 
300 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 8)
301 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 8)
302 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 8)
303 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 8)
304 
305 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 8)
306 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 8)
307 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 8)
308 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 8)
309 
310 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 8)
311 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 8)
312 
313 // 10-bit
314 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 10)
315 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 10)
316 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 10)
317 
318 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 10)
319 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 10)
320 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 10)
321 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 10)
322 
323 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 10)
324 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 10)
325 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 10)
326 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 10)
327 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 10)
328 
329 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 10)
330 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 10)
331 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 10)
332 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 64, 10)
333 
334 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 10)
335 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 32, 10)
336 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 64, 10)
337 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 128, 10)
338 
339 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 64, 10)
340 HIGHBD_OBMC_VARIANCE_WXH_NEON(128, 128, 10)
341 
342 // 12-bit
343 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 4, 12)
344 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 8, 12)
345 HIGHBD_OBMC_VARIANCE_WXH_NEON(4, 16, 12)
346 
347 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 4, 12)
348 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 8, 12)
349 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 16, 12)
350 HIGHBD_OBMC_VARIANCE_WXH_NEON(8, 32, 12)
351 
352 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 4, 12)
353 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 8, 12)
354 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 16, 12)
355 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 32, 12)
356 HIGHBD_OBMC_VARIANCE_WXH_NEON(16, 64, 12)
357 
358 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 8, 12)
359 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 16, 12)
360 HIGHBD_OBMC_VARIANCE_WXH_NEON(32, 32, 12)
361 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(32, 64, 12)
362 
363 HIGHBD_OBMC_VARIANCE_WXH_NEON(64, 16, 12)
364 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 32, 12)
365 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 64, 12)
366 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(64, 128, 12)
367 
368 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 64, 12)
369 HIGHBD_OBMC_VARIANCE_WXH_XLARGE_NEON(128, 128, 12)
370