1 /*
2 * Copyright (c) 2023 The WebM project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <arm_neon.h>
12
13 #include "./vpx_dsp_rtcd.h"
14 #include "vpx_dsp/arm/mem_neon.h"
15 #include "vpx_dsp/arm/sum_neon.h"
16
sse_16x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x4_t * sse)17 static INLINE void sse_16x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
18 uint32x4_t *sse) {
19 uint8x16_t s = vld1q_u8(src);
20 uint8x16_t r = vld1q_u8(ref);
21
22 uint8x16_t abs_diff = vabdq_u8(s, r);
23
24 *sse = vdotq_u32(*sse, abs_diff, abs_diff);
25 }
26
sse_8x1_neon_dotprod(const uint8_t * src,const uint8_t * ref,uint32x2_t * sse)27 static INLINE void sse_8x1_neon_dotprod(const uint8_t *src, const uint8_t *ref,
28 uint32x2_t *sse) {
29 uint8x8_t s = vld1_u8(src);
30 uint8x8_t r = vld1_u8(ref);
31
32 uint8x8_t abs_diff = vabd_u8(s, r);
33
34 *sse = vdot_u32(*sse, abs_diff, abs_diff);
35 }
36
sse_4x2_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,uint32x2_t * sse)37 static INLINE void sse_4x2_neon_dotprod(const uint8_t *src, int src_stride,
38 const uint8_t *ref, int ref_stride,
39 uint32x2_t *sse) {
40 uint8x8_t s = load_unaligned_u8(src, src_stride);
41 uint8x8_t r = load_unaligned_u8(ref, ref_stride);
42
43 uint8x8_t abs_diff = vabd_u8(s, r);
44
45 *sse = vdot_u32(*sse, abs_diff, abs_diff);
46 }
47
sse_wxh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)48 static INLINE uint32_t sse_wxh_neon_dotprod(const uint8_t *src, int src_stride,
49 const uint8_t *ref, int ref_stride,
50 int width, int height) {
51 uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
52
53 if ((width & 0x07) && ((width & 0x07) < 5)) {
54 int i = height;
55 do {
56 int j = 0;
57 do {
58 sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
59 sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
60 &sse[1]);
61 j += 8;
62 } while (j + 4 < width);
63
64 sse_4x2_neon_dotprod(src + j, src_stride, ref + j, ref_stride, &sse[0]);
65 src += 2 * src_stride;
66 ref += 2 * ref_stride;
67 i -= 2;
68 } while (i != 0);
69 } else {
70 int i = height;
71 do {
72 int j = 0;
73 do {
74 sse_8x1_neon_dotprod(src + j, ref + j, &sse[0]);
75 sse_8x1_neon_dotprod(src + j + src_stride, ref + j + ref_stride,
76 &sse[1]);
77 j += 8;
78 } while (j < width);
79
80 src += 2 * src_stride;
81 ref += 2 * ref_stride;
82 i -= 2;
83 } while (i != 0);
84 }
85 return horizontal_add_uint32x4(vcombine_u32(sse[0], sse[1]));
86 }
87
sse_64xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)88 static INLINE uint32_t sse_64xh_neon_dotprod(const uint8_t *src, int src_stride,
89 const uint8_t *ref, int ref_stride,
90 int height) {
91 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
92
93 int i = height;
94 do {
95 sse_16x1_neon_dotprod(src, ref, &sse[0]);
96 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
97 sse_16x1_neon_dotprod(src + 32, ref + 32, &sse[0]);
98 sse_16x1_neon_dotprod(src + 48, ref + 48, &sse[1]);
99
100 src += src_stride;
101 ref += ref_stride;
102 } while (--i != 0);
103
104 return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
105 }
106
sse_32xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)107 static INLINE uint32_t sse_32xh_neon_dotprod(const uint8_t *src, int src_stride,
108 const uint8_t *ref, int ref_stride,
109 int height) {
110 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
111
112 int i = height;
113 do {
114 sse_16x1_neon_dotprod(src, ref, &sse[0]);
115 sse_16x1_neon_dotprod(src + 16, ref + 16, &sse[1]);
116
117 src += src_stride;
118 ref += ref_stride;
119 } while (--i != 0);
120
121 return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
122 }
123
sse_16xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)124 static INLINE uint32_t sse_16xh_neon_dotprod(const uint8_t *src, int src_stride,
125 const uint8_t *ref, int ref_stride,
126 int height) {
127 uint32x4_t sse[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
128
129 int i = height;
130 do {
131 sse_16x1_neon_dotprod(src, ref, &sse[0]);
132 src += src_stride;
133 ref += ref_stride;
134 sse_16x1_neon_dotprod(src, ref, &sse[1]);
135 src += src_stride;
136 ref += ref_stride;
137 i -= 2;
138 } while (i != 0);
139
140 return horizontal_add_uint32x4(vaddq_u32(sse[0], sse[1]));
141 }
142
sse_8xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)143 static INLINE uint32_t sse_8xh_neon_dotprod(const uint8_t *src, int src_stride,
144 const uint8_t *ref, int ref_stride,
145 int height) {
146 uint32x2_t sse[2] = { vdup_n_u32(0), vdup_n_u32(0) };
147
148 int i = height;
149 do {
150 sse_8x1_neon_dotprod(src, ref, &sse[0]);
151 src += src_stride;
152 ref += ref_stride;
153 sse_8x1_neon_dotprod(src, ref, &sse[1]);
154 src += src_stride;
155 ref += ref_stride;
156 i -= 2;
157 } while (i != 0);
158
159 return horizontal_add_uint32x4(vcombine_u32(sse[0], sse[1]));
160 }
161
sse_4xh_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int height)162 static INLINE uint32_t sse_4xh_neon_dotprod(const uint8_t *src, int src_stride,
163 const uint8_t *ref, int ref_stride,
164 int height) {
165 uint32x2_t sse = vdup_n_u32(0);
166
167 int i = height;
168 do {
169 sse_4x2_neon_dotprod(src, src_stride, ref, ref_stride, &sse);
170
171 src += 2 * src_stride;
172 ref += 2 * ref_stride;
173 i -= 2;
174 } while (i != 0);
175
176 return horizontal_add_uint32x2(sse);
177 }
178
vpx_sse_neon_dotprod(const uint8_t * src,int src_stride,const uint8_t * ref,int ref_stride,int width,int height)179 int64_t vpx_sse_neon_dotprod(const uint8_t *src, int src_stride,
180 const uint8_t *ref, int ref_stride, int width,
181 int height) {
182 switch (width) {
183 case 4:
184 return sse_4xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
185 case 8:
186 return sse_8xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
187 case 16:
188 return sse_16xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
189 case 32:
190 return sse_32xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
191 case 64:
192 return sse_64xh_neon_dotprod(src, src_stride, ref, ref_stride, height);
193 default:
194 return sse_wxh_neon_dotprod(src, src_stride, ref, ref_stride, width,
195 height);
196 }
197 }
198