xref: /aosp_15_r20/external/libgav1/src/dsp/arm/cdef_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/cdef.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstddef>
25 #include <cstdint>
26 #include <cstdlib>
27 
28 #include "src/dsp/arm/common_neon.h"
29 #include "src/dsp/constants.h"
30 #include "src/dsp/dsp.h"
31 #include "src/utils/common.h"
32 #include "src/utils/constants.h"
33 
34 namespace libgav1 {
35 namespace dsp {
36 namespace {
37 
38 #include "src/dsp/cdef.inc"
39 
40 // ----------------------------------------------------------------------------
41 // Refer to CdefDirection_C().
42 //
43 // int32_t partial[8][15] = {};
44 // for (int i = 0; i < 8; ++i) {
45 //   for (int j = 0; j < 8; ++j) {
46 //     const int x = 1;
47 //     partial[0][i + j] += x;
48 //     partial[1][i + j / 2] += x;
49 //     partial[2][i] += x;
50 //     partial[3][3 + i - j / 2] += x;
51 //     partial[4][7 + i - j] += x;
52 //     partial[5][3 - i / 2 + j] += x;
53 //     partial[6][j] += x;
54 //     partial[7][i / 2 + j] += x;
55 //   }
56 // }
57 //
58 // Using the code above, generate the position count for partial[8][15].
59 //
60 // partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
61 // partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
62 // partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
63 // partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
64 // partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
65 // partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
66 // partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
67 // partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
68 //
69 // The SIMD code shifts the input horizontally, then adds vertically to get the
70 // correct partial value for the given position.
71 // ----------------------------------------------------------------------------
72 
73 // ----------------------------------------------------------------------------
74 // partial[0][i + j] += x;
75 //
76 // 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00
77 // 00 10 11 12 13 14 15 16  17 00 00 00 00 00 00
78 // 00 00 20 21 22 23 24 25  26 27 00 00 00 00 00
79 // 00 00 00 30 31 32 33 34  35 36 37 00 00 00 00
80 // 00 00 00 00 40 41 42 43  44 45 46 47 00 00 00
81 // 00 00 00 00 00 50 51 52  53 54 55 56 57 00 00
82 // 00 00 00 00 00 00 60 61  62 63 64 65 66 67 00
83 // 00 00 00 00 00 00 00 70  71 72 73 74 75 76 77
84 //
85 // partial[4] is the same except the source is reversed.
AddPartial_D0_D4(uint8x8_t * v_src,uint16x8_t * partial_lo,uint16x8_t * partial_hi)86 LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(uint8x8_t* v_src,
87                                             uint16x8_t* partial_lo,
88                                             uint16x8_t* partial_hi) {
89   const uint8x8_t v_zero = vdup_n_u8(0);
90   // 00 01 02 03 04 05 06 07
91   // 00 10 11 12 13 14 15 16
92   *partial_lo = vaddl_u8(v_src[0], vext_u8(v_zero, v_src[1], 7));
93 
94   // 00 00 20 21 22 23 24 25
95   *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[2], 6));
96   // 17 00 00 00 00 00 00 00
97   // 26 27 00 00 00 00 00 00
98   *partial_hi =
99       vaddl_u8(vext_u8(v_src[1], v_zero, 7), vext_u8(v_src[2], v_zero, 6));
100 
101   // 00 00 00 30 31 32 33 34
102   *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[3], 5));
103   // 35 36 37 00 00 00 00 00
104   *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[3], v_zero, 5));
105 
106   // 00 00 00 00 40 41 42 43
107   *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[4], 4));
108   // 44 45 46 47 00 00 00 00
109   *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[4], v_zero, 4));
110 
111   // 00 00 00 00 00 50 51 52
112   *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[5], 3));
113   // 53 54 55 56 57 00 00 00
114   *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[5], v_zero, 3));
115 
116   // 00 00 00 00 00 00 60 61
117   *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[6], 2));
118   // 62 63 64 65 66 67 00 00
119   *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[6], v_zero, 2));
120 
121   // 00 00 00 00 00 00 00 70
122   *partial_lo = vaddw_u8(*partial_lo, vext_u8(v_zero, v_src[7], 1));
123   // 71 72 73 74 75 76 77 00
124   *partial_hi = vaddw_u8(*partial_hi, vext_u8(v_src[7], v_zero, 1));
125 }
126 
127 // ----------------------------------------------------------------------------
128 // partial[1][i + j / 2] += x;
129 //
130 // A0 = src[0] + src[1], A1 = src[2] + src[3], ...
131 //
132 // A0 A1 A2 A3 00 00 00 00  00 00 00 00 00 00 00
133 // 00 B0 B1 B2 B3 00 00 00  00 00 00 00 00 00 00
134 // 00 00 C0 C1 C2 C3 00 00  00 00 00 00 00 00 00
135 // 00 00 00 D0 D1 D2 D3 00  00 00 00 00 00 00 00
136 // 00 00 00 00 E0 E1 E2 E3  00 00 00 00 00 00 00
137 // 00 00 00 00 00 F0 F1 F2  F3 00 00 00 00 00 00
138 // 00 00 00 00 00 00 G0 G1  G2 G3 00 00 00 00 00
139 // 00 00 00 00 00 00 00 H0  H1 H2 H3 00 00 00 00
140 //
141 // partial[3] is the same except the source is reversed.
AddPartial_D1_D3(uint8x8_t * v_src,uint16x8_t * partial_lo,uint16x8_t * partial_hi)142 LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(uint8x8_t* v_src,
143                                             uint16x8_t* partial_lo,
144                                             uint16x8_t* partial_hi) {
145   uint8x16_t v_d1_temp[8];
146   const uint8x8_t v_zero = vdup_n_u8(0);
147   const uint8x16_t v_zero_16 = vdupq_n_u8(0);
148 
149   for (int i = 0; i < 8; ++i) {
150     v_d1_temp[i] = vcombine_u8(v_src[i], v_zero);
151   }
152 
153   *partial_lo = *partial_hi = vdupq_n_u16(0);
154   // A0 A1 A2 A3 00 00 00 00
155   *partial_lo = vpadalq_u8(*partial_lo, v_d1_temp[0]);
156 
157   // 00 B0 B1 B2 B3 00 00 00
158   *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[1], 14));
159 
160   // 00 00 C0 C1 C2 C3 00 00
161   *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[2], 12));
162   // 00 00 00 D0 D1 D2 D3 00
163   *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[3], 10));
164   // 00 00 00 00 E0 E1 E2 E3
165   *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[4], 8));
166 
167   // 00 00 00 00 00 F0 F1 F2
168   *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[5], 6));
169   // F3 00 00 00 00 00 00 00
170   *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[5], v_zero_16, 6));
171 
172   // 00 00 00 00 00 00 G0 G1
173   *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[6], 4));
174   // G2 G3 00 00 00 00 00 00
175   *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[6], v_zero_16, 4));
176 
177   // 00 00 00 00 00 00 00 H0
178   *partial_lo = vpadalq_u8(*partial_lo, vextq_u8(v_zero_16, v_d1_temp[7], 2));
179   // H1 H2 H3 00 00 00 00 00
180   *partial_hi = vpadalq_u8(*partial_hi, vextq_u8(v_d1_temp[7], v_zero_16, 2));
181 }
182 
183 // ----------------------------------------------------------------------------
184 // partial[7][i / 2 + j] += x;
185 //
186 // 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00
187 // 10 11 12 13 14 15 16 17  00 00 00 00 00 00 00
188 // 00 20 21 22 23 24 25 26  27 00 00 00 00 00 00
189 // 00 30 31 32 33 34 35 36  37 00 00 00 00 00 00
190 // 00 00 40 41 42 43 44 45  46 47 00 00 00 00 00
191 // 00 00 50 51 52 53 54 55  56 57 00 00 00 00 00
192 // 00 00 00 60 61 62 63 64  65 66 67 00 00 00 00
193 // 00 00 00 70 71 72 73 74  75 76 77 00 00 00 00
194 //
195 // partial[5] is the same except the source is reversed.
AddPartial_D5_D7(uint8x8_t * v_src,uint16x8_t * partial_lo,uint16x8_t * partial_hi)196 LIBGAV1_ALWAYS_INLINE void AddPartial_D5_D7(uint8x8_t* v_src,
197                                             uint16x8_t* partial_lo,
198                                             uint16x8_t* partial_hi) {
199   const uint16x8_t v_zero = vdupq_n_u16(0);
200   uint16x8_t v_pair_add[4];
201   // Add vertical source pairs.
202   v_pair_add[0] = vaddl_u8(v_src[0], v_src[1]);
203   v_pair_add[1] = vaddl_u8(v_src[2], v_src[3]);
204   v_pair_add[2] = vaddl_u8(v_src[4], v_src[5]);
205   v_pair_add[3] = vaddl_u8(v_src[6], v_src[7]);
206 
207   // 00 01 02 03 04 05 06 07
208   // 10 11 12 13 14 15 16 17
209   *partial_lo = v_pair_add[0];
210   // 00 00 00 00 00 00 00 00
211   // 00 00 00 00 00 00 00 00
212   *partial_hi = vdupq_n_u16(0);
213 
214   // 00 20 21 22 23 24 25 26
215   // 00 30 31 32 33 34 35 36
216   *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[1], 7));
217   // 27 00 00 00 00 00 00 00
218   // 37 00 00 00 00 00 00 00
219   *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[1], v_zero, 7));
220 
221   // 00 00 40 41 42 43 44 45
222   // 00 00 50 51 52 53 54 55
223   *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[2], 6));
224   // 46 47 00 00 00 00 00 00
225   // 56 57 00 00 00 00 00 00
226   *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[2], v_zero, 6));
227 
228   // 00 00 00 60 61 62 63 64
229   // 00 00 00 70 71 72 73 74
230   *partial_lo = vaddq_u16(*partial_lo, vextq_u16(v_zero, v_pair_add[3], 5));
231   // 65 66 67 00 00 00 00 00
232   // 75 76 77 00 00 00 00 00
233   *partial_hi = vaddq_u16(*partial_hi, vextq_u16(v_pair_add[3], v_zero, 5));
234 }
235 
236 template <int bitdepth>
AddPartial(const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride,uint16x8_t * partial_lo,uint16x8_t * partial_hi)237 LIBGAV1_ALWAYS_INLINE void AddPartial(const void* LIBGAV1_RESTRICT const source,
238                                       ptrdiff_t stride, uint16x8_t* partial_lo,
239                                       uint16x8_t* partial_hi) {
240   const auto* src = static_cast<const uint8_t*>(source);
241 
242   // 8x8 input
243   // 00 01 02 03 04 05 06 07
244   // 10 11 12 13 14 15 16 17
245   // 20 21 22 23 24 25 26 27
246   // 30 31 32 33 34 35 36 37
247   // 40 41 42 43 44 45 46 47
248   // 50 51 52 53 54 55 56 57
249   // 60 61 62 63 64 65 66 67
250   // 70 71 72 73 74 75 76 77
251   uint8x8_t v_src[8];
252   if (bitdepth == kBitdepth8) {
253     for (auto& v : v_src) {
254       v = vld1_u8(src);
255       src += stride;
256     }
257   } else {
258     // bitdepth - 8
259     constexpr int src_shift = (bitdepth == kBitdepth10) ? 2 : 4;
260     for (auto& v : v_src) {
261       v = vshrn_n_u16(vld1q_u16(reinterpret_cast<const uint16_t*>(src)),
262                       src_shift);
263       src += stride;
264     }
265   }
266   // partial for direction 2
267   // --------------------------------------------------------------------------
268   // partial[2][i] += x;
269   // 00 10 20 30 40 50 60 70  00 00 00 00 00 00 00 00
270   // 01 11 21 33 41 51 61 71  00 00 00 00 00 00 00 00
271   // 02 12 22 33 42 52 62 72  00 00 00 00 00 00 00 00
272   // 03 13 23 33 43 53 63 73  00 00 00 00 00 00 00 00
273   // 04 14 24 34 44 54 64 74  00 00 00 00 00 00 00 00
274   // 05 15 25 35 45 55 65 75  00 00 00 00 00 00 00 00
275   // 06 16 26 36 46 56 66 76  00 00 00 00 00 00 00 00
276   // 07 17 27 37 47 57 67 77  00 00 00 00 00 00 00 00
277   partial_lo[2] = vsetq_lane_u16(SumVector(v_src[0]), vdupq_n_u16(0), 0);
278   partial_lo[2] = vsetq_lane_u16(SumVector(v_src[1]), partial_lo[2], 1);
279   partial_lo[2] = vsetq_lane_u16(SumVector(v_src[2]), partial_lo[2], 2);
280   partial_lo[2] = vsetq_lane_u16(SumVector(v_src[3]), partial_lo[2], 3);
281   partial_lo[2] = vsetq_lane_u16(SumVector(v_src[4]), partial_lo[2], 4);
282   partial_lo[2] = vsetq_lane_u16(SumVector(v_src[5]), partial_lo[2], 5);
283   partial_lo[2] = vsetq_lane_u16(SumVector(v_src[6]), partial_lo[2], 6);
284   partial_lo[2] = vsetq_lane_u16(SumVector(v_src[7]), partial_lo[2], 7);
285 
286   // partial for direction 6
287   // --------------------------------------------------------------------------
288   // partial[6][j] += x;
289   // 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00 00
290   // 10 11 12 13 14 15 16 17  00 00 00 00 00 00 00 00
291   // 20 21 22 23 24 25 26 27  00 00 00 00 00 00 00 00
292   // 30 31 32 33 34 35 36 37  00 00 00 00 00 00 00 00
293   // 40 41 42 43 44 45 46 47  00 00 00 00 00 00 00 00
294   // 50 51 52 53 54 55 56 57  00 00 00 00 00 00 00 00
295   // 60 61 62 63 64 65 66 67  00 00 00 00 00 00 00 00
296   // 70 71 72 73 74 75 76 77  00 00 00 00 00 00 00 00
297   partial_lo[6] = vaddl_u8(v_src[0], v_src[1]);
298   for (int i = 2; i < 8; ++i) {
299     partial_lo[6] = vaddw_u8(partial_lo[6], v_src[i]);
300   }
301 
302   // partial for direction 0
303   AddPartial_D0_D4(v_src, &partial_lo[0], &partial_hi[0]);
304 
305   // partial for direction 1
306   AddPartial_D1_D3(v_src, &partial_lo[1], &partial_hi[1]);
307 
308   // partial for direction 7
309   AddPartial_D5_D7(v_src, &partial_lo[7], &partial_hi[7]);
310 
311   uint8x8_t v_src_reverse[8];
312   for (int i = 0; i < 8; ++i) {
313     v_src_reverse[i] = vrev64_u8(v_src[i]);
314   }
315 
316   // partial for direction 4
317   AddPartial_D0_D4(v_src_reverse, &partial_lo[4], &partial_hi[4]);
318 
319   // partial for direction 3
320   AddPartial_D1_D3(v_src_reverse, &partial_lo[3], &partial_hi[3]);
321 
322   // partial for direction 5
323   AddPartial_D5_D7(v_src_reverse, &partial_lo[5], &partial_hi[5]);
324 }
325 
Square(uint16x4_t a)326 uint32x4_t Square(uint16x4_t a) { return vmull_u16(a, a); }
327 
SquareAccumulate(uint32x4_t a,uint16x4_t b)328 uint32x4_t SquareAccumulate(uint32x4_t a, uint16x4_t b) {
329   return vmlal_u16(a, b, b);
330 }
331 
332 // |cost[0]| and |cost[4]| square the input and sum with the corresponding
333 // element from the other end of the vector:
334 // |kCdefDivisionTable[]| element:
335 // cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) *
336 //             kCdefDivisionTable[i + 1];
337 // cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8];
338 // Because everything is being summed into a single value the distributive
339 // property allows us to mirror the division table and accumulate once.
Cost0Or4(const uint16x8_t a,const uint16x8_t b,const uint32x4_t division_table[4])340 uint32_t Cost0Or4(const uint16x8_t a, const uint16x8_t b,
341                   const uint32x4_t division_table[4]) {
342   uint32x4_t c = vmulq_u32(Square(vget_low_u16(a)), division_table[0]);
343   c = vmlaq_u32(c, Square(vget_high_u16(a)), division_table[1]);
344   c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[2]);
345   c = vmlaq_u32(c, Square(vget_high_u16(b)), division_table[3]);
346   return SumVector(c);
347 }
348 
349 // |cost[2]| and |cost[6]| square the input and accumulate:
350 // cost[2] += Square(partial[2][i])
SquareAccumulate(const uint16x8_t a)351 uint32_t SquareAccumulate(const uint16x8_t a) {
352   uint32x4_t c = Square(vget_low_u16(a));
353   c = SquareAccumulate(c, vget_high_u16(a));
354   c = vmulq_n_u32(c, kCdefDivisionTable[7]);
355   return SumVector(c);
356 }
357 
CostOdd(const uint16x8_t a,const uint16x8_t b,const uint32x4_t mask,const uint32x4_t division_table[2])358 uint32_t CostOdd(const uint16x8_t a, const uint16x8_t b, const uint32x4_t mask,
359                  const uint32x4_t division_table[2]) {
360   // Remove elements 0-2.
361   uint32x4_t c = vandq_u32(mask, Square(vget_low_u16(a)));
362   c = vaddq_u32(c, Square(vget_high_u16(a)));
363   c = vmulq_n_u32(c, kCdefDivisionTable[7]);
364 
365   c = vmlaq_u32(c, Square(vget_low_u16(a)), division_table[0]);
366   c = vmlaq_u32(c, Square(vget_low_u16(b)), division_table[1]);
367   return SumVector(c);
368 }
369 
370 template <int bitdepth>
CdefDirection_NEON(const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride,uint8_t * LIBGAV1_RESTRICT const direction,int * LIBGAV1_RESTRICT const variance)371 void CdefDirection_NEON(const void* LIBGAV1_RESTRICT const source,
372                         ptrdiff_t stride,
373                         uint8_t* LIBGAV1_RESTRICT const direction,
374                         int* LIBGAV1_RESTRICT const variance) {
375   assert(direction != nullptr);
376   assert(variance != nullptr);
377   const auto* src = static_cast<const uint8_t*>(source);
378 
379   uint32_t cost[8];
380   uint16x8_t partial_lo[8], partial_hi[8];
381 
382   AddPartial<bitdepth>(src, stride, partial_lo, partial_hi);
383 
384   cost[2] = SquareAccumulate(partial_lo[2]);
385   cost[6] = SquareAccumulate(partial_lo[6]);
386 
387   const uint32x4_t division_table[4] = {
388       vld1q_u32(kCdefDivisionTable), vld1q_u32(kCdefDivisionTable + 4),
389       vld1q_u32(kCdefDivisionTable + 8), vld1q_u32(kCdefDivisionTable + 12)};
390 
391   cost[0] = Cost0Or4(partial_lo[0], partial_hi[0], division_table);
392   cost[4] = Cost0Or4(partial_lo[4], partial_hi[4], division_table);
393 
394   const uint32x4_t division_table_odd[2] = {
395       vld1q_u32(kCdefDivisionTableOdd), vld1q_u32(kCdefDivisionTableOdd + 4)};
396 
397   const uint32x4_t element_3_mask = {0, 0, 0, static_cast<uint32_t>(-1)};
398 
399   cost[1] =
400       CostOdd(partial_lo[1], partial_hi[1], element_3_mask, division_table_odd);
401   cost[3] =
402       CostOdd(partial_lo[3], partial_hi[3], element_3_mask, division_table_odd);
403   cost[5] =
404       CostOdd(partial_lo[5], partial_hi[5], element_3_mask, division_table_odd);
405   cost[7] =
406       CostOdd(partial_lo[7], partial_hi[7], element_3_mask, division_table_odd);
407 
408   uint32_t best_cost = 0;
409   *direction = 0;
410   for (int i = 0; i < 8; ++i) {
411     if (cost[i] > best_cost) {
412       best_cost = cost[i];
413       *direction = i;
414     }
415   }
416   *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10;
417 }
418 
419 // -------------------------------------------------------------------------
420 // CdefFilter
421 
422 // Load 4 vectors based on the given |direction|.
LoadDirection(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t stride,uint16x8_t * output,const int direction)423 void LoadDirection(const uint16_t* LIBGAV1_RESTRICT const src,
424                    const ptrdiff_t stride, uint16x8_t* output,
425                    const int direction) {
426   // Each |direction| describes a different set of source values. Expand this
427   // set by negating each set. For |direction| == 0 this gives a diagonal line
428   // from top right to bottom left. The first value is y, the second x. Negative
429   // y values move up.
430   //    a       b         c       d
431   // {-1, 1}, {1, -1}, {-2, 2}, {2, -2}
432   //         c
433   //       a
434   //     0
435   //   b
436   // d
437   const int y_0 = kCdefDirections[direction][0][0];
438   const int x_0 = kCdefDirections[direction][0][1];
439   const int y_1 = kCdefDirections[direction][1][0];
440   const int x_1 = kCdefDirections[direction][1][1];
441   output[0] = vld1q_u16(src + y_0 * stride + x_0);
442   output[1] = vld1q_u16(src - y_0 * stride - x_0);
443   output[2] = vld1q_u16(src + y_1 * stride + x_1);
444   output[3] = vld1q_u16(src - y_1 * stride - x_1);
445 }
446 
447 // Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to
448 // do 2 rows at a time.
LoadDirection4(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t stride,uint16x8_t * output,const int direction)449 void LoadDirection4(const uint16_t* LIBGAV1_RESTRICT const src,
450                     const ptrdiff_t stride, uint16x8_t* output,
451                     const int direction) {
452   const int y_0 = kCdefDirections[direction][0][0];
453   const int x_0 = kCdefDirections[direction][0][1];
454   const int y_1 = kCdefDirections[direction][1][0];
455   const int x_1 = kCdefDirections[direction][1][1];
456   output[0] = vcombine_u16(vld1_u16(src + y_0 * stride + x_0),
457                            vld1_u16(src + y_0 * stride + stride + x_0));
458   output[1] = vcombine_u16(vld1_u16(src - y_0 * stride - x_0),
459                            vld1_u16(src - y_0 * stride + stride - x_0));
460   output[2] = vcombine_u16(vld1_u16(src + y_1 * stride + x_1),
461                            vld1_u16(src + y_1 * stride + stride + x_1));
462   output[3] = vcombine_u16(vld1_u16(src - y_1 * stride - x_1),
463                            vld1_u16(src - y_1 * stride + stride - x_1));
464 }
465 
Constrain(const uint16x8_t pixel,const uint16x8_t reference,const uint16x8_t threshold,const int16x8_t damping)466 int16x8_t Constrain(const uint16x8_t pixel, const uint16x8_t reference,
467                     const uint16x8_t threshold, const int16x8_t damping) {
468   // If reference > pixel, the difference will be negative, so convert to 0 or
469   // -1.
470   const uint16x8_t sign = vcgtq_u16(reference, pixel);
471   const uint16x8_t abs_diff = vabdq_u16(pixel, reference);
472   const uint16x8_t shifted_diff = vshlq_u16(abs_diff, damping);
473   // For bitdepth == 8, the threshold range is [0, 15] and the damping range is
474   // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be
475   // larger than threshold. Subtract using saturation will return 0 when pixel
476   // == kCdefLargeValue.
477   static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue");
478   const uint16x8_t thresh_minus_shifted_diff =
479       vqsubq_u16(threshold, shifted_diff);
480   const uint16x8_t clamp_abs_diff =
481       vminq_u16(thresh_minus_shifted_diff, abs_diff);
482   // Restore the sign.
483   return vreinterpretq_s16_u16(
484       vsubq_u16(veorq_u16(clamp_abs_diff, sign), sign));
485 }
486 
487 template <typename Pixel>
GetMaxPrimary(uint16x8_t * primary_val,uint16x8_t max,uint16x8_t cdef_large_value_mask)488 uint16x8_t GetMaxPrimary(uint16x8_t* primary_val, uint16x8_t max,
489                          uint16x8_t cdef_large_value_mask) {
490   if (sizeof(Pixel) == 1) {
491     // The source is 16 bits, however, we only really care about the lower
492     // 8 bits.  The upper 8 bits contain the "large" flag.  After the final
493     // primary max has been calculated, zero out the upper 8 bits.  Use this
494     // to find the "16 bit" max.
495     const uint8x16_t max_p01 = vmaxq_u8(vreinterpretq_u8_u16(primary_val[0]),
496                                         vreinterpretq_u8_u16(primary_val[1]));
497     const uint8x16_t max_p23 = vmaxq_u8(vreinterpretq_u8_u16(primary_val[2]),
498                                         vreinterpretq_u8_u16(primary_val[3]));
499     const uint16x8_t max_p = vreinterpretq_u16_u8(vmaxq_u8(max_p01, max_p23));
500     max = vmaxq_u16(max, vandq_u16(max_p, cdef_large_value_mask));
501   } else {
502     // Convert kCdefLargeValue to 0 before calculating max.
503     max = vmaxq_u16(max, vandq_u16(primary_val[0], cdef_large_value_mask));
504     max = vmaxq_u16(max, vandq_u16(primary_val[1], cdef_large_value_mask));
505     max = vmaxq_u16(max, vandq_u16(primary_val[2], cdef_large_value_mask));
506     max = vmaxq_u16(max, vandq_u16(primary_val[3], cdef_large_value_mask));
507   }
508   return max;
509 }
510 
511 template <typename Pixel>
GetMaxSecondary(uint16x8_t * secondary_val,uint16x8_t max,uint16x8_t cdef_large_value_mask)512 uint16x8_t GetMaxSecondary(uint16x8_t* secondary_val, uint16x8_t max,
513                            uint16x8_t cdef_large_value_mask) {
514   if (sizeof(Pixel) == 1) {
515     const uint8x16_t max_s01 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[0]),
516                                         vreinterpretq_u8_u16(secondary_val[1]));
517     const uint8x16_t max_s23 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[2]),
518                                         vreinterpretq_u8_u16(secondary_val[3]));
519     const uint8x16_t max_s45 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[4]),
520                                         vreinterpretq_u8_u16(secondary_val[5]));
521     const uint8x16_t max_s67 = vmaxq_u8(vreinterpretq_u8_u16(secondary_val[6]),
522                                         vreinterpretq_u8_u16(secondary_val[7]));
523     const uint16x8_t max_s = vreinterpretq_u16_u8(
524         vmaxq_u8(vmaxq_u8(max_s01, max_s23), vmaxq_u8(max_s45, max_s67)));
525     max = vmaxq_u16(max, vandq_u16(max_s, cdef_large_value_mask));
526   } else {
527     max = vmaxq_u16(max, vandq_u16(secondary_val[0], cdef_large_value_mask));
528     max = vmaxq_u16(max, vandq_u16(secondary_val[1], cdef_large_value_mask));
529     max = vmaxq_u16(max, vandq_u16(secondary_val[2], cdef_large_value_mask));
530     max = vmaxq_u16(max, vandq_u16(secondary_val[3], cdef_large_value_mask));
531     max = vmaxq_u16(max, vandq_u16(secondary_val[4], cdef_large_value_mask));
532     max = vmaxq_u16(max, vandq_u16(secondary_val[5], cdef_large_value_mask));
533     max = vmaxq_u16(max, vandq_u16(secondary_val[6], cdef_large_value_mask));
534     max = vmaxq_u16(max, vandq_u16(secondary_val[7], cdef_large_value_mask));
535   }
536   return max;
537 }
538 
539 template <typename Pixel, int width>
StorePixels(void * dest,ptrdiff_t dst_stride,int16x8_t result)540 void StorePixels(void* dest, ptrdiff_t dst_stride, int16x8_t result) {
541   auto* const dst8 = static_cast<uint8_t*>(dest);
542   if (sizeof(Pixel) == 1) {
543     const uint8x8_t dst_pixel = vqmovun_s16(result);
544     if (width == 8) {
545       vst1_u8(dst8, dst_pixel);
546     } else {
547       StoreLo4(dst8, dst_pixel);
548       StoreHi4(dst8 + dst_stride, dst_pixel);
549     }
550   } else {
551     const uint16x8_t dst_pixel = vreinterpretq_u16_s16(result);
552     auto* const dst16 = reinterpret_cast<uint16_t*>(dst8);
553     if (width == 8) {
554       vst1q_u16(dst16, dst_pixel);
555     } else {
556       auto* const dst16_next_row =
557           reinterpret_cast<uint16_t*>(dst8 + dst_stride);
558       vst1_u16(dst16, vget_low_u16(dst_pixel));
559       vst1_u16(dst16_next_row, vget_high_u16(dst_pixel));
560     }
561   }
562 }
563 
564 template <int width, typename Pixel, bool enable_primary = true,
565           bool enable_secondary = true>
CdefFilter_NEON(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,const int primary_strength,const int secondary_strength,const int damping,const int direction,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dst_stride)566 void CdefFilter_NEON(const uint16_t* LIBGAV1_RESTRICT src,
567                      const ptrdiff_t src_stride, const int height,
568                      const int primary_strength, const int secondary_strength,
569                      const int damping, const int direction,
570                      void* LIBGAV1_RESTRICT dest, const ptrdiff_t dst_stride) {
571   static_assert(width == 8 || width == 4, "");
572   static_assert(enable_primary || enable_secondary, "");
573   constexpr bool clipping_required = enable_primary && enable_secondary;
574   auto* dst = static_cast<uint8_t*>(dest);
575   const uint16x8_t cdef_large_value_mask =
576       vdupq_n_u16(static_cast<uint16_t>(~kCdefLargeValue));
577   const uint16x8_t primary_threshold = vdupq_n_u16(primary_strength);
578   const uint16x8_t secondary_threshold = vdupq_n_u16(secondary_strength);
579 
580   int16x8_t primary_damping_shift, secondary_damping_shift;
581 
582   // FloorLog2() requires input to be > 0.
583   // 8-bit damping range: Y: [3, 6], UV: [2, 5].
584   // 10-bit damping range: Y: [3, 6 + 2], UV: [2, 5 + 2].
585   if (enable_primary) {
586     // 8-bit primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is
587     // necessary for UV filtering.
588     // 10-bit primary_strength: [0, 15 << 2].
589     primary_damping_shift =
590         vdupq_n_s16(-std::max(0, damping - FloorLog2(primary_strength)));
591   }
592 
593   if (enable_secondary) {
594     if (sizeof(Pixel) == 1) {
595       // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is
596       // necessary.
597       assert(damping - FloorLog2(secondary_strength) >= 0);
598       secondary_damping_shift =
599           vdupq_n_s16(-(damping - FloorLog2(secondary_strength)));
600     } else {
601       // secondary_strength: [0, 4 << 2]
602       secondary_damping_shift =
603           vdupq_n_s16(-std::max(0, damping - FloorLog2(secondary_strength)));
604     }
605   }
606 
607   constexpr int coeff_shift = (sizeof(Pixel) == 1) ? 0 : kBitdepth10 - 8;
608   const int primary_tap_0 =
609       kCdefPrimaryTaps[(primary_strength >> coeff_shift) & 1][0];
610   const int primary_tap_1 =
611       kCdefPrimaryTaps[(primary_strength >> coeff_shift) & 1][1];
612 
613   int y = height;
614   do {
615     uint16x8_t pixel;
616     if (width == 8) {
617       pixel = vld1q_u16(src);
618     } else {
619       pixel = vcombine_u16(vld1_u16(src), vld1_u16(src + src_stride));
620     }
621 
622     uint16x8_t min = pixel;
623     uint16x8_t max = pixel;
624     int16x8_t sum;
625 
626     if (enable_primary) {
627       // Primary |direction|.
628       uint16x8_t primary_val[4];
629       if (width == 8) {
630         LoadDirection(src, src_stride, primary_val, direction);
631       } else {
632         LoadDirection4(src, src_stride, primary_val, direction);
633       }
634 
635       if (clipping_required) {
636         min = vminq_u16(min, primary_val[0]);
637         min = vminq_u16(min, primary_val[1]);
638         min = vminq_u16(min, primary_val[2]);
639         min = vminq_u16(min, primary_val[3]);
640 
641         max = GetMaxPrimary<Pixel>(primary_val, max, cdef_large_value_mask);
642       }
643 
644       sum = Constrain(primary_val[0], pixel, primary_threshold,
645                       primary_damping_shift);
646       sum = vmulq_n_s16(sum, primary_tap_0);
647       sum = vmlaq_n_s16(sum,
648                         Constrain(primary_val[1], pixel, primary_threshold,
649                                   primary_damping_shift),
650                         primary_tap_0);
651       sum = vmlaq_n_s16(sum,
652                         Constrain(primary_val[2], pixel, primary_threshold,
653                                   primary_damping_shift),
654                         primary_tap_1);
655       sum = vmlaq_n_s16(sum,
656                         Constrain(primary_val[3], pixel, primary_threshold,
657                                   primary_damping_shift),
658                         primary_tap_1);
659     } else {
660       sum = vdupq_n_s16(0);
661     }
662 
663     if (enable_secondary) {
664       // Secondary |direction| values (+/- 2). Clamp |direction|.
665       uint16x8_t secondary_val[8];
666       if (width == 8) {
667         LoadDirection(src, src_stride, secondary_val, direction + 2);
668         LoadDirection(src, src_stride, secondary_val + 4, direction - 2);
669       } else {
670         LoadDirection4(src, src_stride, secondary_val, direction + 2);
671         LoadDirection4(src, src_stride, secondary_val + 4, direction - 2);
672       }
673 
674       if (clipping_required) {
675         min = vminq_u16(min, secondary_val[0]);
676         min = vminq_u16(min, secondary_val[1]);
677         min = vminq_u16(min, secondary_val[2]);
678         min = vminq_u16(min, secondary_val[3]);
679         min = vminq_u16(min, secondary_val[4]);
680         min = vminq_u16(min, secondary_val[5]);
681         min = vminq_u16(min, secondary_val[6]);
682         min = vminq_u16(min, secondary_val[7]);
683 
684         max = GetMaxSecondary<Pixel>(secondary_val, max, cdef_large_value_mask);
685       }
686 
687       sum = vmlaq_n_s16(sum,
688                         Constrain(secondary_val[0], pixel, secondary_threshold,
689                                   secondary_damping_shift),
690                         kCdefSecondaryTap0);
691       sum = vmlaq_n_s16(sum,
692                         Constrain(secondary_val[1], pixel, secondary_threshold,
693                                   secondary_damping_shift),
694                         kCdefSecondaryTap0);
695       sum = vmlaq_n_s16(sum,
696                         Constrain(secondary_val[2], pixel, secondary_threshold,
697                                   secondary_damping_shift),
698                         kCdefSecondaryTap1);
699       sum = vmlaq_n_s16(sum,
700                         Constrain(secondary_val[3], pixel, secondary_threshold,
701                                   secondary_damping_shift),
702                         kCdefSecondaryTap1);
703       sum = vmlaq_n_s16(sum,
704                         Constrain(secondary_val[4], pixel, secondary_threshold,
705                                   secondary_damping_shift),
706                         kCdefSecondaryTap0);
707       sum = vmlaq_n_s16(sum,
708                         Constrain(secondary_val[5], pixel, secondary_threshold,
709                                   secondary_damping_shift),
710                         kCdefSecondaryTap0);
711       sum = vmlaq_n_s16(sum,
712                         Constrain(secondary_val[6], pixel, secondary_threshold,
713                                   secondary_damping_shift),
714                         kCdefSecondaryTap1);
715       sum = vmlaq_n_s16(sum,
716                         Constrain(secondary_val[7], pixel, secondary_threshold,
717                                   secondary_damping_shift),
718                         kCdefSecondaryTap1);
719     }
720     // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max))
721     const int16x8_t sum_lt_0 = vshrq_n_s16(sum, 15);
722     sum = vaddq_s16(sum, sum_lt_0);
723     int16x8_t result = vrsraq_n_s16(vreinterpretq_s16_u16(pixel), sum, 4);
724     if (clipping_required) {
725       result = vminq_s16(result, vreinterpretq_s16_u16(max));
726       result = vmaxq_s16(result, vreinterpretq_s16_u16(min));
727     }
728 
729     StorePixels<Pixel, width>(dst, dst_stride, result);
730 
731     src += (width == 8) ? src_stride : src_stride << 1;
732     dst += (width == 8) ? dst_stride : dst_stride << 1;
733     y -= (width == 8) ? 1 : 2;
734   } while (y != 0);
735 }
736 
737 }  // namespace
738 
739 namespace low_bitdepth {
740 namespace {
741 
Init8bpp()742 void Init8bpp() {
743   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
744   assert(dsp != nullptr);
745   dsp->cdef_direction = CdefDirection_NEON<kBitdepth8>;
746   dsp->cdef_filters[0][0] = CdefFilter_NEON<4, uint8_t>;
747   dsp->cdef_filters[0][1] = CdefFilter_NEON<4, uint8_t, /*enable_primary=*/true,
748                                             /*enable_secondary=*/false>;
749   dsp->cdef_filters[0][2] =
750       CdefFilter_NEON<4, uint8_t, /*enable_primary=*/false>;
751   dsp->cdef_filters[1][0] = CdefFilter_NEON<8, uint8_t>;
752   dsp->cdef_filters[1][1] = CdefFilter_NEON<8, uint8_t, /*enable_primary=*/true,
753                                             /*enable_secondary=*/false>;
754   dsp->cdef_filters[1][2] =
755       CdefFilter_NEON<8, uint8_t, /*enable_primary=*/false>;
756 }
757 
758 }  // namespace
759 }  // namespace low_bitdepth
760 
761 #if LIBGAV1_MAX_BITDEPTH >= 10
762 namespace high_bitdepth {
763 namespace {
764 
Init10bpp()765 void Init10bpp() {
766   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
767   assert(dsp != nullptr);
768   dsp->cdef_direction = CdefDirection_NEON<kBitdepth10>;
769   dsp->cdef_filters[0][0] = CdefFilter_NEON<4, uint16_t>;
770   dsp->cdef_filters[0][1] =
771       CdefFilter_NEON<4, uint16_t, /*enable_primary=*/true,
772                       /*enable_secondary=*/false>;
773   dsp->cdef_filters[0][2] =
774       CdefFilter_NEON<4, uint16_t, /*enable_primary=*/false>;
775   dsp->cdef_filters[1][0] = CdefFilter_NEON<8, uint16_t>;
776   dsp->cdef_filters[1][1] =
777       CdefFilter_NEON<8, uint16_t, /*enable_primary=*/true,
778                       /*enable_secondary=*/false>;
779   dsp->cdef_filters[1][2] =
780       CdefFilter_NEON<8, uint16_t, /*enable_primary=*/false>;
781 }
782 
783 }  // namespace
784 }  // namespace high_bitdepth
785 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
786 
CdefInit_NEON()787 void CdefInit_NEON() {
788   low_bitdepth::Init8bpp();
789 #if LIBGAV1_MAX_BITDEPTH >= 10
790   high_bitdepth::Init10bpp();
791 #endif
792 }
793 
794 }  // namespace dsp
795 }  // namespace libgav1
796 #else   // !LIBGAV1_ENABLE_NEON
797 namespace libgav1 {
798 namespace dsp {
799 
CdefInit_NEON()800 void CdefInit_NEON() {}
801 
802 }  // namespace dsp
803 }  // namespace libgav1
804 #endif  // LIBGAV1_ENABLE_NEON
805