1 /*
2 * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #include <arm_neon.h>
13
14 #include "config/aom_config.h"
15 #include "config/aom_dsp_rtcd.h"
16
17 #include "aom/aom_integer.h"
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "aom_dsp/arm/sum_neon.h"
20
highbd_obmc_sad_8x1_s16_neon(uint16x8_t ref,const int32_t * mask,const int32_t * wsrc,uint32x4_t * sum)21 static inline void highbd_obmc_sad_8x1_s16_neon(uint16x8_t ref,
22 const int32_t *mask,
23 const int32_t *wsrc,
24 uint32x4_t *sum) {
25 int16x8_t ref_s16 = vreinterpretq_s16_u16(ref);
26
27 int32x4_t wsrc_lo = vld1q_s32(wsrc);
28 int32x4_t wsrc_hi = vld1q_s32(wsrc + 4);
29
30 int32x4_t mask_lo = vld1q_s32(mask);
31 int32x4_t mask_hi = vld1q_s32(mask + 4);
32
33 int16x8_t mask_s16 = vcombine_s16(vmovn_s32(mask_lo), vmovn_s32(mask_hi));
34
35 int32x4_t pre_lo = vmull_s16(vget_low_s16(ref_s16), vget_low_s16(mask_s16));
36 int32x4_t pre_hi = vmull_s16(vget_high_s16(ref_s16), vget_high_s16(mask_s16));
37
38 uint32x4_t abs_lo = vreinterpretq_u32_s32(vabdq_s32(wsrc_lo, pre_lo));
39 uint32x4_t abs_hi = vreinterpretq_u32_s32(vabdq_s32(wsrc_hi, pre_hi));
40
41 *sum = vrsraq_n_u32(*sum, abs_lo, 12);
42 *sum = vrsraq_n_u32(*sum, abs_hi, 12);
43 }
44
highbd_obmc_sad_4xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int height)45 static inline unsigned int highbd_obmc_sad_4xh_neon(const uint8_t *ref,
46 int ref_stride,
47 const int32_t *wsrc,
48 const int32_t *mask,
49 int height) {
50 const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
51 uint32x4_t sum = vdupq_n_u32(0);
52
53 int h = height / 2;
54 do {
55 uint16x8_t r = load_unaligned_u16_4x2(ref_ptr, ref_stride);
56
57 highbd_obmc_sad_8x1_s16_neon(r, mask, wsrc, &sum);
58
59 ref_ptr += 2 * ref_stride;
60 wsrc += 8;
61 mask += 8;
62 } while (--h != 0);
63
64 return horizontal_add_u32x4(sum);
65 }
66
highbd_obmc_sad_8xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int height)67 static inline unsigned int highbd_obmc_sad_8xh_neon(const uint8_t *ref,
68 int ref_stride,
69 const int32_t *wsrc,
70 const int32_t *mask,
71 int height) {
72 const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
73 uint32x4_t sum = vdupq_n_u32(0);
74
75 do {
76 uint16x8_t r = vld1q_u16(ref_ptr);
77
78 highbd_obmc_sad_8x1_s16_neon(r, mask, wsrc, &sum);
79
80 ref_ptr += ref_stride;
81 wsrc += 8;
82 mask += 8;
83 } while (--height != 0);
84
85 return horizontal_add_u32x4(sum);
86 }
87
highbd_obmc_sad_large_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int width,int height)88 static inline unsigned int highbd_obmc_sad_large_neon(const uint8_t *ref,
89 int ref_stride,
90 const int32_t *wsrc,
91 const int32_t *mask,
92 int width, int height) {
93 const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
94 uint32x4_t sum[2] = { vdupq_n_u32(0), vdupq_n_u32(0) };
95
96 do {
97 int i = 0;
98 do {
99 uint16x8_t r0 = vld1q_u16(ref_ptr + i);
100 highbd_obmc_sad_8x1_s16_neon(r0, mask, wsrc, &sum[0]);
101
102 uint16x8_t r1 = vld1q_u16(ref_ptr + i + 8);
103 highbd_obmc_sad_8x1_s16_neon(r1, mask + 8, wsrc + 8, &sum[1]);
104
105 wsrc += 16;
106 mask += 16;
107 i += 16;
108 } while (i < width);
109
110 ref_ptr += ref_stride;
111 } while (--height != 0);
112
113 return horizontal_add_u32x4(vaddq_u32(sum[0], sum[1]));
114 }
115
highbd_obmc_sad_16xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int h)116 static inline unsigned int highbd_obmc_sad_16xh_neon(const uint8_t *ref,
117 int ref_stride,
118 const int32_t *wsrc,
119 const int32_t *mask,
120 int h) {
121 return highbd_obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 16, h);
122 }
123
highbd_obmc_sad_32xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int height)124 static inline unsigned int highbd_obmc_sad_32xh_neon(const uint8_t *ref,
125 int ref_stride,
126 const int32_t *wsrc,
127 const int32_t *mask,
128 int height) {
129 uint32x4_t sum[4] = { vdupq_n_u32(0), vdupq_n_u32(0), vdupq_n_u32(0),
130 vdupq_n_u32(0) };
131 const uint16_t *ref_ptr = CONVERT_TO_SHORTPTR(ref);
132
133 do {
134 uint16x8_t r0 = vld1q_u16(ref_ptr);
135 uint16x8_t r1 = vld1q_u16(ref_ptr + 8);
136 uint16x8_t r2 = vld1q_u16(ref_ptr + 16);
137 uint16x8_t r3 = vld1q_u16(ref_ptr + 24);
138
139 highbd_obmc_sad_8x1_s16_neon(r0, mask, wsrc, &sum[0]);
140 highbd_obmc_sad_8x1_s16_neon(r1, mask + 8, wsrc + 8, &sum[1]);
141 highbd_obmc_sad_8x1_s16_neon(r2, mask + 16, wsrc + 16, &sum[2]);
142 highbd_obmc_sad_8x1_s16_neon(r3, mask + 24, wsrc + 24, &sum[3]);
143
144 wsrc += 32;
145 mask += 32;
146 ref_ptr += ref_stride;
147 } while (--height != 0);
148
149 sum[0] = vaddq_u32(sum[0], sum[1]);
150 sum[2] = vaddq_u32(sum[2], sum[3]);
151
152 return horizontal_add_u32x4(vaddq_u32(sum[0], sum[2]));
153 }
154
highbd_obmc_sad_64xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int h)155 static inline unsigned int highbd_obmc_sad_64xh_neon(const uint8_t *ref,
156 int ref_stride,
157 const int32_t *wsrc,
158 const int32_t *mask,
159 int h) {
160 return highbd_obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 64, h);
161 }
162
highbd_obmc_sad_128xh_neon(const uint8_t * ref,int ref_stride,const int32_t * wsrc,const int32_t * mask,int h)163 static inline unsigned int highbd_obmc_sad_128xh_neon(const uint8_t *ref,
164 int ref_stride,
165 const int32_t *wsrc,
166 const int32_t *mask,
167 int h) {
168 return highbd_obmc_sad_large_neon(ref, ref_stride, wsrc, mask, 128, h);
169 }
170
171 #define HIGHBD_OBMC_SAD_WXH_NEON(w, h) \
172 unsigned int aom_highbd_obmc_sad##w##x##h##_neon( \
173 const uint8_t *ref, int ref_stride, const int32_t *wsrc, \
174 const int32_t *mask) { \
175 return highbd_obmc_sad_##w##xh_neon(ref, ref_stride, wsrc, mask, h); \
176 }
177
178 HIGHBD_OBMC_SAD_WXH_NEON(4, 4)
179 HIGHBD_OBMC_SAD_WXH_NEON(4, 8)
180
181 HIGHBD_OBMC_SAD_WXH_NEON(8, 4)
182 HIGHBD_OBMC_SAD_WXH_NEON(8, 8)
183 HIGHBD_OBMC_SAD_WXH_NEON(8, 16)
184
185 HIGHBD_OBMC_SAD_WXH_NEON(16, 8)
186 HIGHBD_OBMC_SAD_WXH_NEON(16, 16)
187 HIGHBD_OBMC_SAD_WXH_NEON(16, 32)
188
189 HIGHBD_OBMC_SAD_WXH_NEON(32, 16)
190 HIGHBD_OBMC_SAD_WXH_NEON(32, 32)
191 HIGHBD_OBMC_SAD_WXH_NEON(32, 64)
192
193 HIGHBD_OBMC_SAD_WXH_NEON(64, 32)
194 HIGHBD_OBMC_SAD_WXH_NEON(64, 64)
195 HIGHBD_OBMC_SAD_WXH_NEON(64, 128)
196
197 HIGHBD_OBMC_SAD_WXH_NEON(128, 64)
198 HIGHBD_OBMC_SAD_WXH_NEON(128, 128)
199
200 #if !CONFIG_REALTIME_ONLY
201 HIGHBD_OBMC_SAD_WXH_NEON(4, 16)
202
203 HIGHBD_OBMC_SAD_WXH_NEON(8, 32)
204
205 HIGHBD_OBMC_SAD_WXH_NEON(16, 4)
206 HIGHBD_OBMC_SAD_WXH_NEON(16, 64)
207
208 HIGHBD_OBMC_SAD_WXH_NEON(32, 8)
209
210 HIGHBD_OBMC_SAD_WXH_NEON(64, 16)
211 #endif // !CONFIG_REALTIME_ONLY
212