xref: /aosp_15_r20/external/libvpx/vpx_dsp/arm/variance_neon_dotprod.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2021 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 #include <assert.h>
13 
14 #include "./vpx_dsp_rtcd.h"
15 #include "./vpx_config.h"
16 
17 #include "vpx/vpx_integer.h"
18 #include "vpx_dsp/arm/mem_neon.h"
19 #include "vpx_dsp/arm/sum_neon.h"
20 #include "vpx_ports/mem.h"
21 
22 // Process a block of width 4 four rows at a time.
variance_4xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,uint32_t * sse,int * sum)23 static INLINE void variance_4xh_neon_dotprod(const uint8_t *src_ptr,
24                                              int src_stride,
25                                              const uint8_t *ref_ptr,
26                                              int ref_stride, int h,
27                                              uint32_t *sse, int *sum) {
28   uint32x4_t src_sum = vdupq_n_u32(0);
29   uint32x4_t ref_sum = vdupq_n_u32(0);
30   uint32x4_t sse_u32 = vdupq_n_u32(0);
31 
32   int i = h;
33   do {
34     const uint8x16_t s = load_unaligned_u8q(src_ptr, src_stride);
35     const uint8x16_t r = load_unaligned_u8q(ref_ptr, ref_stride);
36 
37     const uint8x16_t abs_diff = vabdq_u8(s, r);
38     sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
39 
40     src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
41     ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
42 
43     src_ptr += 4 * src_stride;
44     ref_ptr += 4 * ref_stride;
45     i -= 4;
46   } while (i != 0);
47 
48   *sum = horizontal_add_int32x4(
49       vreinterpretq_s32_u32(vsubq_u32(src_sum, ref_sum)));
50   *sse = horizontal_add_uint32x4(sse_u32);
51 }
52 
53 // Process a block of width 8 two rows at a time.
variance_8xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,uint32_t * sse,int * sum)54 static INLINE void variance_8xh_neon_dotprod(const uint8_t *src_ptr,
55                                              int src_stride,
56                                              const uint8_t *ref_ptr,
57                                              int ref_stride, int h,
58                                              uint32_t *sse, int *sum) {
59   uint32x4_t src_sum = vdupq_n_u32(0);
60   uint32x4_t ref_sum = vdupq_n_u32(0);
61   uint32x4_t sse_u32 = vdupq_n_u32(0);
62 
63   int i = h;
64   do {
65     const uint8x16_t s =
66         vcombine_u8(vld1_u8(src_ptr), vld1_u8(src_ptr + src_stride));
67     const uint8x16_t r =
68         vcombine_u8(vld1_u8(ref_ptr), vld1_u8(ref_ptr + ref_stride));
69 
70     const uint8x16_t abs_diff = vabdq_u8(s, r);
71     sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
72 
73     src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
74     ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
75 
76     src_ptr += 2 * src_stride;
77     ref_ptr += 2 * ref_stride;
78     i -= 2;
79   } while (i != 0);
80 
81   *sum = horizontal_add_int32x4(
82       vreinterpretq_s32_u32(vsubq_u32(src_sum, ref_sum)));
83   *sse = horizontal_add_uint32x4(sse_u32);
84 }
85 
86 // Process a block of width 16 one row at a time.
variance_16xh_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int h,uint32_t * sse,int * sum)87 static INLINE void variance_16xh_neon_dotprod(const uint8_t *src_ptr,
88                                               int src_stride,
89                                               const uint8_t *ref_ptr,
90                                               int ref_stride, int h,
91                                               uint32_t *sse, int *sum) {
92   uint32x4_t src_sum = vdupq_n_u32(0);
93   uint32x4_t ref_sum = vdupq_n_u32(0);
94   uint32x4_t sse_u32 = vdupq_n_u32(0);
95 
96   int i = h;
97   do {
98     const uint8x16_t s = vld1q_u8(src_ptr);
99     const uint8x16_t r = vld1q_u8(ref_ptr);
100 
101     const uint8x16_t abs_diff = vabdq_u8(s, r);
102     sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
103 
104     src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
105     ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
106 
107     src_ptr += src_stride;
108     ref_ptr += ref_stride;
109   } while (--i != 0);
110 
111   *sum = horizontal_add_int32x4(
112       vreinterpretq_s32_u32(vsubq_u32(src_sum, ref_sum)));
113   *sse = horizontal_add_uint32x4(sse_u32);
114 }
115 
116 // Process a block of any size where the width is divisible by 16.
variance_large_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,int w,int h,uint32_t * sse,int * sum)117 static INLINE void variance_large_neon_dotprod(const uint8_t *src_ptr,
118                                                int src_stride,
119                                                const uint8_t *ref_ptr,
120                                                int ref_stride, int w, int h,
121                                                uint32_t *sse, int *sum) {
122   uint32x4_t src_sum = vdupq_n_u32(0);
123   uint32x4_t ref_sum = vdupq_n_u32(0);
124   uint32x4_t sse_u32 = vdupq_n_u32(0);
125 
126   int i = h;
127   do {
128     int j = 0;
129     do {
130       const uint8x16_t s = vld1q_u8(src_ptr + j);
131       const uint8x16_t r = vld1q_u8(ref_ptr + j);
132 
133       const uint8x16_t abs_diff = vabdq_u8(s, r);
134       sse_u32 = vdotq_u32(sse_u32, abs_diff, abs_diff);
135 
136       src_sum = vdotq_u32(src_sum, s, vdupq_n_u8(1));
137       ref_sum = vdotq_u32(ref_sum, r, vdupq_n_u8(1));
138 
139       j += 16;
140     } while (j < w);
141 
142     src_ptr += src_stride;
143     ref_ptr += ref_stride;
144   } while (--i != 0);
145 
146   *sum = horizontal_add_int32x4(
147       vreinterpretq_s32_u32(vsubq_u32(src_sum, ref_sum)));
148   *sse = horizontal_add_uint32x4(sse_u32);
149 }
150 
variance_32xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int h,uint32_t * sse,int * sum)151 static INLINE void variance_32xh_neon_dotprod(const uint8_t *src,
152                                               int src_stride,
153                                               const uint8_t *ref,
154                                               int ref_stride, int h,
155                                               uint32_t *sse, int *sum) {
156   variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 32, h, sse,
157                               sum);
158 }
159 
variance_64xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int h,uint32_t * sse,int * sum)160 static INLINE void variance_64xh_neon_dotprod(const uint8_t *src,
161                                               int src_stride,
162                                               const uint8_t *ref,
163                                               int ref_stride, int h,
164                                               uint32_t *sse, int *sum) {
165   variance_large_neon_dotprod(src, src_stride, ref, ref_stride, 64, h, sse,
166                               sum);
167 }
168 
vpx_get8x8var_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse,int * sum)169 void vpx_get8x8var_neon_dotprod(const uint8_t *src_ptr, int src_stride,
170                                 const uint8_t *ref_ptr, int ref_stride,
171                                 unsigned int *sse, int *sum) {
172   variance_8xh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 8, sse,
173                             sum);
174 }
175 
vpx_get16x16var_neon_dotprod(const uint8_t * src_ptr,int src_stride,const uint8_t * ref_ptr,int ref_stride,unsigned int * sse,int * sum)176 void vpx_get16x16var_neon_dotprod(const uint8_t *src_ptr, int src_stride,
177                                   const uint8_t *ref_ptr, int ref_stride,
178                                   unsigned int *sse, int *sum) {
179   variance_16xh_neon_dotprod(src_ptr, src_stride, ref_ptr, ref_stride, 16, sse,
180                              sum);
181 }
182 
183 #define VARIANCE_WXH_NEON_DOTPROD(w, h, shift)                                \
184   unsigned int vpx_variance##w##x##h##_neon_dotprod(                          \
185       const uint8_t *src, int src_stride, const uint8_t *ref, int ref_stride, \
186       unsigned int *sse) {                                                    \
187     int sum;                                                                  \
188     variance_##w##xh_neon_dotprod(src, src_stride, ref, ref_stride, h, sse,   \
189                                   &sum);                                      \
190     return *sse - (uint32_t)(((int64_t)sum * sum) >> shift);                  \
191   }
192 
193 VARIANCE_WXH_NEON_DOTPROD(4, 4, 4)
194 VARIANCE_WXH_NEON_DOTPROD(4, 8, 5)
195 
196 VARIANCE_WXH_NEON_DOTPROD(8, 4, 5)
197 VARIANCE_WXH_NEON_DOTPROD(8, 8, 6)
198 VARIANCE_WXH_NEON_DOTPROD(8, 16, 7)
199 
200 VARIANCE_WXH_NEON_DOTPROD(16, 8, 7)
201 VARIANCE_WXH_NEON_DOTPROD(16, 16, 8)
202 VARIANCE_WXH_NEON_DOTPROD(16, 32, 9)
203 
204 VARIANCE_WXH_NEON_DOTPROD(32, 16, 9)
205 VARIANCE_WXH_NEON_DOTPROD(32, 32, 10)
206 VARIANCE_WXH_NEON_DOTPROD(32, 64, 11)
207 
208 VARIANCE_WXH_NEON_DOTPROD(64, 32, 11)
209 VARIANCE_WXH_NEON_DOTPROD(64, 64, 12)
210 
211 #undef VARIANCE_WXH_NEON_DOTPROD
212 
vpx_mse8xh_neon_dotprod(const unsigned char * src_ptr,int src_stride,const unsigned char * ref_ptr,int ref_stride,int h)213 static INLINE unsigned int vpx_mse8xh_neon_dotprod(const unsigned char *src_ptr,
214                                                    int src_stride,
215                                                    const unsigned char *ref_ptr,
216                                                    int ref_stride, int h) {
217   uint32x2_t sse_u32[2] = { vdup_n_u32(0), vdup_n_u32(0) };
218 
219   int i = h / 2;
220   do {
221     uint8x8_t s0, s1, r0, r1, diff0, diff1;
222 
223     s0 = vld1_u8(src_ptr);
224     src_ptr += src_stride;
225     s1 = vld1_u8(src_ptr);
226     src_ptr += src_stride;
227     r0 = vld1_u8(ref_ptr);
228     ref_ptr += ref_stride;
229     r1 = vld1_u8(ref_ptr);
230     ref_ptr += ref_stride;
231 
232     diff0 = vabd_u8(s0, r0);
233     diff1 = vabd_u8(s1, r1);
234 
235     sse_u32[0] = vdot_u32(sse_u32[0], diff0, diff0);
236     sse_u32[1] = vdot_u32(sse_u32[1], diff1, diff1);
237   } while (--i != 0);
238 
239   return horizontal_add_uint32x2(vadd_u32(sse_u32[0], sse_u32[1]));
240 }
241 
vpx_mse16xh_neon_dotprod(const unsigned char * src_ptr,int src_stride,const unsigned char * ref_ptr,int ref_stride,int h)242 static INLINE unsigned int vpx_mse16xh_neon_dotprod(
243     const unsigned char *src_ptr, int src_stride, const unsigned char *ref_ptr,
244     int ref_stride, int h) {
245   uint32x4_t sse_u32[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
246 
247   int i = h / 2;
248   do {
249     uint8x16_t s0, s1, r0, r1, diff0, diff1;
250 
251     s0 = vld1q_u8(src_ptr);
252     src_ptr += src_stride;
253     s1 = vld1q_u8(src_ptr);
254     src_ptr += src_stride;
255     r0 = vld1q_u8(ref_ptr);
256     ref_ptr += ref_stride;
257     r1 = vld1q_u8(ref_ptr);
258     ref_ptr += ref_stride;
259 
260     diff0 = vabdq_u8(s0, r0);
261     diff1 = vabdq_u8(s1, r1);
262 
263     sse_u32[0] = vdotq_u32(sse_u32[0], diff0, diff0);
264     sse_u32[1] = vdotq_u32(sse_u32[1], diff1, diff1);
265   } while (--i != 0);
266 
267   return horizontal_add_uint32x4(vaddq_u32(sse_u32[0], sse_u32[1]));
268 }
269 
vpx_get4x4sse_cs_neon_dotprod(const unsigned char * src_ptr,int src_stride,const unsigned char * ref_ptr,int ref_stride)270 unsigned int vpx_get4x4sse_cs_neon_dotprod(const unsigned char *src_ptr,
271                                            int src_stride,
272                                            const unsigned char *ref_ptr,
273                                            int ref_stride) {
274   uint8x16_t s = load_unaligned_u8q(src_ptr, src_stride);
275   uint8x16_t r = load_unaligned_u8q(ref_ptr, ref_stride);
276 
277   uint8x16_t abs_diff = vabdq_u8(s, r);
278 
279   uint32x4_t sse = vdotq_u32(vdupq_n_u32(0), abs_diff, abs_diff);
280 
281   return horizontal_add_uint32x4(sse);
282 }
283 
284 #define VPX_MSE_WXH_NEON_DOTPROD(w, h)                                   \
285   unsigned int vpx_mse##w##x##h##_neon_dotprod(                          \
286       const unsigned char *src_ptr, int src_stride,                      \
287       const unsigned char *ref_ptr, int ref_stride, unsigned int *sse) { \
288     *sse = vpx_mse##w##xh_neon_dotprod(src_ptr, src_stride, ref_ptr,     \
289                                        ref_stride, h);                   \
290     return *sse;                                                         \
291   }
292 
293 VPX_MSE_WXH_NEON_DOTPROD(8, 8)
294 VPX_MSE_WXH_NEON_DOTPROD(8, 16)
295 VPX_MSE_WXH_NEON_DOTPROD(16, 8)
296 VPX_MSE_WXH_NEON_DOTPROD(16, 16)
297 
298 #undef VPX_MSE_WXH_NEON_DOTPROD
299