xref: /aosp_15_r20/external/libaom/av1/common/arm/highbd_reconintra_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2023, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 
15 #include "aom_dsp/arm/sum_neon.h"
16 #include "config/av1_rtcd.h"
17 
18 #define MAX_UPSAMPLE_SZ 16
19 
av1_highbd_filter_intra_edge_neon(uint16_t * p,int sz,int strength)20 void av1_highbd_filter_intra_edge_neon(uint16_t *p, int sz, int strength) {
21   if (!strength) return;
22   assert(sz >= 0 && sz <= 129);
23 
24   DECLARE_ALIGNED(16, static const uint16_t,
25                   idx[8]) = { 0, 1, 2, 3, 4, 5, 6, 7 };
26   const uint16x8_t index = vld1q_u16(idx);
27 
28   uint16_t edge[160];  // Max value of sz + enough padding for vector accesses.
29   memcpy(edge + 1, p, sz * sizeof(*p));
30 
31   // Populate extra space appropriately.
32   edge[0] = edge[1];
33   edge[sz + 1] = edge[sz];
34   edge[sz + 2] = edge[sz];
35 
36   // Don't overwrite first pixel.
37   uint16_t *dst = p + 1;
38   sz--;
39 
40   if (strength == 1) {  // Filter: {4, 8, 4}.
41     const uint16_t *src = edge + 1;
42 
43     while (sz >= 8) {
44       uint16x8_t s0 = vld1q_u16(src);
45       uint16x8_t s1 = vld1q_u16(src + 1);
46       uint16x8_t s2 = vld1q_u16(src + 2);
47 
48       // Make use of the identity:
49       // (4*a + 8*b + 4*c) >> 4 == (a + (b << 1) + c) >> 2
50       uint16x8_t t0 = vaddq_u16(s0, s2);
51       uint16x8_t t1 = vaddq_u16(s1, s1);
52       uint16x8_t sum = vaddq_u16(t0, t1);
53       uint16x8_t res = vrshrq_n_u16(sum, 2);
54 
55       vst1q_u16(dst, res);
56 
57       src += 8;
58       dst += 8;
59       sz -= 8;
60     }
61 
62     if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
63       uint16x8_t s0 = vld1q_u16(src);
64       uint16x8_t s1 = vld1q_u16(src + 1);
65       uint16x8_t s2 = vld1q_u16(src + 2);
66 
67       // Make use of the identity:
68       // (4*a + 8*b + 4*c) >> 4 == (a + (b << 1) + c) >> 2
69       uint16x8_t t0 = vaddq_u16(s0, s2);
70       uint16x8_t t1 = vaddq_u16(s1, s1);
71       uint16x8_t sum = vaddq_u16(t0, t1);
72       uint16x8_t res = vrshrq_n_u16(sum, 2);
73 
74       // Mask off out-of-bounds indices.
75       uint16x8_t current_dst = vld1q_u16(dst);
76       uint16x8_t mask = vcgtq_u16(vdupq_n_u16(sz), index);
77       res = vbslq_u16(mask, res, current_dst);
78 
79       vst1q_u16(dst, res);
80     }
81   } else if (strength == 2) {  // Filter: {5, 6, 5}.
82     const uint16_t *src = edge + 1;
83 
84     const uint16x8x3_t filter = { { vdupq_n_u16(5), vdupq_n_u16(6),
85                                     vdupq_n_u16(5) } };
86     while (sz >= 8) {
87       uint16x8_t s0 = vld1q_u16(src);
88       uint16x8_t s1 = vld1q_u16(src + 1);
89       uint16x8_t s2 = vld1q_u16(src + 2);
90 
91       uint16x8_t accum = vmulq_u16(s0, filter.val[0]);
92       accum = vmlaq_u16(accum, s1, filter.val[1]);
93       accum = vmlaq_u16(accum, s2, filter.val[2]);
94       uint16x8_t res = vrshrq_n_u16(accum, 4);
95 
96       vst1q_u16(dst, res);
97 
98       src += 8;
99       dst += 8;
100       sz -= 8;
101     }
102 
103     if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
104       uint16x8_t s0 = vld1q_u16(src);
105       uint16x8_t s1 = vld1q_u16(src + 1);
106       uint16x8_t s2 = vld1q_u16(src + 2);
107 
108       uint16x8_t accum = vmulq_u16(s0, filter.val[0]);
109       accum = vmlaq_u16(accum, s1, filter.val[1]);
110       accum = vmlaq_u16(accum, s2, filter.val[2]);
111       uint16x8_t res = vrshrq_n_u16(accum, 4);
112 
113       // Mask off out-of-bounds indices.
114       uint16x8_t current_dst = vld1q_u16(dst);
115       uint16x8_t mask = vcgtq_u16(vdupq_n_u16(sz), index);
116       res = vbslq_u16(mask, res, current_dst);
117 
118       vst1q_u16(dst, res);
119     }
120   } else {  // Filter {2, 4, 4, 4, 2}.
121     const uint16_t *src = edge;
122 
123     while (sz >= 8) {
124       uint16x8_t s0 = vld1q_u16(src);
125       uint16x8_t s1 = vld1q_u16(src + 1);
126       uint16x8_t s2 = vld1q_u16(src + 2);
127       uint16x8_t s3 = vld1q_u16(src + 3);
128       uint16x8_t s4 = vld1q_u16(src + 4);
129 
130       // Make use of the identity:
131       // (2*a + 4*b + 4*c + 4*d + 2*e) >> 4 == (a + ((b + c + d) << 1) + e) >> 3
132       uint16x8_t t0 = vaddq_u16(s0, s4);
133       uint16x8_t t1 = vaddq_u16(s1, s2);
134       t1 = vaddq_u16(t1, s3);
135       t1 = vaddq_u16(t1, t1);
136       uint16x8_t sum = vaddq_u16(t0, t1);
137       uint16x8_t res = vrshrq_n_u16(sum, 3);
138 
139       vst1q_u16(dst, res);
140 
141       src += 8;
142       dst += 8;
143       sz -= 8;
144     }
145 
146     if (sz > 0) {  // Handle sz < 8 to avoid modifying out-of-bounds values.
147       uint16x8_t s0 = vld1q_u16(src);
148       uint16x8_t s1 = vld1q_u16(src + 1);
149       uint16x8_t s2 = vld1q_u16(src + 2);
150       uint16x8_t s3 = vld1q_u16(src + 3);
151       uint16x8_t s4 = vld1q_u16(src + 4);
152 
153       // Make use of the identity:
154       // (2*a + 4*b + 4*c + 4*d + 2*e) >> 4 == (a + ((b + c + d) << 1) + e) >> 3
155       uint16x8_t t0 = vaddq_u16(s0, s4);
156       uint16x8_t t1 = vaddq_u16(s1, s2);
157       t1 = vaddq_u16(t1, s3);
158       t1 = vaddq_u16(t1, t1);
159       uint16x8_t sum = vaddq_u16(t0, t1);
160       uint16x8_t res = vrshrq_n_u16(sum, 3);
161 
162       // Mask off out-of-bounds indices.
163       uint16x8_t current_dst = vld1q_u16(dst);
164       uint16x8_t mask = vcgtq_u16(vdupq_n_u16(sz), index);
165       res = vbslq_u16(mask, res, current_dst);
166 
167       vst1q_u16(dst, res);
168     }
169   }
170 }
171 
av1_highbd_upsample_intra_edge_neon(uint16_t * p,int sz,int bd)172 void av1_highbd_upsample_intra_edge_neon(uint16_t *p, int sz, int bd) {
173   if (!sz) return;
174 
175   assert(sz <= MAX_UPSAMPLE_SZ);
176 
177   uint16_t edge[MAX_UPSAMPLE_SZ + 3];
178   const uint16_t *src = edge;
179 
180   // Copy p[-1..(sz-1)] and pad out both ends.
181   edge[0] = p[-1];
182   edge[1] = p[-1];
183   memcpy(edge + 2, p, sz * 2);
184   edge[sz + 2] = p[sz - 1];
185   p[-2] = p[-1];
186 
187   uint16x8_t pixel_val_max = vdupq_n_u16((1 << bd) - 1);
188 
189   uint16_t *dst = p - 1;
190 
191   if (bd == 12) {
192     do {
193       uint16x8_t s0 = vld1q_u16(src);
194       uint16x8_t s1 = vld1q_u16(src + 1);
195       uint16x8_t s2 = vld1q_u16(src + 2);
196       uint16x8_t s3 = vld1q_u16(src + 3);
197 
198       uint16x8_t t0 = vaddq_u16(s1, s2);
199       uint16x8_t t1 = vaddq_u16(s0, s3);
200       uint32x4_t acc0 = vmull_n_u16(vget_low_u16(t0), 9);
201       acc0 = vqsubq_u32(acc0, vmovl_u16(vget_low_u16(t1)));
202       uint32x4_t acc1 = vmull_n_u16(vget_high_u16(t0), 9);
203       acc1 = vqsubq_u32(acc1, vmovl_u16(vget_high_u16(t1)));
204 
205       uint16x8x2_t res;
206       res.val[0] = vcombine_u16(vrshrn_n_u32(acc0, 4), vrshrn_n_u32(acc1, 4));
207       // Clamp pixel values at bitdepth maximum.
208       res.val[0] = vminq_u16(res.val[0], pixel_val_max);
209       res.val[1] = s2;
210 
211       vst2q_u16(dst, res);
212 
213       src += 8;
214       dst += 16;
215       sz -= 8;
216     } while (sz > 0);
217   } else {  // Bit depth is 8 or 10.
218     do {
219       uint16x8_t s0 = vld1q_u16(src);
220       uint16x8_t s1 = vld1q_u16(src + 1);
221       uint16x8_t s2 = vld1q_u16(src + 2);
222       uint16x8_t s3 = vld1q_u16(src + 3);
223 
224       uint16x8_t t0 = vaddq_u16(s0, s3);
225       uint16x8_t t1 = vaddq_u16(s1, s2);
226       t1 = vmulq_n_u16(t1, 9);
227       t1 = vqsubq_u16(t1, t0);
228 
229       uint16x8x2_t res;
230       res.val[0] = vrshrq_n_u16(t1, 4);
231       // Clamp pixel values at bitdepth maximum.
232       res.val[0] = vminq_u16(res.val[0], pixel_val_max);
233       res.val[1] = s2;
234 
235       vst2q_u16(dst, res);
236 
237       src += 8;
238       dst += 16;
239       sz -= 8;
240     } while (sz > 0);
241   }
242 }
243