xref: /aosp_15_r20/external/libaom/aom_dsp/arm/obmc_sad_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include "config/aom_config.h"
14 #include "config/aom_dsp_rtcd.h"
15 #include "mem_neon.h"
16 #include "sum_neon.h"
17 
obmc_sad_8x1_s16_neon(int16x8_t ref_s16,const int32_t * mask,const int32_t * wsrc,uint32x4_t * sum)18 static inline void obmc_sad_8x1_s16_neon(int16x8_t ref_s16, const int32_t *mask,
19                                          const int32_t *wsrc, uint32x4_t *sum) {
20   int32x4_t wsrc_lo = vld1q_s32(wsrc);
21   int32x4_t wsrc_hi = vld1q_s32(wsrc + 4);
22 
23   int32x4_t mask_lo = vld1q_s32(mask);
24   int32x4_t mask_hi = vld1q_s32(mask + 4);
25 
26   int16x8_t mask_s16 =
27       vuzpq_s16(vreinterpretq_s16_s32(mask_lo), vreinterpretq_s16_s32(mask_hi))
28           .val[0];
29 
30   int32x4_t pre_lo = vmull_s16(vget_low_s16(ref_s16), vget_low_s16(mask_s16));
31   int32x4_t pre_hi = vmull_s16(vget_high_s16(ref_s16), vget_high_s16(mask_s16));
32 
33   uint32x4_t abs_lo = vreinterpretq_u32_s32(vabdq_s32(wsrc_lo, pre_lo));
34   uint32x4_t abs_hi = vreinterpretq_u32_s32(vabdq_s32(wsrc_hi, pre_hi));
35 
36   *sum = vrsraq_n_u32(*sum, abs_lo, 12);
37   *sum = vrsraq_n_u32(*sum, abs_hi, 12);
38 }
39 
40 #if AOM_ARCH_AARCH64
41 
42 // Use tbl for doing a double-width zero extension from 8->32 bits since we can
43 // do this in one instruction rather than two (indices out of range (255 here)
44 // are set to zero by tbl).
45 DECLARE_ALIGNED(16, static const uint8_t, obmc_variance_permute_idx[]) = {
46   0,  255, 255, 255, 1,  255, 255, 255, 2,  255, 255, 255, 3,  255, 255, 255,
47   4,  255, 255, 255, 5,  255, 255, 255, 6,  255, 255, 255, 7,  255, 255, 255,
48   8,  255, 255, 255, 9,  255, 255, 255, 10, 255, 255, 255, 11, 255, 255, 255,
49   12, 255, 255, 255, 13, 255, 255, 255, 14, 255, 255, 255, 15, 255, 255, 255
50 };
51 
obmc_sad_8x1_s32_neon(uint32x4_t ref_u32_lo,uint32x4_t ref_u32_hi,const int32_t * mask,const int32_t * wsrc,uint32x4_t sum[2])52 static inline void obmc_sad_8x1_s32_neon(uint32x4_t ref_u32_lo,
53                                          uint32x4_t ref_u32_hi,
54                                          const int32_t *mask,
55                                          const int32_t *wsrc,
56                                          uint32x4_t sum[2]) {
57   int32x4_t wsrc_lo = vld1q_s32(wsrc);
58   int32x4_t wsrc_hi = vld1q_s32(wsrc + 4);
59   int32x4_t mask_lo = vld1q_s32(mask);
60   int32x4_t mask_hi = vld1q_s32(mask + 4);
61 
62   int32x4_t pre_lo = vmulq_s32(vreinterpretq_s32_u32(ref_u32_lo), mask_lo);
63   int32x4_t pre_hi = vmulq_s32(vreinterpretq_s32_u32(ref_u32_hi), mask_hi);
64 
65   uint32x4_t abs_lo = vreinterpretq_u32_s32(vabdq_s32(wsrc_lo, pre_lo));
66   uint32x4_t abs_hi = vreinterpretq_u32_s32(vabdq_s32(wsrc_hi, pre_hi));
67 
68   sum[0] = vrsraq_n_u32(sum[0], abs_lo, 12);
69   sum[1] = vrsraq_n_u32(sum[1], abs_hi, 12);
70 }
71 
obmc_sad_large_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int width,int height)72 static inline unsigned int obmc_sad_large_neon(const uint8_t *ref,
73                                                int ref_stride,
74                                                const int32_t *wsrc,
75                                                const int32_t *mask, int width,
76                                                int height) {
77   uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
78 
79   // Use tbl for doing a double-width zero extension from 8->32 bits since we
80   // can do this in one instruction rather than two.
81   uint8x16_t pre_idx0 = vld1q_u8(&obmc_variance_permute_idx[0]);
82   uint8x16_t pre_idx1 = vld1q_u8(&obmc_variance_permute_idx[16]);
83   uint8x16_t pre_idx2 = vld1q_u8(&obmc_variance_permute_idx[32]);
84   uint8x16_t pre_idx3 = vld1q_u8(&obmc_variance_permute_idx[48]);
85 
86   int h = height;
87   do {
88     int w = width;
89     const uint8_t *ref_ptr = ref;
90     do {
91       uint8x16_t r = vld1q_u8(ref_ptr);
92 
93       uint32x4_t ref_u32_lo = vreinterpretq_u32_u8(vqtbl1q_u8(r, pre_idx0));
94       uint32x4_t ref_u32_hi = vreinterpretq_u32_u8(vqtbl1q_u8(r, pre_idx1));
95       obmc_sad_8x1_s32_neon(ref_u32_lo, ref_u32_hi, mask, wsrc, sum);
96 
97       ref_u32_lo = vreinterpretq_u32_u8(vqtbl1q_u8(r, pre_idx2));
98       ref_u32_hi = vreinterpretq_u32_u8(vqtbl1q_u8(r, pre_idx3));
99       obmc_sad_8x1_s32_neon(ref_u32_lo, ref_u32_hi, mask + 8, wsrc + 8, sum);
100 
101       ref_ptr += 16;
102       wsrc += 16;
103       mask += 16;
104       w -= 16;
105     } while (w != 0);
106 
107     ref += ref_stride;
108   } while (--h != 0);
109 
110   return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
111 }
112 
113 #else  // !AOM_ARCH_AARCH64
114 
obmc_sad_large_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int width,int height)115 static inline unsigned int obmc_sad_large_neon(const uint8_t *ref,
116                                                int ref_stride,
117                                                const int32_t *wsrc,
118                                                const int32_t *mask, int width,
119                                                int height) {
120   uint32x4_t sum = vdupq_n_u32(0);
121 
122   int h = height;
123   do {
124     int w = width;
125     const uint8_t *ref_ptr = ref;
126     do {
127       uint8x16_t r = vld1q_u8(ref_ptr);
128 
129       int16x8_t ref_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_low_u8(r)));
130       obmc_sad_8x1_s16_neon(ref_s16, mask, wsrc, &sum);
131 
132       ref_s16 = vreinterpretq_s16_u16(vmovl_u8(vget_high_u8(r)));
133       obmc_sad_8x1_s16_neon(ref_s16, mask + 8, wsrc + 8, &sum);
134 
135       ref_ptr += 16;
136       wsrc += 16;
137       mask += 16;
138       w -= 16;
139     } while (w != 0);
140 
141     ref += ref_stride;
142   } while (--h != 0);
143 
144   return horizontal_add_u32x4(sum);
145 }
146 
147 #endif  // AOM_ARCH_AARCH64
148 
obmc_sad_128xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int h)149 static inline unsigned int obmc_sad_128xh_neon(const uint8_t *ref,
150                                                int ref_stride,
151                                                const int32_t *wsrc,
152                                                const int32_t *mask, int h) {
153   return obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 128, h);
154 }
155 
obmc_sad_64xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int h)156 static inline unsigned int obmc_sad_64xh_neon(const uint8_t *ref,
157                                               int ref_stride,
158                                               const int32_t *wsrc,
159                                               const int32_t *mask, int h) {
160   return obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 64, h);
161 }
162 
obmc_sad_32xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int h)163 static inline unsigned int obmc_sad_32xh_neon(const uint8_t *ref,
164                                               int ref_stride,
165                                               const int32_t *wsrc,
166                                               const int32_t *mask, int h) {
167   return obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 32, h);
168 }
169 
obmc_sad_16xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int h)170 static inline unsigned int obmc_sad_16xh_neon(const uint8_t *ref,
171                                               int ref_stride,
172                                               const int32_t *wsrc,
173                                               const int32_t *mask, int h) {
174   return obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 16, h);
175 }
176 
obmc_sad_8xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int height)177 static inline unsigned int obmc_sad_8xh_neon(const uint8_t *ref, int ref_stride,
178                                              const int32_t *wsrc,
179                                              const int32_t *mask, int height) {
180   uint32x4_t sum = vdupq_n_u32(0);
181 
182   int h = height;
183   do {
184     uint8x8_t r = vld1_u8(ref);
185 
186     int16x8_t ref_s16 = vreinterpretq_s16_u16(vmovl_u8(r));
187     obmc_sad_8x1_s16_neon(ref_s16, mask, wsrc, &sum);
188 
189     ref += ref_stride;
190     wsrc += 8;
191     mask += 8;
192   } while (--h != 0);
193 
194   return horizontal_add_u32x4(sum);
195 }
196 
obmc_sad_4xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int height)197 static inline unsigned int obmc_sad_4xh_neon(const uint8_t *ref, int ref_stride,
198                                              const int32_t *wsrc,
199                                              const int32_t *mask, int height) {
200   uint32x4_t sum = vdupq_n_u32(0);
201 
202   int h = height / 2;
203   do {
204     uint8x8_t r = load_unaligned_u8(ref, ref_stride);
205 
206     int16x8_t ref_s16 = vreinterpretq_s16_u16(vmovl_u8(r));
207     obmc_sad_8x1_s16_neon(ref_s16, mask, wsrc, &sum);
208 
209     ref += 2 * ref_stride;
210     wsrc += 8;
211     mask += 8;
212   } while (--h != 0);
213 
214   return horizontal_add_u32x4(sum);
215 }
216 
217 #define OBMC_SAD_WXH_NEON(w, h)                                   \
218   unsigned int aom_obmc_sad##w##x##h##_neon(                      \
219       const uint8_t *ref, int ref_stride, const int32_t *wsrc,    \
220       const int32_t *mask) {                                      \
221     return obmc_sad_##w##xh_neon(ref, ref_stride, wsrc, mask, h); \
222   }
223 
224 OBMC_SAD_WXH_NEON(4, 4)
225 OBMC_SAD_WXH_NEON(4, 8)
226 OBMC_SAD_WXH_NEON(4, 16)
227 
228 OBMC_SAD_WXH_NEON(8, 4)
229 OBMC_SAD_WXH_NEON(8, 8)
230 OBMC_SAD_WXH_NEON(8, 16)
231 OBMC_SAD_WXH_NEON(8, 32)
232 
233 OBMC_SAD_WXH_NEON(16, 4)
234 OBMC_SAD_WXH_NEON(16, 8)
235 OBMC_SAD_WXH_NEON(16, 16)
236 OBMC_SAD_WXH_NEON(16, 32)
237 OBMC_SAD_WXH_NEON(16, 64)
238 
239 OBMC_SAD_WXH_NEON(32, 8)
240 OBMC_SAD_WXH_NEON(32, 16)
241 OBMC_SAD_WXH_NEON(32, 32)
242 OBMC_SAD_WXH_NEON(32, 64)
243 
244 OBMC_SAD_WXH_NEON(64, 16)
245 OBMC_SAD_WXH_NEON(64, 32)
246 OBMC_SAD_WXH_NEON(64, 64)
247 OBMC_SAD_WXH_NEON(64, 128)
248 
249 OBMC_SAD_WXH_NEON(128, 64)
250 OBMC_SAD_WXH_NEON(128, 128)
251