xref: /aosp_15_r20/external/libaom/aom_dsp/arm/obmc_variance_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16 #include "mem_neon.h"
17 #include "sum_neon.h"
18 
obmc_variance_8x1_s16_neon(int16x8_t pre_s16,const int32_t * wsrc,const int32_t * mask,int32x4_t * ssev,int32x4_t * sumv)19 static inline void obmc_variance_8x1_s16_neon(int16x8_t pre_s16,
20                                               const int32_t *wsrc,
21                                               const int32_t *mask,
22                                               int32x4_t *ssev,
23                                               int32x4_t *sumv) {
24   // For 4xh and 8xh we observe it is faster to avoid the double-widening of
25   // pre. Instead we do a single widening step and narrow the mask to 16-bits
26   // to allow us to perform a widening multiply. Widening multiply
27   // instructions have better throughput on some micro-architectures but for
28   // the larger block sizes this benefit is outweighed by the additional
29   // instruction needed to first narrow the mask vectors.
30 
31   int32x4_t wsrc_s32_lo = vld1q_s32(&wsrc[0]);
32   int32x4_t wsrc_s32_hi = vld1q_s32(&wsrc[4]);
33   int16x8_t mask_s16 = vuzpq_s16(vreinterpretq_s16_s32(vld1q_s32(&mask[0])),
34                                  vreinterpretq_s16_s32(vld1q_s32(&mask[4])))
35                            .val[0];
36 
37   int32x4_t diff_s32_lo =
38       vmlsl_s16(wsrc_s32_lo, vget_low_s16(pre_s16), vget_low_s16(mask_s16));
39   int32x4_t diff_s32_hi =
40       vmlsl_s16(wsrc_s32_hi, vget_high_s16(pre_s16), vget_high_s16(mask_s16));
41 
42   // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away
43   // from zero, however vrshrq_n_s32 rounds to nearest with ties rounded up.
44   // This difference only affects the bit patterns at the rounding breakpoints
45   // exactly, so we can add -1 to all negative numbers to move the breakpoint
46   // one value across and into the correct rounding region.
47   diff_s32_lo = vsraq_n_s32(diff_s32_lo, diff_s32_lo, 31);
48   diff_s32_hi = vsraq_n_s32(diff_s32_hi, diff_s32_hi, 31);
49   int32x4_t round_s32_lo = vrshrq_n_s32(diff_s32_lo, 12);
50   int32x4_t round_s32_hi = vrshrq_n_s32(diff_s32_hi, 12);
51 
52   *sumv = vrsraq_n_s32(*sumv, diff_s32_lo, 12);
53   *sumv = vrsraq_n_s32(*sumv, diff_s32_hi, 12);
54   *ssev = vmlaq_s32(*ssev, round_s32_lo, round_s32_lo);
55   *ssev = vmlaq_s32(*ssev, round_s32_hi, round_s32_hi);
56 }
57 
58 #if AOM_ARCH_AARCH64
59 
60 // Use tbl for doing a double-width zero extension from 8->32 bits since we can
61 // do this in one instruction rather than two (indices out of range (255 here)
62 // are set to zero by tbl).
63 DECLARE_ALIGNED(16, static const uint8_t, obmc_variance_permute_idx[]) = {
64   0,  255, 255, 255, 1,  255, 255, 255, 2,  255, 255, 255, 3,  255, 255, 255,
65   4,  255, 255, 255, 5,  255, 255, 255, 6,  255, 255, 255, 7,  255, 255, 255,
66   8,  255, 255, 255, 9,  255, 255, 255, 10, 255, 255, 255, 11, 255, 255, 255,
67   12, 255, 255, 255, 13, 255, 255, 255, 14, 255, 255, 255, 15, 255, 255, 255
68 };
69 
obmc_variance_8x1_s32_neon(int32x4_t pre_lo,int32x4_t pre_hi,const int32_t * wsrc,const int32_t * mask,int32x4_t * ssev,int32x4_t * sumv)70 static inline void obmc_variance_8x1_s32_neon(
71     int32x4_t pre_lo, int32x4_t pre_hi, const int32_t *wsrc,
72     const int32_t *mask, int32x4_t *ssev, int32x4_t *sumv) {
73   int32x4_t wsrc_lo = vld1q_s32(&wsrc[0]);
74   int32x4_t wsrc_hi = vld1q_s32(&wsrc[4]);
75   int32x4_t mask_lo = vld1q_s32(&mask[0]);
76   int32x4_t mask_hi = vld1q_s32(&mask[4]);
77 
78   int32x4_t diff_lo = vmlsq_s32(wsrc_lo, pre_lo, mask_lo);
79   int32x4_t diff_hi = vmlsq_s32(wsrc_hi, pre_hi, mask_hi);
80 
81   // ROUND_POWER_OF_TWO_SIGNED(value, 12) rounds to nearest with ties away from
82   // zero, however vrshrq_n_s32 rounds to nearest with ties rounded up. This
83   // difference only affects the bit patterns at the rounding breakpoints
84   // exactly, so we can add -1 to all negative numbers to move the breakpoint
85   // one value across and into the correct rounding region.
86   diff_lo = vsraq_n_s32(diff_lo, diff_lo, 31);
87   diff_hi = vsraq_n_s32(diff_hi, diff_hi, 31);
88   int32x4_t round_lo = vrshrq_n_s32(diff_lo, 12);
89   int32x4_t round_hi = vrshrq_n_s32(diff_hi, 12);
90 
91   *sumv = vrsraq_n_s32(*sumv, diff_lo, 12);
92   *sumv = vrsraq_n_s32(*sumv, diff_hi, 12);
93   *ssev = vmlaq_s32(*ssev, round_lo, round_lo);
94   *ssev = vmlaq_s32(*ssev, round_hi, round_hi);
95 }
96 
obmc_variance_large_neon(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int width,int height,unsigned * sse,int * sum)97 static inline void obmc_variance_large_neon(const uint8_t *pre, int pre_stride,
98                                             const int32_t *wsrc,
99                                             const int32_t *mask, int width,
100                                             int height, unsigned *sse,
101                                             int *sum) {
102   assert(width % 16 == 0);
103 
104   // Use tbl for doing a double-width zero extension from 8->32 bits since we
105   // can do this in one instruction rather than two.
106   uint8x16_t pre_idx0 = vld1q_u8(&obmc_variance_permute_idx[0]);
107   uint8x16_t pre_idx1 = vld1q_u8(&obmc_variance_permute_idx[16]);
108   uint8x16_t pre_idx2 = vld1q_u8(&obmc_variance_permute_idx[32]);
109   uint8x16_t pre_idx3 = vld1q_u8(&obmc_variance_permute_idx[48]);
110 
111   int32x4_t ssev = vdupq_n_s32(0);
112   int32x4_t sumv = vdupq_n_s32(0);
113 
114   int h = height;
115   do {
116     int w = width;
117     do {
118       uint8x16_t pre_u8 = vld1q_u8(pre);
119 
120       int32x4_t pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx0));
121       int32x4_t pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx1));
122       obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[0], &mask[0],
123                                  &ssev, &sumv);
124 
125       pre_s32_lo = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx2));
126       pre_s32_hi = vreinterpretq_s32_u8(vqtbl1q_u8(pre_u8, pre_idx3));
127       obmc_variance_8x1_s32_neon(pre_s32_lo, pre_s32_hi, &wsrc[8], &mask[8],
128                                  &ssev, &sumv);
129 
130       wsrc += 16;
131       mask += 16;
132       pre += 16;
133       w -= 16;
134     } while (w != 0);
135 
136     pre += pre_stride - width;
137   } while (--h != 0);
138 
139   *sse = horizontal_add_s32x4(ssev);
140   *sum = horizontal_add_s32x4(sumv);
141 }
142 
143 #else  // !AOM_ARCH_AARCH64
144 
obmc_variance_large_neon(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int width,int height,unsigned * sse,int * sum)145 static inline void obmc_variance_large_neon(const uint8_t *pre, int pre_stride,
146                                             const int32_t *wsrc,
147                                             const int32_t *mask, int width,
148                                             int height, unsigned *sse,
149                                             int *sum) {
150   // Non-aarch64 targets do not have a 128-bit tbl instruction, so use the
151   // widening version of the core kernel instead.
152 
153   assert(width % 16 == 0);
154 
155   int32x4_t ssev = vdupq_n_s32(0);
156   int32x4_t sumv = vdupq_n_s32(0);
157 
158   int h = height;
159   do {
160     int w = width;
161     do {
162       uint8x16_t pre_u8 = vld1q_u8(pre);
163 
164       int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(pre_u8)));
165       obmc_variance_8x1_s16_neon(pre_s16, &wsrc[0], &mask[0], &ssev, &sumv);
166 
167       pre_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(pre_u8)));
168       obmc_variance_8x1_s16_neon(pre_s16, &wsrc[8], &mask[8], &ssev, &sumv);
169 
170       wsrc += 16;
171       mask += 16;
172       pre += 16;
173       w -= 16;
174     } while (w != 0);
175 
176     pre += pre_stride - width;
177   } while (--h != 0);
178 
179   *sse = horizontal_add_s32x4(ssev);
180   *sum = horizontal_add_s32x4(sumv);
181 }
182 
183 #endif  // AOM_ARCH_AARCH64
184 
obmc_variance_neon_128xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,unsigned * sse,int * sum)185 static inline void obmc_variance_neon_128xh(const uint8_t *pre, int pre_stride,
186                                             const int32_t *wsrc,
187                                             const int32_t *mask, int h,
188                                             unsigned *sse, int *sum) {
189   obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 128, h, sse, sum);
190 }
191 
obmc_variance_neon_64xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,unsigned * sse,int * sum)192 static inline void obmc_variance_neon_64xh(const uint8_t *pre, int pre_stride,
193                                            const int32_t *wsrc,
194                                            const int32_t *mask, int h,
195                                            unsigned *sse, int *sum) {
196   obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 64, h, sse, sum);
197 }
198 
obmc_variance_neon_32xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,unsigned * sse,int * sum)199 static inline void obmc_variance_neon_32xh(const uint8_t *pre, int pre_stride,
200                                            const int32_t *wsrc,
201                                            const int32_t *mask, int h,
202                                            unsigned *sse, int *sum) {
203   obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 32, h, sse, sum);
204 }
205 
obmc_variance_neon_16xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,unsigned * sse,int * sum)206 static inline void obmc_variance_neon_16xh(const uint8_t *pre, int pre_stride,
207                                            const int32_t *wsrc,
208                                            const int32_t *mask, int h,
209                                            unsigned *sse, int *sum) {
210   obmc_variance_large_neon(pre, pre_stride, wsrc, mask, 16, h, sse, sum);
211 }
212 
obmc_variance_neon_8xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,unsigned * sse,int * sum)213 static inline void obmc_variance_neon_8xh(const uint8_t *pre, int pre_stride,
214                                           const int32_t *wsrc,
215                                           const int32_t *mask, int h,
216                                           unsigned *sse, int *sum) {
217   int32x4_t ssev = vdupq_n_s32(0);
218   int32x4_t sumv = vdupq_n_s32(0);
219 
220   do {
221     uint8x8_t pre_u8 = vld1_u8(pre);
222     int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8));
223 
224     obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv);
225 
226     pre += pre_stride;
227     wsrc += 8;
228     mask += 8;
229   } while (--h != 0);
230 
231   *sse = horizontal_add_s32x4(ssev);
232   *sum = horizontal_add_s32x4(sumv);
233 }
234 
obmc_variance_neon_4xh(const uint8_t * pre,int pre_stride,const int32_t * wsrc,const int32_t * mask,int h,unsigned * sse,int * sum)235 static inline void obmc_variance_neon_4xh(const uint8_t *pre, int pre_stride,
236                                           const int32_t *wsrc,
237                                           const int32_t *mask, int h,
238                                           unsigned *sse, int *sum) {
239   assert(h % 2 == 0);
240 
241   int32x4_t ssev = vdupq_n_s32(0);
242   int32x4_t sumv = vdupq_n_s32(0);
243 
244   do {
245     uint8x8_t pre_u8 = load_unaligned_u8(pre, pre_stride);
246     int16x8_t pre_s16 = vreinterpretq_s16_u16(vmovl_u8(pre_u8));
247 
248     obmc_variance_8x1_s16_neon(pre_s16, wsrc, mask, &ssev, &sumv);
249 
250     pre += 2 * pre_stride;
251     wsrc += 8;
252     mask += 8;
253     h -= 2;
254   } while (h != 0);
255 
256   *sse = horizontal_add_s32x4(ssev);
257   *sum = horizontal_add_s32x4(sumv);
258 }
259 
260 #define OBMC_VARIANCE_WXH_NEON(W, H)                                       \
261   unsigned aom_obmc_variance##W##x##H##_neon(                              \
262       const uint8_t *pre, int pre_stride, const int32_t *wsrc,             \
263       const int32_t *mask, unsigned *sse) {                                \
264     int sum;                                                               \
265     obmc_variance_neon_##W##xh(pre, pre_stride, wsrc, mask, H, sse, &sum); \
266     return *sse - (unsigned)(((int64_t)sum * sum) / (W * H));              \
267   }
268 
269 OBMC_VARIANCE_WXH_NEON(4, 4)
270 OBMC_VARIANCE_WXH_NEON(4, 8)
271 OBMC_VARIANCE_WXH_NEON(8, 4)
272 OBMC_VARIANCE_WXH_NEON(8, 8)
273 OBMC_VARIANCE_WXH_NEON(8, 16)
274 OBMC_VARIANCE_WXH_NEON(16, 8)
275 OBMC_VARIANCE_WXH_NEON(16, 16)
276 OBMC_VARIANCE_WXH_NEON(16, 32)
277 OBMC_VARIANCE_WXH_NEON(32, 16)
278 OBMC_VARIANCE_WXH_NEON(32, 32)
279 OBMC_VARIANCE_WXH_NEON(32, 64)
280 OBMC_VARIANCE_WXH_NEON(64, 32)
281 OBMC_VARIANCE_WXH_NEON(64, 64)
282 OBMC_VARIANCE_WXH_NEON(64, 128)
283 OBMC_VARIANCE_WXH_NEON(128, 64)
284 OBMC_VARIANCE_WXH_NEON(128, 128)
285 OBMC_VARIANCE_WXH_NEON(4, 16)
286 OBMC_VARIANCE_WXH_NEON(16, 4)
287 OBMC_VARIANCE_WXH_NEON(8, 32)
288 OBMC_VARIANCE_WXH_NEON(32, 8)
289 OBMC_VARIANCE_WXH_NEON(16, 64)
290 OBMC_VARIANCE_WXH_NEON(64, 16)
291