xref: /aosp_15_r20/external/libgav1/src/dsp/x86/cdef_avx2.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2021 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/cdef.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_AVX2
19 #include <immintrin.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 #include <cstdlib>
26 
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/dsp/x86/common_avx2.h"
30 #include "src/utils/common.h"
31 #include "src/utils/constants.h"
32 
33 namespace libgav1 {
34 namespace dsp {
35 namespace low_bitdepth {
36 namespace {
37 
38 #include "src/dsp/cdef.inc"
39 
40 // Used when calculating odd |cost[x]| values.
41 // Holds elements 1 3 5 7 7 7 7 7
42 alignas(32) constexpr uint32_t kCdefDivisionTableOddPairsPadded[] = {
43     420, 210, 140, 105, 420, 210, 140, 105,
44     105, 105, 105, 105, 105, 105, 105, 105};
45 
46 // ----------------------------------------------------------------------------
47 // Refer to CdefDirection_C().
48 //
49 // int32_t partial[8][15] = {};
50 // for (int i = 0; i < 8; ++i) {
51 //   for (int j = 0; j < 8; ++j) {
52 //     const int x = 1;
53 //     partial[0][i + j] += x;
54 //     partial[1][i + j / 2] += x;
55 //     partial[2][i] += x;
56 //     partial[3][3 + i - j / 2] += x;
57 //     partial[4][7 + i - j] += x;
58 //     partial[5][3 - i / 2 + j] += x;
59 //     partial[6][j] += x;
60 //     partial[7][i / 2 + j] += x;
61 //   }
62 // }
63 //
64 // Using the code above, generate the position count for partial[8][15].
65 //
66 // partial[0]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
67 // partial[1]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
68 // partial[2]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
69 // partial[3]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
70 // partial[4]: 1 2 3 4 5 6 7 8 7 6 5 4 3 2 1
71 // partial[5]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
72 // partial[6]: 8 8 8 8 8 8 8 8 0 0 0 0 0 0 0
73 // partial[7]: 2 4 6 8 8 8 8 8 6 4 2 0 0 0 0
74 //
75 // The SIMD code shifts the input horizontally, then adds vertically to get the
76 // correct partial value for the given position.
77 // ----------------------------------------------------------------------------
78 
79 // ----------------------------------------------------------------------------
80 // partial[0][i + j] += x;
81 //
82 // 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00
83 // 00 10 11 12 13 14 15 16  17 00 00 00 00 00 00
84 // 00 00 20 21 22 23 24 25  26 27 00 00 00 00 00
85 // 00 00 00 30 31 32 33 34  35 36 37 00 00 00 00
86 // 00 00 00 00 40 41 42 43  44 45 46 47 00 00 00
87 // 00 00 00 00 00 50 51 52  53 54 55 56 57 00 00
88 // 00 00 00 00 00 00 60 61  62 63 64 65 66 67 00
89 // 00 00 00 00 00 00 00 70  71 72 73 74 75 76 77
90 //
91 // partial[4] is the same except the source is reversed.
AddPartial_D0_D4(__m256i * v_src_16,__m256i * partial_lo,__m256i * partial_hi)92 LIBGAV1_ALWAYS_INLINE void AddPartial_D0_D4(__m256i* v_src_16,
93                                             __m256i* partial_lo,
94                                             __m256i* partial_hi) {
95   // 00 01 02 03 04 05 06 07
96   *partial_lo = v_src_16[0];
97   // 00 00 00 00 00 00 00 00
98   *partial_hi = _mm256_setzero_si256();
99 
100   // 00 10 11 12 13 14 15 16
101   *partial_lo =
102       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[1], 2));
103   // 17 00 00 00 00 00 00 00
104   *partial_hi =
105       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[1], 14));
106 
107   // 00 00 20 21 22 23 24 25
108   *partial_lo =
109       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[2], 4));
110   // 26 27 00 00 00 00 00 00
111   *partial_hi =
112       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[2], 12));
113 
114   // 00 00 00 30 31 32 33 34
115   *partial_lo =
116       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[3], 6));
117   // 35 36 37 00 00 00 00 00
118   *partial_hi =
119       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[3], 10));
120 
121   // 00 00 00 00 40 41 42 43
122   *partial_lo =
123       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[4], 8));
124   // 44 45 46 47 00 00 00 00
125   *partial_hi =
126       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[4], 8));
127 
128   // 00 00 00 00 00 50 51 52
129   *partial_lo =
130       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[5], 10));
131   // 53 54 55 56 57 00 00 00
132   *partial_hi =
133       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[5], 6));
134 
135   // 00 00 00 00 00 00 60 61
136   *partial_lo =
137       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[6], 12));
138   // 62 63 64 65 66 67 00 00
139   *partial_hi =
140       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[6], 4));
141 
142   // 00 00 00 00 00 00 00 70
143   *partial_lo =
144       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_src_16[7], 14));
145   // 71 72 73 74 75 76 77 00
146   *partial_hi =
147       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_src_16[7], 2));
148 }
149 
150 // ----------------------------------------------------------------------------
151 // partial[1][i + j / 2] += x;
152 //
153 // A0 = src[0] + src[1], A1 = src[2] + src[3], ...
154 //
155 // A0 A1 A2 A3 00 00 00 00  00 00 00 00 00 00 00
156 // 00 B0 B1 B2 B3 00 00 00  00 00 00 00 00 00 00
157 // 00 00 C0 C1 C2 C3 00 00  00 00 00 00 00 00 00
158 // 00 00 00 D0 D1 D2 D3 00  00 00 00 00 00 00 00
159 // 00 00 00 00 E0 E1 E2 E3  00 00 00 00 00 00 00
160 // 00 00 00 00 00 F0 F1 F2  F3 00 00 00 00 00 00
161 // 00 00 00 00 00 00 G0 G1  G2 G3 00 00 00 00 00
162 // 00 00 00 00 00 00 00 H0  H1 H2 H3 00 00 00 00
163 //
164 // partial[3] is the same except the source is reversed.
AddPartial_D1_D3(__m256i * v_src_16,__m256i * partial_lo,__m256i * partial_hi)165 LIBGAV1_ALWAYS_INLINE void AddPartial_D1_D3(__m256i* v_src_16,
166                                             __m256i* partial_lo,
167                                             __m256i* partial_hi) {
168   __m256i v_d1_temp[8];
169   const __m256i v_zero = _mm256_setzero_si256();
170 
171   for (int i = 0; i < 8; ++i) {
172     v_d1_temp[i] = _mm256_hadd_epi16(v_src_16[i], v_zero);
173   }
174 
175   *partial_lo = *partial_hi = v_zero;
176   // A0 A1 A2 A3 00 00 00 00
177   *partial_lo = _mm256_add_epi16(*partial_lo, v_d1_temp[0]);
178 
179   // 00 B0 B1 B2 B3 00 00 00
180   *partial_lo =
181       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[1], 2));
182 
183   // 00 00 C0 C1 C2 C3 00 00
184   *partial_lo =
185       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[2], 4));
186   // 00 00 00 D0 D1 D2 D3 00
187   *partial_lo =
188       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[3], 6));
189   // 00 00 00 00 E0 E1 E2 E3
190   *partial_lo =
191       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[4], 8));
192 
193   // 00 00 00 00 00 F0 F1 F2
194   *partial_lo =
195       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[5], 10));
196   // F3 00 00 00 00 00 00 00
197   *partial_hi =
198       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_d1_temp[5], 6));
199 
200   // 00 00 00 00 00 00 G0 G1
201   *partial_lo =
202       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[6], 12));
203   // G2 G3 00 00 00 00 00 00
204   *partial_hi =
205       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_d1_temp[6], 4));
206 
207   // 00 00 00 00 00 00 00 H0
208   *partial_lo =
209       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_d1_temp[7], 14));
210   // H1 H2 H3 00 00 00 00 00
211   *partial_hi =
212       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_d1_temp[7], 2));
213 }
214 
215 // ----------------------------------------------------------------------------
216 // partial[7][i / 2 + j] += x;
217 //
218 // 00 01 02 03 04 05 06 07  00 00 00 00 00 00 00
219 // 10 11 12 13 14 15 16 17  00 00 00 00 00 00 00
220 // 00 20 21 22 23 24 25 26  27 00 00 00 00 00 00
221 // 00 30 31 32 33 34 35 36  37 00 00 00 00 00 00
222 // 00 00 40 41 42 43 44 45  46 47 00 00 00 00 00
223 // 00 00 50 51 52 53 54 55  56 57 00 00 00 00 00
224 // 00 00 00 60 61 62 63 64  65 66 67 00 00 00 00
225 // 00 00 00 70 71 72 73 74  75 76 77 00 00 00 00
226 //
227 // partial[5] is the same except the source is reversed.
AddPartial_D7_D5(__m256i * v_src,__m256i * partial_lo,__m256i * partial_hi)228 LIBGAV1_ALWAYS_INLINE void AddPartial_D7_D5(__m256i* v_src, __m256i* partial_lo,
229                                             __m256i* partial_hi) {
230   __m256i v_pair_add[4];
231   // Add vertical source pairs.
232   v_pair_add[0] = _mm256_add_epi16(v_src[0], v_src[1]);
233   v_pair_add[1] = _mm256_add_epi16(v_src[2], v_src[3]);
234   v_pair_add[2] = _mm256_add_epi16(v_src[4], v_src[5]);
235   v_pair_add[3] = _mm256_add_epi16(v_src[6], v_src[7]);
236 
237   // 00 01 02 03 04 05 06 07
238   // 10 11 12 13 14 15 16 17
239   *partial_lo = v_pair_add[0];
240   // 00 00 00 00 00 00 00 00
241   // 00 00 00 00 00 00 00 00
242   *partial_hi = _mm256_setzero_si256();
243 
244   // 00 20 21 22 23 24 25 26
245   // 00 30 31 32 33 34 35 36
246   *partial_lo =
247       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_pair_add[1], 2));
248   // 27 00 00 00 00 00 00 00
249   // 37 00 00 00 00 00 00 00
250   *partial_hi =
251       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_pair_add[1], 14));
252 
253   // 00 00 40 41 42 43 44 45
254   // 00 00 50 51 52 53 54 55
255   *partial_lo =
256       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_pair_add[2], 4));
257   // 46 47 00 00 00 00 00 00
258   // 56 57 00 00 00 00 00 00
259   *partial_hi =
260       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_pair_add[2], 12));
261 
262   // 00 00 00 60 61 62 63 64
263   // 00 00 00 70 71 72 73 74
264   *partial_lo =
265       _mm256_add_epi16(*partial_lo, _mm256_slli_si256(v_pair_add[3], 6));
266   // 65 66 67 00 00 00 00 00
267   // 75 76 77 00 00 00 00 00
268   *partial_hi =
269       _mm256_add_epi16(*partial_hi, _mm256_srli_si256(v_pair_add[3], 10));
270 }
271 
AddPartial(const uint8_t * LIBGAV1_RESTRICT src,ptrdiff_t stride,__m256i * partial)272 LIBGAV1_ALWAYS_INLINE void AddPartial(const uint8_t* LIBGAV1_RESTRICT src,
273                                       ptrdiff_t stride, __m256i* partial) {
274   // 8x8 input
275   // 00 01 02 03 04 05 06 07
276   // 10 11 12 13 14 15 16 17
277   // 20 21 22 23 24 25 26 27
278   // 30 31 32 33 34 35 36 37
279   // 40 41 42 43 44 45 46 47
280   // 50 51 52 53 54 55 56 57
281   // 60 61 62 63 64 65 66 67
282   // 70 71 72 73 74 75 76 77
283   __m256i v_src[8];
284   for (auto& i : v_src) {
285     i = _mm256_castsi128_si256(LoadLo8(src));
286     // Dup lower lane.
287     i = _mm256_permute2x128_si256(i, i, 0x0);
288     src += stride;
289   }
290 
291   const __m256i v_zero = _mm256_setzero_si256();
292   // partial for direction 2
293   // --------------------------------------------------------------------------
294   // partial[2][i] += x;
295   // 00 10 20 30 40 50 60 70  xx xx xx xx xx xx xx xx
296   // 01 11 21 33 41 51 61 71  xx xx xx xx xx xx xx xx
297   // 02 12 22 33 42 52 62 72  xx xx xx xx xx xx xx xx
298   // 03 13 23 33 43 53 63 73  xx xx xx xx xx xx xx xx
299   // 04 14 24 34 44 54 64 74  xx xx xx xx xx xx xx xx
300   // 05 15 25 35 45 55 65 75  xx xx xx xx xx xx xx xx
301   // 06 16 26 36 46 56 66 76  xx xx xx xx xx xx xx xx
302   // 07 17 27 37 47 57 67 77  xx xx xx xx xx xx xx xx
303   const __m256i v_src_4_0 = _mm256_unpacklo_epi64(v_src[0], v_src[4]);
304   const __m256i v_src_5_1 = _mm256_unpacklo_epi64(v_src[1], v_src[5]);
305   const __m256i v_src_6_2 = _mm256_unpacklo_epi64(v_src[2], v_src[6]);
306   const __m256i v_src_7_3 = _mm256_unpacklo_epi64(v_src[3], v_src[7]);
307   const __m256i v_hsum_4_0 = _mm256_sad_epu8(v_src_4_0, v_zero);
308   const __m256i v_hsum_5_1 = _mm256_sad_epu8(v_src_5_1, v_zero);
309   const __m256i v_hsum_6_2 = _mm256_sad_epu8(v_src_6_2, v_zero);
310   const __m256i v_hsum_7_3 = _mm256_sad_epu8(v_src_7_3, v_zero);
311   const __m256i v_hsum_1_0 = _mm256_unpacklo_epi16(v_hsum_4_0, v_hsum_5_1);
312   const __m256i v_hsum_3_2 = _mm256_unpacklo_epi16(v_hsum_6_2, v_hsum_7_3);
313   const __m256i v_hsum_5_4 = _mm256_unpackhi_epi16(v_hsum_4_0, v_hsum_5_1);
314   const __m256i v_hsum_7_6 = _mm256_unpackhi_epi16(v_hsum_6_2, v_hsum_7_3);
315   partial[2] =
316       _mm256_unpacklo_epi64(_mm256_unpacklo_epi32(v_hsum_1_0, v_hsum_3_2),
317                             _mm256_unpacklo_epi32(v_hsum_5_4, v_hsum_7_6));
318 
319   const __m256i extend_reverse = SetrM128i(
320       _mm_set_epi32(static_cast<int>(0x80078006), static_cast<int>(0x80058004),
321                     static_cast<int>(0x80038002), static_cast<int>(0x80018000)),
322       _mm_set_epi32(static_cast<int>(0x80008001), static_cast<int>(0x80028003),
323                     static_cast<int>(0x80048005),
324                     static_cast<int>(0x80068007)));
325 
326   for (auto& i : v_src) {
327     // Zero extend unsigned 8 to 16. The upper lane is reversed.
328     i = _mm256_shuffle_epi8(i, extend_reverse);
329   }
330 
331   // partial for direction 6
332   // --------------------------------------------------------------------------
333   // partial[6][j] += x;
334   // 00 01 02 03 04 05 06 07  xx xx xx xx xx xx xx xx
335   // 10 11 12 13 14 15 16 17  xx xx xx xx xx xx xx xx
336   // 20 21 22 23 24 25 26 27  xx xx xx xx xx xx xx xx
337   // 30 31 32 33 34 35 36 37  xx xx xx xx xx xx xx xx
338   // 40 41 42 43 44 45 46 47  xx xx xx xx xx xx xx xx
339   // 50 51 52 53 54 55 56 57  xx xx xx xx xx xx xx xx
340   // 60 61 62 63 64 65 66 67  xx xx xx xx xx xx xx xx
341   // 70 71 72 73 74 75 76 77  xx xx xx xx xx xx xx xx
342   partial[6] = v_src[0];
343   for (int i = 1; i < 8; ++i) {
344     partial[6] = _mm256_add_epi16(partial[6], v_src[i]);
345   }
346 
347   AddPartial_D0_D4(v_src, &partial[0], &partial[4]);
348   AddPartial_D1_D3(v_src, &partial[1], &partial[3]);
349   AddPartial_D7_D5(v_src, &partial[7], &partial[5]);
350 }
351 
SumVectorPair_S32(__m256i a)352 inline __m256i SumVectorPair_S32(__m256i a) {
353   a = _mm256_hadd_epi32(a, a);
354   a = _mm256_add_epi32(a, _mm256_srli_si256(a, 4));
355   return a;
356 }
357 
358 // |cost[0]| and |cost[4]| square the input and sum with the corresponding
359 // element from the other end of the vector:
360 // |kCdefDivisionTable[]| element:
361 // cost[0] += (Square(partial[0][i]) + Square(partial[0][14 - i])) *
362 //             kCdefDivisionTable[i + 1];
363 // cost[0] += Square(partial[0][7]) * kCdefDivisionTable[8];
Cost0Or4_Pair(uint32_t * cost,const __m256i partial_0,const __m256i partial_4,const __m256i division_table)364 inline void Cost0Or4_Pair(uint32_t* cost, const __m256i partial_0,
365                           const __m256i partial_4,
366                           const __m256i division_table) {
367   const __m256i division_table_0 =
368       _mm256_permute2x128_si256(division_table, division_table, 0x0);
369   const __m256i division_table_1 =
370       _mm256_permute2x128_si256(division_table, division_table, 0x11);
371 
372   // partial_lo
373   const __m256i a = partial_0;
374   // partial_hi
375   const __m256i b = partial_4;
376 
377   // Reverse and clear upper 2 bytes.
378   const __m256i reverser = _mm256_broadcastsi128_si256(_mm_set_epi32(
379       static_cast<int>(0x80800100), 0x03020504, 0x07060908, 0x0b0a0d0c));
380 
381   // 14 13 12 11 10 09 08 ZZ
382   const __m256i b_reversed = _mm256_shuffle_epi8(b, reverser);
383   // 00 14 01 13 02 12 03 11
384   const __m256i ab_lo = _mm256_unpacklo_epi16(a, b_reversed);
385   // 04 10 05 09 06 08 07 ZZ
386   const __m256i ab_hi = _mm256_unpackhi_epi16(a, b_reversed);
387 
388   // Square(partial[0][i]) + Square(partial[0][14 - i])
389   const __m256i square_lo = _mm256_madd_epi16(ab_lo, ab_lo);
390   const __m256i square_hi = _mm256_madd_epi16(ab_hi, ab_hi);
391 
392   const __m256i c = _mm256_mullo_epi32(square_lo, division_table_0);
393   const __m256i d = _mm256_mullo_epi32(square_hi, division_table_1);
394   const __m256i e = SumVectorPair_S32(_mm256_add_epi32(c, d));
395   // Copy upper 32bit sum to lower lane.
396   const __m128i sums =
397       _mm256_castsi256_si128(_mm256_permute4x64_epi64(e, 0x08));
398   cost[0] = _mm_cvtsi128_si32(sums);
399   cost[4] = _mm_cvtsi128_si32(_mm_srli_si128(sums, 8));
400 }
401 
402 template <int index_a, int index_b>
CostOdd_Pair(uint32_t * cost,const __m256i partial_a,const __m256i partial_b,const __m256i division_table[2])403 inline void CostOdd_Pair(uint32_t* cost, const __m256i partial_a,
404                          const __m256i partial_b,
405                          const __m256i division_table[2]) {
406   // partial_lo
407   const __m256i a = partial_a;
408   // partial_hi
409   const __m256i b = partial_b;
410 
411   // Reverse and clear upper 10 bytes.
412   const __m256i reverser = _mm256_broadcastsi128_si256(
413       _mm_set_epi32(static_cast<int>(0x80808080), static_cast<int>(0x80808080),
414                     static_cast<int>(0x80800100), 0x03020504));
415 
416   // 10 09 08 ZZ ZZ ZZ ZZ ZZ
417   const __m256i b_reversed = _mm256_shuffle_epi8(b, reverser);
418   // 00 10 01 09 02 08 03 ZZ
419   const __m256i ab_lo = _mm256_unpacklo_epi16(a, b_reversed);
420   // 04 ZZ 05 ZZ 06 ZZ 07 ZZ
421   const __m256i ab_hi = _mm256_unpackhi_epi16(a, b_reversed);
422 
423   // Square(partial[0][i]) + Square(partial[0][14 - i])
424   const __m256i square_lo = _mm256_madd_epi16(ab_lo, ab_lo);
425   const __m256i square_hi = _mm256_madd_epi16(ab_hi, ab_hi);
426 
427   const __m256i c = _mm256_mullo_epi32(square_lo, division_table[0]);
428   const __m256i d = _mm256_mullo_epi32(square_hi, division_table[1]);
429   const __m256i e = SumVectorPair_S32(_mm256_add_epi32(c, d));
430   // Copy upper 32bit sum to lower lane.
431   const __m128i sums =
432       _mm256_castsi256_si128(_mm256_permute4x64_epi64(e, 0x08));
433   cost[index_a] = _mm_cvtsi128_si32(sums);
434   cost[index_b] = _mm_cvtsi128_si32(_mm_srli_si128(sums, 8));
435 }
436 
Cost2And6_Pair(uint32_t * cost,const __m256i partial_a,const __m256i partial_b,const __m256i division_table)437 inline void Cost2And6_Pair(uint32_t* cost, const __m256i partial_a,
438                            const __m256i partial_b,
439                            const __m256i division_table) {
440   // The upper lane is a "don't care", so only use the lower lane for
441   // calculating cost.
442   const __m256i a = _mm256_permute2x128_si256(partial_a, partial_b, 0x20);
443 
444   const __m256i square_a = _mm256_madd_epi16(a, a);
445   const __m256i b = _mm256_mullo_epi32(square_a, division_table);
446   const __m256i c = SumVectorPair_S32(b);
447   // Copy upper 32bit sum to lower lane.
448   const __m128i sums =
449       _mm256_castsi256_si128(_mm256_permute4x64_epi64(c, 0x08));
450   cost[2] = _mm_cvtsi128_si32(sums);
451   cost[6] = _mm_cvtsi128_si32(_mm_srli_si128(sums, 8));
452 }
453 
CdefDirection_AVX2(const void * LIBGAV1_RESTRICT const source,ptrdiff_t stride,uint8_t * LIBGAV1_RESTRICT const direction,int * LIBGAV1_RESTRICT const variance)454 void CdefDirection_AVX2(const void* LIBGAV1_RESTRICT const source,
455                         ptrdiff_t stride,
456                         uint8_t* LIBGAV1_RESTRICT const direction,
457                         int* LIBGAV1_RESTRICT const variance) {
458   assert(direction != nullptr);
459   assert(variance != nullptr);
460   const auto* src = static_cast<const uint8_t*>(source);
461   uint32_t cost[8];
462 
463   // partial[0] = add partial 0,4 low
464   // partial[1] = add partial 1,3 low
465   // partial[2] = add partial 2 low
466   // partial[3] = add partial 1,3 high
467   // partial[4] = add partial 0,4 high
468   // partial[5] = add partial 7,5 high
469   // partial[6] = add partial 6 low
470   // partial[7] = add partial 7,5 low
471   __m256i partial[8];
472 
473   AddPartial(src, stride, partial);
474 
475   const __m256i division_table = LoadUnaligned32(kCdefDivisionTable);
476   const __m256i division_table_7 =
477       _mm256_broadcastd_epi32(_mm_cvtsi32_si128(kCdefDivisionTable[7]));
478 
479   Cost2And6_Pair(cost, partial[2], partial[6], division_table_7);
480 
481   Cost0Or4_Pair(cost, partial[0], partial[4], division_table);
482 
483   const __m256i division_table_odd[2] = {
484       LoadUnaligned32(kCdefDivisionTableOddPairsPadded),
485       LoadUnaligned32(kCdefDivisionTableOddPairsPadded + 8)};
486 
487   CostOdd_Pair<1, 3>(cost, partial[1], partial[3], division_table_odd);
488   CostOdd_Pair<7, 5>(cost, partial[7], partial[5], division_table_odd);
489 
490   uint32_t best_cost = 0;
491   *direction = 0;
492   for (int i = 0; i < 8; ++i) {
493     if (cost[i] > best_cost) {
494       best_cost = cost[i];
495       *direction = i;
496     }
497   }
498   *variance = (best_cost - cost[(*direction + 4) & 7]) >> 10;
499 }
500 
501 // -------------------------------------------------------------------------
502 // CdefFilter
503 
504 // Load 4 vectors based on the given |direction|.
LoadDirection(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t stride,__m128i * output,const int direction)505 inline void LoadDirection(const uint16_t* LIBGAV1_RESTRICT const src,
506                           const ptrdiff_t stride, __m128i* output,
507                           const int direction) {
508   // Each |direction| describes a different set of source values. Expand this
509   // set by negating each set. For |direction| == 0 this gives a diagonal line
510   // from top right to bottom left. The first value is y, the second x. Negative
511   // y values move up.
512   //    a       b         c       d
513   // {-1, 1}, {1, -1}, {-2, 2}, {2, -2}
514   //         c
515   //       a
516   //     0
517   //   b
518   // d
519   const int y_0 = kCdefDirections[direction][0][0];
520   const int x_0 = kCdefDirections[direction][0][1];
521   const int y_1 = kCdefDirections[direction][1][0];
522   const int x_1 = kCdefDirections[direction][1][1];
523   output[0] = LoadUnaligned16(src - y_0 * stride - x_0);
524   output[1] = LoadUnaligned16(src + y_0 * stride + x_0);
525   output[2] = LoadUnaligned16(src - y_1 * stride - x_1);
526   output[3] = LoadUnaligned16(src + y_1 * stride + x_1);
527 }
528 
529 // Load 4 vectors based on the given |direction|. Use when |block_width| == 4 to
530 // do 2 rows at a time.
LoadDirection4(const uint16_t * LIBGAV1_RESTRICT const src,const ptrdiff_t stride,__m128i * output,const int direction)531 void LoadDirection4(const uint16_t* LIBGAV1_RESTRICT const src,
532                     const ptrdiff_t stride, __m128i* output,
533                     const int direction) {
534   const int y_0 = kCdefDirections[direction][0][0];
535   const int x_0 = kCdefDirections[direction][0][1];
536   const int y_1 = kCdefDirections[direction][1][0];
537   const int x_1 = kCdefDirections[direction][1][1];
538   output[0] = LoadHi8(LoadLo8(src - y_0 * stride - x_0),
539                       src - y_0 * stride + stride - x_0);
540   output[1] = LoadHi8(LoadLo8(src + y_0 * stride + x_0),
541                       src + y_0 * stride + stride + x_0);
542   output[2] = LoadHi8(LoadLo8(src - y_1 * stride - x_1),
543                       src - y_1 * stride + stride - x_1);
544   output[3] = LoadHi8(LoadLo8(src + y_1 * stride + x_1),
545                       src + y_1 * stride + stride + x_1);
546 }
547 
Constrain(const __m256i & pixel,const __m256i & reference,const __m128i & damping,const __m256i & threshold)548 inline __m256i Constrain(const __m256i& pixel, const __m256i& reference,
549                          const __m128i& damping, const __m256i& threshold) {
550   const __m256i diff = _mm256_sub_epi16(pixel, reference);
551   const __m256i abs_diff = _mm256_abs_epi16(diff);
552   // sign(diff) * Clip3(threshold - (std::abs(diff) >> damping),
553   //                    0, std::abs(diff))
554   const __m256i shifted_diff = _mm256_srl_epi16(abs_diff, damping);
555   // For bitdepth == 8, the threshold range is [0, 15] and the damping range is
556   // [3, 6]. If pixel == kCdefLargeValue(0x4000), shifted_diff will always be
557   // larger than threshold. Subtract using saturation will return 0 when pixel
558   // == kCdefLargeValue.
559   static_assert(kCdefLargeValue == 0x4000, "Invalid kCdefLargeValue");
560   const __m256i thresh_minus_shifted_diff =
561       _mm256_subs_epu16(threshold, shifted_diff);
562   const __m256i clamp_abs_diff =
563       _mm256_min_epi16(thresh_minus_shifted_diff, abs_diff);
564   // Restore the sign.
565   return _mm256_sign_epi16(clamp_abs_diff, diff);
566 }
567 
ApplyConstrainAndTap(const __m256i & pixel,const __m256i & val,const __m256i & tap,const __m128i & damping,const __m256i & threshold)568 inline __m256i ApplyConstrainAndTap(const __m256i& pixel, const __m256i& val,
569                                     const __m256i& tap, const __m128i& damping,
570                                     const __m256i& threshold) {
571   const __m256i constrained = Constrain(val, pixel, damping, threshold);
572   return _mm256_mullo_epi16(constrained, tap);
573 }
574 
575 template <int width, bool enable_primary = true, bool enable_secondary = true>
CdefFilter_AVX2(const uint16_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,const int primary_strength,const int secondary_strength,const int damping,const int direction,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dst_stride)576 void CdefFilter_AVX2(const uint16_t* LIBGAV1_RESTRICT src,
577                      const ptrdiff_t src_stride, const int height,
578                      const int primary_strength, const int secondary_strength,
579                      const int damping, const int direction,
580                      void* LIBGAV1_RESTRICT dest, const ptrdiff_t dst_stride) {
581   static_assert(width == 8 || width == 4, "Invalid CDEF width.");
582   static_assert(enable_primary || enable_secondary, "");
583   constexpr bool clipping_required = enable_primary && enable_secondary;
584   auto* dst = static_cast<uint8_t*>(dest);
585   __m128i primary_damping_shift, secondary_damping_shift;
586 
587   // FloorLog2() requires input to be > 0.
588   // 8-bit damping range: Y: [3, 6], UV: [2, 5].
589   if (enable_primary) {
590     // primary_strength: [0, 15] -> FloorLog2: [0, 3] so a clamp is necessary
591     // for UV filtering.
592     primary_damping_shift =
593         _mm_cvtsi32_si128(std::max(0, damping - FloorLog2(primary_strength)));
594   }
595   if (enable_secondary) {
596     // secondary_strength: [0, 4] -> FloorLog2: [0, 2] so no clamp to 0 is
597     // necessary.
598     assert(damping - FloorLog2(secondary_strength) >= 0);
599     secondary_damping_shift =
600         _mm_cvtsi32_si128(damping - FloorLog2(secondary_strength));
601   }
602   const __m256i primary_tap_0 = _mm256_broadcastw_epi16(
603       _mm_cvtsi32_si128(kCdefPrimaryTaps[primary_strength & 1][0]));
604   const __m256i primary_tap_1 = _mm256_broadcastw_epi16(
605       _mm_cvtsi32_si128(kCdefPrimaryTaps[primary_strength & 1][1]));
606   const __m256i secondary_tap_0 =
607       _mm256_broadcastw_epi16(_mm_cvtsi32_si128(kCdefSecondaryTap0));
608   const __m256i secondary_tap_1 =
609       _mm256_broadcastw_epi16(_mm_cvtsi32_si128(kCdefSecondaryTap1));
610   const __m256i cdef_large_value_mask = _mm256_broadcastw_epi16(
611       _mm_cvtsi32_si128(static_cast<int16_t>(~kCdefLargeValue)));
612   const __m256i primary_threshold =
613       _mm256_broadcastw_epi16(_mm_cvtsi32_si128(primary_strength));
614   const __m256i secondary_threshold =
615       _mm256_broadcastw_epi16(_mm_cvtsi32_si128(secondary_strength));
616 
617   int y = height;
618   do {
619     __m128i pixel_128;
620     if (width == 8) {
621       pixel_128 = LoadUnaligned16(src);
622     } else {
623       pixel_128 = LoadHi8(LoadLo8(src), src + src_stride);
624     }
625 
626     __m256i pixel = SetrM128i(pixel_128, pixel_128);
627 
628     __m256i min = pixel;
629     __m256i max = pixel;
630     __m256i sum_pair;
631 
632     if (enable_primary) {
633       // Primary |direction|.
634       __m128i primary_val_128[4];
635       if (width == 8) {
636         LoadDirection(src, src_stride, primary_val_128, direction);
637       } else {
638         LoadDirection4(src, src_stride, primary_val_128, direction);
639       }
640 
641       __m256i primary_val[2];
642       primary_val[0] = SetrM128i(primary_val_128[0], primary_val_128[1]);
643       primary_val[1] = SetrM128i(primary_val_128[2], primary_val_128[3]);
644 
645       if (clipping_required) {
646         min = _mm256_min_epu16(min, primary_val[0]);
647         min = _mm256_min_epu16(min, primary_val[1]);
648 
649         // The source is 16 bits, however, we only really care about the lower
650         // 8 bits.  The upper 8 bits contain the "large" flag.  After the final
651         // primary max has been calculated, zero out the upper 8 bits.  Use this
652         // to find the "16 bit" max.
653         const __m256i max_p01 = _mm256_max_epu8(primary_val[0], primary_val[1]);
654         max = _mm256_max_epu16(
655             max, _mm256_and_si256(max_p01, cdef_large_value_mask));
656       }
657 
658       sum_pair = ApplyConstrainAndTap(pixel, primary_val[0], primary_tap_0,
659                                       primary_damping_shift, primary_threshold);
660       sum_pair = _mm256_add_epi16(
661           sum_pair,
662           ApplyConstrainAndTap(pixel, primary_val[1], primary_tap_1,
663                                primary_damping_shift, primary_threshold));
664     } else {
665       sum_pair = _mm256_setzero_si256();
666     }
667 
668     if (enable_secondary) {
669       // Secondary |direction| values (+/- 2). Clamp |direction|.
670       __m128i secondary_val_128[8];
671       if (width == 8) {
672         LoadDirection(src, src_stride, secondary_val_128, direction + 2);
673         LoadDirection(src, src_stride, secondary_val_128 + 4, direction - 2);
674       } else {
675         LoadDirection4(src, src_stride, secondary_val_128, direction + 2);
676         LoadDirection4(src, src_stride, secondary_val_128 + 4, direction - 2);
677       }
678 
679       __m256i secondary_val[4];
680       secondary_val[0] = SetrM128i(secondary_val_128[0], secondary_val_128[1]);
681       secondary_val[1] = SetrM128i(secondary_val_128[2], secondary_val_128[3]);
682       secondary_val[2] = SetrM128i(secondary_val_128[4], secondary_val_128[5]);
683       secondary_val[3] = SetrM128i(secondary_val_128[6], secondary_val_128[7]);
684 
685       if (clipping_required) {
686         min = _mm256_min_epu16(min, secondary_val[0]);
687         min = _mm256_min_epu16(min, secondary_val[1]);
688         min = _mm256_min_epu16(min, secondary_val[2]);
689         min = _mm256_min_epu16(min, secondary_val[3]);
690 
691         const __m256i max_s01 =
692             _mm256_max_epu8(secondary_val[0], secondary_val[1]);
693         const __m256i max_s23 =
694             _mm256_max_epu8(secondary_val[2], secondary_val[3]);
695         const __m256i max_s = _mm256_max_epu8(max_s01, max_s23);
696         max = _mm256_max_epu8(max,
697                               _mm256_and_si256(max_s, cdef_large_value_mask));
698       }
699 
700       sum_pair = _mm256_add_epi16(
701           sum_pair,
702           ApplyConstrainAndTap(pixel, secondary_val[0], secondary_tap_0,
703                                secondary_damping_shift, secondary_threshold));
704       sum_pair = _mm256_add_epi16(
705           sum_pair,
706           ApplyConstrainAndTap(pixel, secondary_val[1], secondary_tap_1,
707                                secondary_damping_shift, secondary_threshold));
708       sum_pair = _mm256_add_epi16(
709           sum_pair,
710           ApplyConstrainAndTap(pixel, secondary_val[2], secondary_tap_0,
711                                secondary_damping_shift, secondary_threshold));
712       sum_pair = _mm256_add_epi16(
713           sum_pair,
714           ApplyConstrainAndTap(pixel, secondary_val[3], secondary_tap_1,
715                                secondary_damping_shift, secondary_threshold));
716     }
717 
718     __m128i sum = _mm_add_epi16(_mm256_castsi256_si128(sum_pair),
719                                 _mm256_extracti128_si256(sum_pair, 1));
720 
721     // Clip3(pixel + ((8 + sum - (sum < 0)) >> 4), min, max))
722     const __m128i sum_lt_0 = _mm_srai_epi16(sum, 15);
723     // 8 + sum
724     sum = _mm_add_epi16(sum, _mm_set1_epi16(8));
725     // (... - (sum < 0)) >> 4
726     sum = _mm_add_epi16(sum, sum_lt_0);
727     sum = _mm_srai_epi16(sum, 4);
728     // pixel + ...
729     sum = _mm_add_epi16(sum, _mm256_castsi256_si128(pixel));
730     if (clipping_required) {
731       const __m128i min_128 = _mm_min_epu16(_mm256_castsi256_si128(min),
732                                             _mm256_extracti128_si256(min, 1));
733 
734       const __m128i max_128 = _mm_max_epu16(_mm256_castsi256_si128(max),
735                                             _mm256_extracti128_si256(max, 1));
736       // Clip3
737       sum = _mm_min_epi16(sum, max_128);
738       sum = _mm_max_epi16(sum, min_128);
739     }
740 
741     const __m128i result = _mm_packus_epi16(sum, sum);
742     if (width == 8) {
743       src += src_stride;
744       StoreLo8(dst, result);
745       dst += dst_stride;
746       --y;
747     } else {
748       src += src_stride << 1;
749       Store4(dst, result);
750       dst += dst_stride;
751       Store4(dst, _mm_srli_si128(result, 4));
752       dst += dst_stride;
753       y -= 2;
754     }
755   } while (y != 0);
756 }
757 
Init8bpp()758 void Init8bpp() {
759   Dsp* const dsp = dsp_internal::GetWritableDspTable(8);
760   assert(dsp != nullptr);
761   dsp->cdef_direction = CdefDirection_AVX2;
762 
763   dsp->cdef_filters[0][0] = CdefFilter_AVX2<4>;
764   dsp->cdef_filters[0][1] =
765       CdefFilter_AVX2<4, /*enable_primary=*/true, /*enable_secondary=*/false>;
766   dsp->cdef_filters[0][2] = CdefFilter_AVX2<4, /*enable_primary=*/false>;
767   dsp->cdef_filters[1][0] = CdefFilter_AVX2<8>;
768   dsp->cdef_filters[1][1] =
769       CdefFilter_AVX2<8, /*enable_primary=*/true, /*enable_secondary=*/false>;
770   dsp->cdef_filters[1][2] = CdefFilter_AVX2<8, /*enable_primary=*/false>;
771 }
772 
773 }  // namespace
774 }  // namespace low_bitdepth
775 
CdefInit_AVX2()776 void CdefInit_AVX2() { low_bitdepth::Init8bpp(); }
777 
778 }  // namespace dsp
779 }  // namespace libgav1
780 #else   // !LIBGAV1_TARGETING_AVX2
781 namespace libgav1 {
782 namespace dsp {
783 
CdefInit_AVX2()784 void CdefInit_AVX2() {}
785 
786 }  // namespace dsp
787 }  // namespace libgav1
788 #endif  // LIBGAV1_TARGETING_AVX2
789