xref: /aosp_15_r20/external/libvpx/vpx_dsp/x86/fwd_txfm_avx2.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2012 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <immintrin.h>  // AVX2
12 #include "./vpx_config.h"
13 #include "./vpx_dsp_rtcd.h"
14 
15 #include "vpx_dsp/txfm_common.h"
16 #define ADD256_EPI16 _mm256_add_epi16
17 #define SUB256_EPI16 _mm256_sub_epi16
18 
load_buffer_16bit_to_16bit_avx2(const int16_t * in,int stride,__m256i * out,int out_size,int pass)19 static INLINE void load_buffer_16bit_to_16bit_avx2(const int16_t *in,
20                                                    int stride, __m256i *out,
21                                                    int out_size, int pass) {
22   int i;
23   const __m256i kOne = _mm256_set1_epi16(1);
24   if (pass == 0) {
25     for (i = 0; i < out_size; i++) {
26       out[i] = _mm256_loadu_si256((const __m256i *)(in + i * stride));
27       // x = x << 2
28       out[i] = _mm256_slli_epi16(out[i], 2);
29     }
30   } else {
31     for (i = 0; i < out_size; i++) {
32       out[i] = _mm256_loadu_si256((const __m256i *)(in + i * 16));
33       // x = (x + 1) >> 2
34       out[i] = _mm256_add_epi16(out[i], kOne);
35       out[i] = _mm256_srai_epi16(out[i], 2);
36     }
37   }
38 }
39 
transpose2_8x8_avx2(const __m256i * const in,__m256i * const out)40 static INLINE void transpose2_8x8_avx2(const __m256i *const in,
41                                        __m256i *const out) {
42   int i;
43   __m256i t[16], u[16];
44   // (1st, 2nd) ==> (lo, hi)
45   //   (0, 1)   ==>  (0, 1)
46   //   (2, 3)   ==>  (2, 3)
47   //   (4, 5)   ==>  (4, 5)
48   //   (6, 7)   ==>  (6, 7)
49   for (i = 0; i < 4; i++) {
50     t[2 * i] = _mm256_unpacklo_epi16(in[2 * i], in[2 * i + 1]);
51     t[2 * i + 1] = _mm256_unpackhi_epi16(in[2 * i], in[2 * i + 1]);
52   }
53 
54   // (1st, 2nd) ==> (lo, hi)
55   //   (0, 2)   ==>  (0, 2)
56   //   (1, 3)   ==>  (1, 3)
57   //   (4, 6)   ==>  (4, 6)
58   //   (5, 7)   ==>  (5, 7)
59   for (i = 0; i < 2; i++) {
60     u[i] = _mm256_unpacklo_epi32(t[i], t[i + 2]);
61     u[i + 2] = _mm256_unpackhi_epi32(t[i], t[i + 2]);
62 
63     u[i + 4] = _mm256_unpacklo_epi32(t[i + 4], t[i + 6]);
64     u[i + 6] = _mm256_unpackhi_epi32(t[i + 4], t[i + 6]);
65   }
66 
67   // (1st, 2nd) ==> (lo, hi)
68   //   (0, 4)   ==>  (0, 1)
69   //   (1, 5)   ==>  (4, 5)
70   //   (2, 6)   ==>  (2, 3)
71   //   (3, 7)   ==>  (6, 7)
72   for (i = 0; i < 2; i++) {
73     out[2 * i] = _mm256_unpacklo_epi64(u[2 * i], u[2 * i + 4]);
74     out[2 * i + 1] = _mm256_unpackhi_epi64(u[2 * i], u[2 * i + 4]);
75 
76     out[2 * i + 4] = _mm256_unpacklo_epi64(u[2 * i + 1], u[2 * i + 5]);
77     out[2 * i + 5] = _mm256_unpackhi_epi64(u[2 * i + 1], u[2 * i + 5]);
78   }
79 }
80 
transpose_16bit_16x16_avx2(const __m256i * const in,__m256i * const out)81 static INLINE void transpose_16bit_16x16_avx2(const __m256i *const in,
82                                               __m256i *const out) {
83   __m256i t[16];
84 
85 #define LOADL(idx)                                                            \
86   t[idx] = _mm256_castsi128_si256(_mm_load_si128((__m128i const *)&in[idx])); \
87   t[idx] = _mm256_inserti128_si256(                                           \
88       t[idx], _mm_load_si128((__m128i const *)&in[idx + 8]), 1);
89 
90 #define LOADR(idx)                                                           \
91   t[8 + idx] =                                                               \
92       _mm256_castsi128_si256(_mm_load_si128((__m128i const *)&in[idx] + 1)); \
93   t[8 + idx] = _mm256_inserti128_si256(                                      \
94       t[8 + idx], _mm_load_si128((__m128i const *)&in[idx + 8] + 1), 1);
95 
96   // load left 8x16
97   LOADL(0)
98   LOADL(1)
99   LOADL(2)
100   LOADL(3)
101   LOADL(4)
102   LOADL(5)
103   LOADL(6)
104   LOADL(7)
105 
106   // load right 8x16
107   LOADR(0)
108   LOADR(1)
109   LOADR(2)
110   LOADR(3)
111   LOADR(4)
112   LOADR(5)
113   LOADR(6)
114   LOADR(7)
115 
116   // get the top 16x8 result
117   transpose2_8x8_avx2(t, out);
118   // get the bottom 16x8 result
119   transpose2_8x8_avx2(&t[8], &out[8]);
120 }
121 
122 // Store 8 16-bit values. Sign extend the values.
store_buffer_16bit_to_32bit_w16_avx2(const __m256i * const in,tran_low_t * out,const int stride,const int out_size)123 static INLINE void store_buffer_16bit_to_32bit_w16_avx2(const __m256i *const in,
124                                                         tran_low_t *out,
125                                                         const int stride,
126                                                         const int out_size) {
127   int i;
128   for (i = 0; i < out_size; ++i) {
129     _mm256_storeu_si256((__m256i *)(out), in[i]);
130     out += stride;
131   }
132 }
133 
134 #define PAIR256_SET_EPI16(a, b)                                            \
135   _mm256_set_epi16((int16_t)(b), (int16_t)(a), (int16_t)(b), (int16_t)(a), \
136                    (int16_t)(b), (int16_t)(a), (int16_t)(b), (int16_t)(a), \
137                    (int16_t)(b), (int16_t)(a), (int16_t)(b), (int16_t)(a), \
138                    (int16_t)(b), (int16_t)(a), (int16_t)(b), (int16_t)(a))
139 
mult256_round_shift(const __m256i * pin0,const __m256i * pin1,const __m256i * pmultiplier,const __m256i * prounding,const int shift)140 static INLINE __m256i mult256_round_shift(const __m256i *pin0,
141                                           const __m256i *pin1,
142                                           const __m256i *pmultiplier,
143                                           const __m256i *prounding,
144                                           const int shift) {
145   const __m256i u0 = _mm256_madd_epi16(*pin0, *pmultiplier);
146   const __m256i u1 = _mm256_madd_epi16(*pin1, *pmultiplier);
147   const __m256i v0 = _mm256_add_epi32(u0, *prounding);
148   const __m256i v1 = _mm256_add_epi32(u1, *prounding);
149   const __m256i w0 = _mm256_srai_epi32(v0, shift);
150   const __m256i w1 = _mm256_srai_epi32(v1, shift);
151   return _mm256_packs_epi32(w0, w1);
152 }
153 
fdct16x16_1D_avx2(__m256i * input,__m256i * output)154 static INLINE void fdct16x16_1D_avx2(__m256i *input, __m256i *output) {
155   int i;
156   __m256i step2[4];
157   __m256i in[8];
158   __m256i step1[8];
159   __m256i step3[8];
160 
161   const __m256i k__cospi_p16_p16 = _mm256_set1_epi16(cospi_16_64);
162   const __m256i k__cospi_p16_m16 = PAIR256_SET_EPI16(cospi_16_64, -cospi_16_64);
163   const __m256i k__cospi_p24_p08 = PAIR256_SET_EPI16(cospi_24_64, cospi_8_64);
164   const __m256i k__cospi_p08_m24 = PAIR256_SET_EPI16(cospi_8_64, -cospi_24_64);
165   const __m256i k__cospi_m08_p24 = PAIR256_SET_EPI16(-cospi_8_64, cospi_24_64);
166   const __m256i k__cospi_p28_p04 = PAIR256_SET_EPI16(cospi_28_64, cospi_4_64);
167   const __m256i k__cospi_m04_p28 = PAIR256_SET_EPI16(-cospi_4_64, cospi_28_64);
168   const __m256i k__cospi_p12_p20 = PAIR256_SET_EPI16(cospi_12_64, cospi_20_64);
169   const __m256i k__cospi_m20_p12 = PAIR256_SET_EPI16(-cospi_20_64, cospi_12_64);
170   const __m256i k__cospi_p30_p02 = PAIR256_SET_EPI16(cospi_30_64, cospi_2_64);
171   const __m256i k__cospi_p14_p18 = PAIR256_SET_EPI16(cospi_14_64, cospi_18_64);
172   const __m256i k__cospi_m02_p30 = PAIR256_SET_EPI16(-cospi_2_64, cospi_30_64);
173   const __m256i k__cospi_m18_p14 = PAIR256_SET_EPI16(-cospi_18_64, cospi_14_64);
174   const __m256i k__cospi_p22_p10 = PAIR256_SET_EPI16(cospi_22_64, cospi_10_64);
175   const __m256i k__cospi_p06_p26 = PAIR256_SET_EPI16(cospi_6_64, cospi_26_64);
176   const __m256i k__cospi_m10_p22 = PAIR256_SET_EPI16(-cospi_10_64, cospi_22_64);
177   const __m256i k__cospi_m26_p06 = PAIR256_SET_EPI16(-cospi_26_64, cospi_6_64);
178   const __m256i k__DCT_CONST_ROUNDING = _mm256_set1_epi32(DCT_CONST_ROUNDING);
179 
180   // Calculate input for the first 8 results.
181   for (i = 0; i < 8; i++) {
182     in[i] = ADD256_EPI16(input[i], input[15 - i]);
183   }
184 
185   // Calculate input for the next 8 results.
186   for (i = 0; i < 8; i++) {
187     step1[i] = SUB256_EPI16(input[7 - i], input[8 + i]);
188   }
189 
190   // Work on the first eight values; fdct8(input, even_results);
191   {
192     // Add/subtract
193     const __m256i q0 = ADD256_EPI16(in[0], in[7]);
194     const __m256i q1 = ADD256_EPI16(in[1], in[6]);
195     const __m256i q2 = ADD256_EPI16(in[2], in[5]);
196     const __m256i q3 = ADD256_EPI16(in[3], in[4]);
197     const __m256i q4 = SUB256_EPI16(in[3], in[4]);
198     const __m256i q5 = SUB256_EPI16(in[2], in[5]);
199     const __m256i q6 = SUB256_EPI16(in[1], in[6]);
200     const __m256i q7 = SUB256_EPI16(in[0], in[7]);
201 
202     // Work on first four results
203     {
204       // Add/subtract
205       const __m256i r0 = ADD256_EPI16(q0, q3);
206       const __m256i r1 = ADD256_EPI16(q1, q2);
207       const __m256i r2 = SUB256_EPI16(q1, q2);
208       const __m256i r3 = SUB256_EPI16(q0, q3);
209 
210       // Interleave to do the multiply by constants which gets us
211       // into 32 bits.
212       {
213         const __m256i t0 = _mm256_unpacklo_epi16(r0, r1);
214         const __m256i t1 = _mm256_unpackhi_epi16(r0, r1);
215         const __m256i t2 = _mm256_unpacklo_epi16(r2, r3);
216         const __m256i t3 = _mm256_unpackhi_epi16(r2, r3);
217 
218         output[0] = mult256_round_shift(&t0, &t1, &k__cospi_p16_p16,
219                                         &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
220         output[8] = mult256_round_shift(&t0, &t1, &k__cospi_p16_m16,
221                                         &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
222         output[4] = mult256_round_shift(&t2, &t3, &k__cospi_p24_p08,
223                                         &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
224         output[12] =
225             mult256_round_shift(&t2, &t3, &k__cospi_m08_p24,
226                                 &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
227       }
228     }
229 
230     // Work on next four results
231     {
232       // Interleave to do the multiply by constants which gets us
233       // into 32 bits.
234       const __m256i d0 = _mm256_unpacklo_epi16(q6, q5);
235       const __m256i d1 = _mm256_unpackhi_epi16(q6, q5);
236       const __m256i r0 = mult256_round_shift(
237           &d0, &d1, &k__cospi_p16_m16, &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
238       const __m256i r1 = mult256_round_shift(
239           &d0, &d1, &k__cospi_p16_p16, &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
240 
241       {
242         // Add/subtract
243         const __m256i x0 = ADD256_EPI16(q4, r0);
244         const __m256i x1 = SUB256_EPI16(q4, r0);
245         const __m256i x2 = SUB256_EPI16(q7, r1);
246         const __m256i x3 = ADD256_EPI16(q7, r1);
247 
248         // Interleave to do the multiply by constants which gets us
249         // into 32 bits.
250         {
251           const __m256i t0 = _mm256_unpacklo_epi16(x0, x3);
252           const __m256i t1 = _mm256_unpackhi_epi16(x0, x3);
253           const __m256i t2 = _mm256_unpacklo_epi16(x1, x2);
254           const __m256i t3 = _mm256_unpackhi_epi16(x1, x2);
255           output[2] =
256               mult256_round_shift(&t0, &t1, &k__cospi_p28_p04,
257                                   &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
258           output[14] =
259               mult256_round_shift(&t0, &t1, &k__cospi_m04_p28,
260                                   &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
261           output[10] =
262               mult256_round_shift(&t2, &t3, &k__cospi_p12_p20,
263                                   &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
264           output[6] =
265               mult256_round_shift(&t2, &t3, &k__cospi_m20_p12,
266                                   &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
267         }
268       }
269     }
270   }
271   // Work on the next eight values; step1 -> odd_results
272   {  // step 2
273     {
274       const __m256i t0 = _mm256_unpacklo_epi16(step1[5], step1[2]);
275       const __m256i t1 = _mm256_unpackhi_epi16(step1[5], step1[2]);
276       const __m256i t2 = _mm256_unpacklo_epi16(step1[4], step1[3]);
277       const __m256i t3 = _mm256_unpackhi_epi16(step1[4], step1[3]);
278       step2[0] = mult256_round_shift(&t0, &t1, &k__cospi_p16_m16,
279                                      &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
280       step2[1] = mult256_round_shift(&t2, &t3, &k__cospi_p16_m16,
281                                      &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
282       step2[2] = mult256_round_shift(&t0, &t1, &k__cospi_p16_p16,
283                                      &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
284       step2[3] = mult256_round_shift(&t2, &t3, &k__cospi_p16_p16,
285                                      &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
286     }
287     // step 3
288     {
289       step3[0] = ADD256_EPI16(step1[0], step2[1]);
290       step3[1] = ADD256_EPI16(step1[1], step2[0]);
291       step3[2] = SUB256_EPI16(step1[1], step2[0]);
292       step3[3] = SUB256_EPI16(step1[0], step2[1]);
293       step3[4] = SUB256_EPI16(step1[7], step2[3]);
294       step3[5] = SUB256_EPI16(step1[6], step2[2]);
295       step3[6] = ADD256_EPI16(step1[6], step2[2]);
296       step3[7] = ADD256_EPI16(step1[7], step2[3]);
297     }
298     // step 4
299     {
300       const __m256i t0 = _mm256_unpacklo_epi16(step3[1], step3[6]);
301       const __m256i t1 = _mm256_unpackhi_epi16(step3[1], step3[6]);
302       const __m256i t2 = _mm256_unpacklo_epi16(step3[2], step3[5]);
303       const __m256i t3 = _mm256_unpackhi_epi16(step3[2], step3[5]);
304       step2[0] = mult256_round_shift(&t0, &t1, &k__cospi_m08_p24,
305                                      &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
306       step2[1] = mult256_round_shift(&t2, &t3, &k__cospi_p24_p08,
307                                      &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
308       step2[2] = mult256_round_shift(&t0, &t1, &k__cospi_p24_p08,
309                                      &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
310       step2[3] = mult256_round_shift(&t2, &t3, &k__cospi_p08_m24,
311                                      &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
312     }
313     // step 5
314     {
315       step1[0] = ADD256_EPI16(step3[0], step2[0]);
316       step1[1] = SUB256_EPI16(step3[0], step2[0]);
317       step1[2] = ADD256_EPI16(step3[3], step2[1]);
318       step1[3] = SUB256_EPI16(step3[3], step2[1]);
319       step1[4] = SUB256_EPI16(step3[4], step2[3]);
320       step1[5] = ADD256_EPI16(step3[4], step2[3]);
321       step1[6] = SUB256_EPI16(step3[7], step2[2]);
322       step1[7] = ADD256_EPI16(step3[7], step2[2]);
323     }
324     // step 6
325     {
326       const __m256i t0 = _mm256_unpacklo_epi16(step1[0], step1[7]);
327       const __m256i t1 = _mm256_unpackhi_epi16(step1[0], step1[7]);
328       const __m256i t2 = _mm256_unpacklo_epi16(step1[1], step1[6]);
329       const __m256i t3 = _mm256_unpackhi_epi16(step1[1], step1[6]);
330       output[1] = mult256_round_shift(&t0, &t1, &k__cospi_p30_p02,
331                                       &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
332       output[9] = mult256_round_shift(&t2, &t3, &k__cospi_p14_p18,
333                                       &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
334       output[15] = mult256_round_shift(&t0, &t1, &k__cospi_m02_p30,
335                                        &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
336       output[7] = mult256_round_shift(&t2, &t3, &k__cospi_m18_p14,
337                                       &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
338     }
339     {
340       const __m256i t0 = _mm256_unpacklo_epi16(step1[2], step1[5]);
341       const __m256i t1 = _mm256_unpackhi_epi16(step1[2], step1[5]);
342       const __m256i t2 = _mm256_unpacklo_epi16(step1[3], step1[4]);
343       const __m256i t3 = _mm256_unpackhi_epi16(step1[3], step1[4]);
344       output[5] = mult256_round_shift(&t0, &t1, &k__cospi_p22_p10,
345                                       &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
346       output[13] = mult256_round_shift(&t2, &t3, &k__cospi_p06_p26,
347                                        &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
348       output[11] = mult256_round_shift(&t0, &t1, &k__cospi_m10_p22,
349                                        &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
350       output[3] = mult256_round_shift(&t2, &t3, &k__cospi_m26_p06,
351                                       &k__DCT_CONST_ROUNDING, DCT_CONST_BITS);
352     }
353   }
354 }
355 
vpx_fdct16x16_avx2(const int16_t * input,tran_low_t * output,int stride)356 void vpx_fdct16x16_avx2(const int16_t *input, tran_low_t *output, int stride) {
357   int pass;
358   DECLARE_ALIGNED(32, int16_t, intermediate[256]);
359   int16_t *out0 = intermediate;
360   tran_low_t *out1 = output;
361   const int width = 16;
362   const int height = 16;
363   __m256i buf0[16], buf1[16];
364 
365   // Two transform and transpose passes
366   // Process 16 columns (transposed rows in second pass) at a time.
367   for (pass = 0; pass < 2; ++pass) {
368     // Load and pre-condition input.
369     load_buffer_16bit_to_16bit_avx2(input, stride, buf1, height, pass);
370 
371     // Calculate dct for 16x16 values
372     fdct16x16_1D_avx2(buf1, buf0);
373 
374     // Transpose the results.
375     transpose_16bit_16x16_avx2(buf0, buf1);
376 
377     if (pass == 0) {
378       store_buffer_16bit_to_32bit_w16_avx2(buf1, (tran_low_t *)out0, width,
379                                            height);
380     } else {
381       store_buffer_16bit_to_32bit_w16_avx2(buf1, out1, width, height);
382     }
383     // Setup in/out for next pass.
384     input = intermediate;
385   }
386 }
387 
388 #if !CONFIG_VP9_HIGHBITDEPTH
389 #define FDCT32x32_2D_AVX2 vpx_fdct32x32_rd_avx2
390 #define FDCT32x32_HIGH_PRECISION 0
391 #include "vpx_dsp/x86/fwd_dct32x32_impl_avx2.h"
392 #undef FDCT32x32_2D_AVX2
393 #undef FDCT32x32_HIGH_PRECISION
394 
395 #define FDCT32x32_2D_AVX2 vpx_fdct32x32_avx2
396 #define FDCT32x32_HIGH_PRECISION 1
397 #include "vpx_dsp/x86/fwd_dct32x32_impl_avx2.h"  // NOLINT
398 #undef FDCT32x32_2D_AVX2
399 #undef FDCT32x32_HIGH_PRECISION
400 #endif  // !CONFIG_VP9_HIGHBITDEPTH
401