xref: /aosp_15_r20/external/libaom/aom_dsp/x86/txfm_common_avx2.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #ifndef AOM_AOM_DSP_X86_TXFM_COMMON_AVX2_H_
13 #define AOM_AOM_DSP_X86_TXFM_COMMON_AVX2_H_
14 
15 #include <emmintrin.h>
16 #include "aom/aom_integer.h"
17 #include "aom_dsp/x86/synonyms.h"
18 
19 #ifdef __cplusplus
20 extern "C" {
21 #endif
22 
pair_set_w16_epi16(int16_t a,int16_t b)23 static inline __m256i pair_set_w16_epi16(int16_t a, int16_t b) {
24   return _mm256_set1_epi32(
25       (int32_t)(((uint16_t)(a)) | (((uint32_t)(uint16_t)(b)) << 16)));
26 }
27 
btf_16_w16_avx2(const __m256i w0,const __m256i w1,__m256i * in0,__m256i * in1,const __m256i _r,const int32_t cos_bit)28 static inline void btf_16_w16_avx2(const __m256i w0, const __m256i w1,
29                                    __m256i *in0, __m256i *in1, const __m256i _r,
30                                    const int32_t cos_bit) {
31   __m256i t0 = _mm256_unpacklo_epi16(*in0, *in1);
32   __m256i t1 = _mm256_unpackhi_epi16(*in0, *in1);
33   __m256i u0 = _mm256_madd_epi16(t0, w0);
34   __m256i u1 = _mm256_madd_epi16(t1, w0);
35   __m256i v0 = _mm256_madd_epi16(t0, w1);
36   __m256i v1 = _mm256_madd_epi16(t1, w1);
37 
38   __m256i a0 = _mm256_add_epi32(u0, _r);
39   __m256i a1 = _mm256_add_epi32(u1, _r);
40   __m256i b0 = _mm256_add_epi32(v0, _r);
41   __m256i b1 = _mm256_add_epi32(v1, _r);
42 
43   __m256i c0 = _mm256_srai_epi32(a0, cos_bit);
44   __m256i c1 = _mm256_srai_epi32(a1, cos_bit);
45   __m256i d0 = _mm256_srai_epi32(b0, cos_bit);
46   __m256i d1 = _mm256_srai_epi32(b1, cos_bit);
47 
48   *in0 = _mm256_packs_epi32(c0, c1);
49   *in1 = _mm256_packs_epi32(d0, d1);
50 }
51 
btf_16_adds_subs_avx2(__m256i * in0,__m256i * in1)52 static inline void btf_16_adds_subs_avx2(__m256i *in0, __m256i *in1) {
53   const __m256i _in0 = *in0;
54   const __m256i _in1 = *in1;
55   *in0 = _mm256_adds_epi16(_in0, _in1);
56   *in1 = _mm256_subs_epi16(_in0, _in1);
57 }
58 
btf_32_add_sub_avx2(__m256i * in0,__m256i * in1)59 static inline void btf_32_add_sub_avx2(__m256i *in0, __m256i *in1) {
60   const __m256i _in0 = *in0;
61   const __m256i _in1 = *in1;
62   *in0 = _mm256_add_epi32(_in0, _in1);
63   *in1 = _mm256_sub_epi32(_in0, _in1);
64 }
65 
btf_16_adds_subs_out_avx2(__m256i * out0,__m256i * out1,__m256i in0,__m256i in1)66 static inline void btf_16_adds_subs_out_avx2(__m256i *out0, __m256i *out1,
67                                              __m256i in0, __m256i in1) {
68   const __m256i _in0 = in0;
69   const __m256i _in1 = in1;
70   *out0 = _mm256_adds_epi16(_in0, _in1);
71   *out1 = _mm256_subs_epi16(_in0, _in1);
72 }
73 
btf_32_add_sub_out_avx2(__m256i * out0,__m256i * out1,__m256i in0,__m256i in1)74 static inline void btf_32_add_sub_out_avx2(__m256i *out0, __m256i *out1,
75                                            __m256i in0, __m256i in1) {
76   const __m256i _in0 = in0;
77   const __m256i _in1 = in1;
78   *out0 = _mm256_add_epi32(_in0, _in1);
79   *out1 = _mm256_sub_epi32(_in0, _in1);
80 }
81 
load_16bit_to_16bit_avx2(const int16_t * a)82 static inline __m256i load_16bit_to_16bit_avx2(const int16_t *a) {
83   return _mm256_load_si256((const __m256i *)a);
84 }
85 
load_buffer_16bit_to_16bit_avx2(const int16_t * in,int stride,__m256i * out,int out_size)86 static inline void load_buffer_16bit_to_16bit_avx2(const int16_t *in,
87                                                    int stride, __m256i *out,
88                                                    int out_size) {
89   for (int i = 0; i < out_size; ++i) {
90     out[i] = load_16bit_to_16bit_avx2(in + i * stride);
91   }
92 }
93 
load_buffer_16bit_to_16bit_flip_avx2(const int16_t * in,int stride,__m256i * out,int out_size)94 static inline void load_buffer_16bit_to_16bit_flip_avx2(const int16_t *in,
95                                                         int stride,
96                                                         __m256i *out,
97                                                         int out_size) {
98   for (int i = 0; i < out_size; ++i) {
99     out[out_size - i - 1] = load_16bit_to_16bit_avx2(in + i * stride);
100   }
101 }
102 
load_32bit_to_16bit_w16_avx2(const int32_t * a)103 static inline __m256i load_32bit_to_16bit_w16_avx2(const int32_t *a) {
104   const __m256i a_low = _mm256_lddqu_si256((const __m256i *)a);
105   const __m256i b = _mm256_packs_epi32(a_low, *(const __m256i *)(a + 8));
106   return _mm256_permute4x64_epi64(b, 0xD8);
107 }
108 
load_buffer_32bit_to_16bit_w16_avx2(const int32_t * in,int stride,__m256i * out,int out_size)109 static inline void load_buffer_32bit_to_16bit_w16_avx2(const int32_t *in,
110                                                        int stride, __m256i *out,
111                                                        int out_size) {
112   for (int i = 0; i < out_size; ++i) {
113     out[i] = load_32bit_to_16bit_w16_avx2(in + i * stride);
114   }
115 }
116 
transpose2_8x8_avx2(const __m256i * const in,__m256i * const out)117 static inline void transpose2_8x8_avx2(const __m256i *const in,
118                                        __m256i *const out) {
119   __m256i t[16], u[16];
120   // (1st, 2nd) ==> (lo, hi)
121   //   (0, 1)   ==>  (0, 1)
122   //   (2, 3)   ==>  (2, 3)
123   //   (4, 5)   ==>  (4, 5)
124   //   (6, 7)   ==>  (6, 7)
125   for (int i = 0; i < 4; i++) {
126     t[2 * i] = _mm256_unpacklo_epi16(in[2 * i], in[2 * i + 1]);
127     t[2 * i + 1] = _mm256_unpackhi_epi16(in[2 * i], in[2 * i + 1]);
128   }
129 
130   // (1st, 2nd) ==> (lo, hi)
131   //   (0, 2)   ==>  (0, 2)
132   //   (1, 3)   ==>  (1, 3)
133   //   (4, 6)   ==>  (4, 6)
134   //   (5, 7)   ==>  (5, 7)
135   for (int i = 0; i < 2; i++) {
136     u[i] = _mm256_unpacklo_epi32(t[i], t[i + 2]);
137     u[i + 2] = _mm256_unpackhi_epi32(t[i], t[i + 2]);
138 
139     u[i + 4] = _mm256_unpacklo_epi32(t[i + 4], t[i + 6]);
140     u[i + 6] = _mm256_unpackhi_epi32(t[i + 4], t[i + 6]);
141   }
142 
143   // (1st, 2nd) ==> (lo, hi)
144   //   (0, 4)   ==>  (0, 1)
145   //   (1, 5)   ==>  (4, 5)
146   //   (2, 6)   ==>  (2, 3)
147   //   (3, 7)   ==>  (6, 7)
148   for (int i = 0; i < 2; i++) {
149     out[2 * i] = _mm256_unpacklo_epi64(u[2 * i], u[2 * i + 4]);
150     out[2 * i + 1] = _mm256_unpackhi_epi64(u[2 * i], u[2 * i + 4]);
151 
152     out[2 * i + 4] = _mm256_unpacklo_epi64(u[2 * i + 1], u[2 * i + 5]);
153     out[2 * i + 5] = _mm256_unpackhi_epi64(u[2 * i + 1], u[2 * i + 5]);
154   }
155 }
156 
transpose_16bit_16x16_avx2(const __m256i * const in,__m256i * const out)157 static inline void transpose_16bit_16x16_avx2(const __m256i *const in,
158                                               __m256i *const out) {
159   __m256i t[16];
160 
161 #define LOADL(idx)                                                            \
162   t[idx] = _mm256_castsi128_si256(_mm_load_si128((__m128i const *)&in[idx])); \
163   t[idx] = _mm256_inserti128_si256(                                           \
164       t[idx], _mm_load_si128((__m128i const *)&in[idx + 8]), 1);
165 
166 #define LOADR(idx)                                                           \
167   t[8 + idx] =                                                               \
168       _mm256_castsi128_si256(_mm_load_si128((__m128i const *)&in[idx] + 1)); \
169   t[8 + idx] = _mm256_inserti128_si256(                                      \
170       t[8 + idx], _mm_load_si128((__m128i const *)&in[idx + 8] + 1), 1);
171 
172   // load left 8x16
173   LOADL(0)
174   LOADL(1)
175   LOADL(2)
176   LOADL(3)
177   LOADL(4)
178   LOADL(5)
179   LOADL(6)
180   LOADL(7)
181 
182   // load right 8x16
183   LOADR(0)
184   LOADR(1)
185   LOADR(2)
186   LOADR(3)
187   LOADR(4)
188   LOADR(5)
189   LOADR(6)
190   LOADR(7)
191 
192   // get the top 16x8 result
193   transpose2_8x8_avx2(t, out);
194   // get the bottom 16x8 result
195   transpose2_8x8_avx2(&t[8], &out[8]);
196 }
197 
transpose_16bit_16x8_avx2(const __m256i * const in,__m256i * const out)198 static inline void transpose_16bit_16x8_avx2(const __m256i *const in,
199                                              __m256i *const out) {
200   const __m256i a0 = _mm256_unpacklo_epi16(in[0], in[1]);
201   const __m256i a1 = _mm256_unpacklo_epi16(in[2], in[3]);
202   const __m256i a2 = _mm256_unpacklo_epi16(in[4], in[5]);
203   const __m256i a3 = _mm256_unpacklo_epi16(in[6], in[7]);
204   const __m256i a4 = _mm256_unpackhi_epi16(in[0], in[1]);
205   const __m256i a5 = _mm256_unpackhi_epi16(in[2], in[3]);
206   const __m256i a6 = _mm256_unpackhi_epi16(in[4], in[5]);
207   const __m256i a7 = _mm256_unpackhi_epi16(in[6], in[7]);
208 
209   const __m256i b0 = _mm256_unpacklo_epi32(a0, a1);
210   const __m256i b1 = _mm256_unpacklo_epi32(a2, a3);
211   const __m256i b2 = _mm256_unpacklo_epi32(a4, a5);
212   const __m256i b3 = _mm256_unpacklo_epi32(a6, a7);
213   const __m256i b4 = _mm256_unpackhi_epi32(a0, a1);
214   const __m256i b5 = _mm256_unpackhi_epi32(a2, a3);
215   const __m256i b6 = _mm256_unpackhi_epi32(a4, a5);
216   const __m256i b7 = _mm256_unpackhi_epi32(a6, a7);
217 
218   out[0] = _mm256_unpacklo_epi64(b0, b1);
219   out[1] = _mm256_unpackhi_epi64(b0, b1);
220   out[2] = _mm256_unpacklo_epi64(b4, b5);
221   out[3] = _mm256_unpackhi_epi64(b4, b5);
222   out[4] = _mm256_unpacklo_epi64(b2, b3);
223   out[5] = _mm256_unpackhi_epi64(b2, b3);
224   out[6] = _mm256_unpacklo_epi64(b6, b7);
225   out[7] = _mm256_unpackhi_epi64(b6, b7);
226 }
227 
flip_buf_avx2(__m256i * in,__m256i * out,int size)228 static inline void flip_buf_avx2(__m256i *in, __m256i *out, int size) {
229   for (int i = 0; i < size; ++i) {
230     out[size - i - 1] = in[i];
231   }
232 }
233 
round_shift_16bit_w16_avx2(__m256i * in,int size,int bit)234 static inline void round_shift_16bit_w16_avx2(__m256i *in, int size, int bit) {
235   if (bit < 0) {
236     bit = -bit;
237     __m256i round = _mm256_set1_epi16(1 << (bit - 1));
238     for (int i = 0; i < size; ++i) {
239       in[i] = _mm256_adds_epi16(in[i], round);
240       in[i] = _mm256_srai_epi16(in[i], bit);
241     }
242   } else if (bit > 0) {
243     for (int i = 0; i < size; ++i) {
244       in[i] = _mm256_slli_epi16(in[i], bit);
245     }
246   }
247 }
248 
round_shift_32_avx2(__m256i vec,int bit)249 static inline __m256i round_shift_32_avx2(__m256i vec, int bit) {
250   __m256i tmp, round;
251   round = _mm256_set1_epi32(1 << (bit - 1));
252   tmp = _mm256_add_epi32(vec, round);
253   return _mm256_srai_epi32(tmp, bit);
254 }
255 
round_shift_array_32_avx2(__m256i * input,__m256i * output,const int size,const int bit)256 static inline void round_shift_array_32_avx2(__m256i *input, __m256i *output,
257                                              const int size, const int bit) {
258   if (bit > 0) {
259     int i;
260     for (i = 0; i < size; i++) {
261       output[i] = round_shift_32_avx2(input[i], bit);
262     }
263   } else {
264     int i;
265     for (i = 0; i < size; i++) {
266       output[i] = _mm256_slli_epi32(input[i], -bit);
267     }
268   }
269 }
270 
round_shift_rect_array_32_avx2(__m256i * input,__m256i * output,const int size,const int bit,const int val)271 static inline void round_shift_rect_array_32_avx2(__m256i *input,
272                                                   __m256i *output,
273                                                   const int size, const int bit,
274                                                   const int val) {
275   const __m256i sqrt2 = _mm256_set1_epi32(val);
276   if (bit > 0) {
277     int i;
278     for (i = 0; i < size; i++) {
279       const __m256i r0 = round_shift_32_avx2(input[i], bit);
280       const __m256i r1 = _mm256_mullo_epi32(sqrt2, r0);
281       output[i] = round_shift_32_avx2(r1, NewSqrt2Bits);
282     }
283   } else {
284     int i;
285     for (i = 0; i < size; i++) {
286       const __m256i r0 = _mm256_slli_epi32(input[i], -bit);
287       const __m256i r1 = _mm256_mullo_epi32(sqrt2, r0);
288       output[i] = round_shift_32_avx2(r1, NewSqrt2Bits);
289     }
290   }
291 }
292 
scale_round_avx2(const __m256i a,const int scale)293 static inline __m256i scale_round_avx2(const __m256i a, const int scale) {
294   const __m256i scale_rounding =
295       pair_set_w16_epi16(scale, 1 << (NewSqrt2Bits - 1));
296   const __m256i b = _mm256_madd_epi16(a, scale_rounding);
297   return _mm256_srai_epi32(b, NewSqrt2Bits);
298 }
299 
store_rect_16bit_to_32bit_w8_avx2(const __m256i a,int32_t * const b)300 static inline void store_rect_16bit_to_32bit_w8_avx2(const __m256i a,
301                                                      int32_t *const b) {
302   const __m256i one = _mm256_set1_epi16(1);
303   const __m256i a_lo = _mm256_unpacklo_epi16(a, one);
304   const __m256i a_hi = _mm256_unpackhi_epi16(a, one);
305   const __m256i b_lo = scale_round_avx2(a_lo, NewSqrt2);
306   const __m256i b_hi = scale_round_avx2(a_hi, NewSqrt2);
307   const __m256i temp = _mm256_permute2f128_si256(b_lo, b_hi, 0x31);
308   _mm_store_si128((__m128i *)b, _mm256_castsi256_si128(b_lo));
309   _mm_store_si128((__m128i *)(b + 4), _mm256_castsi256_si128(b_hi));
310   _mm256_store_si256((__m256i *)(b + 64), temp);
311 }
312 
store_rect_buffer_16bit_to_32bit_w8_avx2(const __m256i * const in,int32_t * const out,const int stride,const int out_size)313 static inline void store_rect_buffer_16bit_to_32bit_w8_avx2(
314     const __m256i *const in, int32_t *const out, const int stride,
315     const int out_size) {
316   for (int i = 0; i < out_size; ++i) {
317     store_rect_16bit_to_32bit_w8_avx2(in[i], out + i * stride);
318   }
319 }
320 
pack_reg(const __m128i * in1,const __m128i * in2,__m256i * out)321 static inline void pack_reg(const __m128i *in1, const __m128i *in2,
322                             __m256i *out) {
323   out[0] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[0]), in2[0], 0x1);
324   out[1] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[1]), in2[1], 0x1);
325   out[2] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[2]), in2[2], 0x1);
326   out[3] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[3]), in2[3], 0x1);
327   out[4] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[4]), in2[4], 0x1);
328   out[5] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[5]), in2[5], 0x1);
329   out[6] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[6]), in2[6], 0x1);
330   out[7] = _mm256_insertf128_si256(_mm256_castsi128_si256(in1[7]), in2[7], 0x1);
331 }
332 
extract_reg(const __m256i * in,__m128i * out1)333 static inline void extract_reg(const __m256i *in, __m128i *out1) {
334   out1[0] = _mm256_castsi256_si128(in[0]);
335   out1[1] = _mm256_castsi256_si128(in[1]);
336   out1[2] = _mm256_castsi256_si128(in[2]);
337   out1[3] = _mm256_castsi256_si128(in[3]);
338   out1[4] = _mm256_castsi256_si128(in[4]);
339   out1[5] = _mm256_castsi256_si128(in[5]);
340   out1[6] = _mm256_castsi256_si128(in[6]);
341   out1[7] = _mm256_castsi256_si128(in[7]);
342 
343   out1[8] = _mm256_extracti128_si256(in[0], 0x01);
344   out1[9] = _mm256_extracti128_si256(in[1], 0x01);
345   out1[10] = _mm256_extracti128_si256(in[2], 0x01);
346   out1[11] = _mm256_extracti128_si256(in[3], 0x01);
347   out1[12] = _mm256_extracti128_si256(in[4], 0x01);
348   out1[13] = _mm256_extracti128_si256(in[5], 0x01);
349   out1[14] = _mm256_extracti128_si256(in[6], 0x01);
350   out1[15] = _mm256_extracti128_si256(in[7], 0x01);
351 }
352 
353 #ifdef __cplusplus
354 }
355 #endif
356 
357 #endif  // AOM_AOM_DSP_X86_TXFM_COMMON_AVX2_H_
358