xref: /aosp_15_r20/external/libaom/av1/encoder/x86/highbd_fwd_txfm_sse4.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 #include <assert.h>
12 #include <smmintrin.h> /* SSE4.1 */
13 
14 #include "aom_dsp/txfm_common.h"
15 #include "aom_dsp/x86/transpose_sse2.h"
16 #include "aom_dsp/x86/txfm_common_sse2.h"
17 #include "aom_ports/mem.h"
18 #include "av1/common/av1_txfm.h"
19 #include "av1/common/x86/highbd_txfm_utility_sse4.h"
20 #include "av1/encoder/av1_fwd_txfm1d_cfg.h"
21 #include "av1/encoder/x86/av1_txfm1d_sse4.h"
22 #include "config/aom_config.h"
23 #include "config/av1_rtcd.h"
24 
store_output_w4(int32_t * const out,const __m128i * const in,const int stride,const int out_size)25 static inline void store_output_w4(int32_t *const out, const __m128i *const in,
26                                    const int stride, const int out_size) {
27   for (int i = 0; i < out_size; ++i) {
28     _mm_store_si128((__m128i *)(out + i * stride), in[i]);
29   }
30 }
31 
av1_fwht4x4_sse4_1(const int16_t * input,tran_low_t * output,int stride)32 void av1_fwht4x4_sse4_1(const int16_t *input, tran_low_t *output, int stride) {
33   __m128i in[4];
34   in[0] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride));
35   in[1] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride));
36   in[2] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride));
37   in[3] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride));
38 
39   // Convert to int32_t.
40   __m128i op[4];
41   op[0] = _mm_cvtepi16_epi32(in[0]);
42   op[1] = _mm_cvtepi16_epi32(in[1]);
43   op[2] = _mm_cvtepi16_epi32(in[2]);
44   op[3] = _mm_cvtepi16_epi32(in[3]);
45 
46   for (int i = 0; i < 2; ++i) {
47     __m128i a1 = op[0];
48     __m128i b1 = op[1];
49     __m128i c1 = op[2];
50     __m128i d1 = op[3];
51     __m128i e1;
52 
53     a1 = _mm_add_epi32(a1, b1);  // a1 += b1
54     d1 = _mm_sub_epi32(d1, c1);  // d1 = d1 - c1
55     e1 = _mm_sub_epi32(a1, d1);  // e1 = (a1 - d1) >> 1
56     e1 = _mm_srai_epi32(e1, 1);
57     b1 = _mm_sub_epi32(e1, b1);  // b1 = e1 - b1
58     c1 = _mm_sub_epi32(e1, c1);  // c1 = e1 - c1
59     a1 = _mm_sub_epi32(a1, c1);  // a1 -= c1
60     d1 = _mm_add_epi32(d1, b1);  // d1 += b1
61 
62     op[0] = a1;
63     op[1] = c1;
64     op[2] = d1;
65     op[3] = b1;
66 
67     if (i == 0) {
68       transpose_32bit_4x4(op, op);
69     }
70   }
71 
72   op[0] = _mm_slli_epi32(op[0], UNIT_QUANT_SHIFT);
73   op[1] = _mm_slli_epi32(op[1], UNIT_QUANT_SHIFT);
74   op[2] = _mm_slli_epi32(op[2], UNIT_QUANT_SHIFT);
75   op[3] = _mm_slli_epi32(op[3], UNIT_QUANT_SHIFT);
76 
77   _mm_storeu_si128((__m128i *)(output + 0), op[0]);
78   _mm_storeu_si128((__m128i *)(output + 4), op[1]);
79   _mm_storeu_si128((__m128i *)(output + 8), op[2]);
80   _mm_storeu_si128((__m128i *)(output + 12), op[3]);
81 }
82 
load_buffer_4x4(const int16_t * input,__m128i * in,int stride,int flipud,int fliplr,int shift)83 static inline void load_buffer_4x4(const int16_t *input, __m128i *in,
84                                    int stride, int flipud, int fliplr,
85                                    int shift) {
86   if (!flipud) {
87     in[0] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride));
88     in[1] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride));
89     in[2] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride));
90     in[3] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride));
91   } else {
92     in[0] = _mm_loadl_epi64((const __m128i *)(input + 3 * stride));
93     in[1] = _mm_loadl_epi64((const __m128i *)(input + 2 * stride));
94     in[2] = _mm_loadl_epi64((const __m128i *)(input + 1 * stride));
95     in[3] = _mm_loadl_epi64((const __m128i *)(input + 0 * stride));
96   }
97 
98   if (fliplr) {
99     in[0] = _mm_shufflelo_epi16(in[0], 0x1b);
100     in[1] = _mm_shufflelo_epi16(in[1], 0x1b);
101     in[2] = _mm_shufflelo_epi16(in[2], 0x1b);
102     in[3] = _mm_shufflelo_epi16(in[3], 0x1b);
103   }
104 
105   in[0] = _mm_cvtepi16_epi32(in[0]);
106   in[1] = _mm_cvtepi16_epi32(in[1]);
107   in[2] = _mm_cvtepi16_epi32(in[2]);
108   in[3] = _mm_cvtepi16_epi32(in[3]);
109 
110   in[0] = _mm_slli_epi32(in[0], shift);
111   in[1] = _mm_slli_epi32(in[1], shift);
112   in[2] = _mm_slli_epi32(in[2], shift);
113   in[3] = _mm_slli_epi32(in[3], shift);
114 }
115 
116 // We only use stage-2 bit;
117 // shift[0] is used in load_buffer_4x4()
118 // shift[1] is used in txfm_func_col()
119 // shift[2] is used in txfm_func_row()
fdct4x4_sse4_1(__m128i * in,__m128i * out,int bit,const int num_col)120 static void fdct4x4_sse4_1(__m128i *in, __m128i *out, int bit,
121                            const int num_col) {
122   const int32_t *cospi = cospi_arr(bit);
123   const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
124   const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
125   const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
126   const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
127   __m128i s0, s1, s2, s3;
128   __m128i u0, u1, u2, u3;
129   __m128i v0, v1, v2, v3;
130 
131   int endidx = 3 * num_col;
132   s0 = _mm_add_epi32(in[0], in[endidx]);
133   s3 = _mm_sub_epi32(in[0], in[endidx]);
134   endidx -= num_col;
135   s1 = _mm_add_epi32(in[num_col], in[endidx]);
136   s2 = _mm_sub_epi32(in[num_col], in[endidx]);
137 
138   // btf_32_sse4_1_type0(cospi32, cospi32, s[01], u[02], bit);
139   u0 = _mm_mullo_epi32(s0, cospi32);
140   u1 = _mm_mullo_epi32(s1, cospi32);
141   u2 = _mm_add_epi32(u0, u1);
142   v0 = _mm_sub_epi32(u0, u1);
143 
144   u3 = _mm_add_epi32(u2, rnding);
145   v1 = _mm_add_epi32(v0, rnding);
146 
147   u0 = _mm_srai_epi32(u3, bit);
148   u2 = _mm_srai_epi32(v1, bit);
149 
150   // btf_32_sse4_1_type1(cospi48, cospi16, s[23], u[13], bit);
151   v0 = _mm_mullo_epi32(s2, cospi48);
152   v1 = _mm_mullo_epi32(s3, cospi16);
153   v2 = _mm_add_epi32(v0, v1);
154 
155   v3 = _mm_add_epi32(v2, rnding);
156   u1 = _mm_srai_epi32(v3, bit);
157 
158   v0 = _mm_mullo_epi32(s2, cospi16);
159   v1 = _mm_mullo_epi32(s3, cospi48);
160   v2 = _mm_sub_epi32(v1, v0);
161 
162   v3 = _mm_add_epi32(v2, rnding);
163   u3 = _mm_srai_epi32(v3, bit);
164 
165   // Note: shift[1] and shift[2] are zeros
166 
167   out[0] = u0;
168   out[1] = u1;
169   out[2] = u2;
170   out[3] = u3;
171 }
172 
write_buffer_4x4(__m128i * res,int32_t * output)173 static inline void write_buffer_4x4(__m128i *res, int32_t *output) {
174   _mm_store_si128((__m128i *)(output + 0 * 4), res[0]);
175   _mm_store_si128((__m128i *)(output + 1 * 4), res[1]);
176   _mm_store_si128((__m128i *)(output + 2 * 4), res[2]);
177   _mm_store_si128((__m128i *)(output + 3 * 4), res[3]);
178 }
179 
fadst4x4_sse4_1(__m128i * in,__m128i * out,int bit,const int num_col)180 static void fadst4x4_sse4_1(__m128i *in, __m128i *out, int bit,
181                             const int num_col) {
182   const int32_t *sinpi = sinpi_arr(bit);
183   const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
184   const __m128i sinpi1 = _mm_set1_epi32((int)sinpi[1]);
185   const __m128i sinpi2 = _mm_set1_epi32((int)sinpi[2]);
186   const __m128i sinpi3 = _mm_set1_epi32((int)sinpi[3]);
187   const __m128i sinpi4 = _mm_set1_epi32((int)sinpi[4]);
188   __m128i t;
189   __m128i s0, s1, s2, s3, s4, s5, s6, s7;
190   __m128i x0, x1, x2, x3;
191   __m128i u0, u1, u2, u3;
192 
193   int idx = 0 * num_col;
194   s0 = _mm_mullo_epi32(in[idx], sinpi1);
195   s1 = _mm_mullo_epi32(in[idx], sinpi4);
196   t = _mm_add_epi32(in[idx], in[idx + num_col]);
197   idx += num_col;
198   s2 = _mm_mullo_epi32(in[idx], sinpi2);
199   s3 = _mm_mullo_epi32(in[idx], sinpi1);
200   idx += num_col;
201   s4 = _mm_mullo_epi32(in[idx], sinpi3);
202   idx += num_col;
203   s5 = _mm_mullo_epi32(in[idx], sinpi4);
204   s6 = _mm_mullo_epi32(in[idx], sinpi2);
205   s7 = _mm_sub_epi32(t, in[idx]);
206 
207   t = _mm_add_epi32(s0, s2);
208   x0 = _mm_add_epi32(t, s5);
209   x1 = _mm_mullo_epi32(s7, sinpi3);
210   t = _mm_sub_epi32(s1, s3);
211   x2 = _mm_add_epi32(t, s6);
212   x3 = s4;
213 
214   s0 = _mm_add_epi32(x0, x3);
215   s1 = x1;
216   s2 = _mm_sub_epi32(x2, x3);
217   t = _mm_sub_epi32(x2, x0);
218   s3 = _mm_add_epi32(t, x3);
219 
220   u0 = _mm_add_epi32(s0, rnding);
221   u0 = _mm_srai_epi32(u0, bit);
222 
223   u1 = _mm_add_epi32(s1, rnding);
224   u1 = _mm_srai_epi32(u1, bit);
225 
226   u2 = _mm_add_epi32(s2, rnding);
227   u2 = _mm_srai_epi32(u2, bit);
228 
229   u3 = _mm_add_epi32(s3, rnding);
230   u3 = _mm_srai_epi32(u3, bit);
231 
232   out[0] = u0;
233   out[1] = u1;
234   out[2] = u2;
235   out[3] = u3;
236 }
idtx4x4_sse4_1(__m128i * in,__m128i * out,int bit,int col_num)237 static void idtx4x4_sse4_1(__m128i *in, __m128i *out, int bit, int col_num) {
238   (void)bit;
239   __m128i fact = _mm_set1_epi32(NewSqrt2);
240   __m128i offset = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
241   __m128i a_low;
242 
243   for (int i = 0; i < 4; i++) {
244     a_low = _mm_mullo_epi32(in[i * col_num], fact);
245     a_low = _mm_add_epi32(a_low, offset);
246     out[i] = _mm_srai_epi32(a_low, NewSqrt2Bits);
247   }
248 }
av1_fwd_txfm2d_4x4_sse4_1(const int16_t * input,int32_t * coeff,int input_stride,TX_TYPE tx_type,int bd)249 void av1_fwd_txfm2d_4x4_sse4_1(const int16_t *input, int32_t *coeff,
250                                int input_stride, TX_TYPE tx_type, int bd) {
251   __m128i in[4];
252   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_4X4];
253   const int txw_idx = get_txw_idx(TX_4X4);
254   const int txh_idx = get_txh_idx(TX_4X4);
255 
256   switch (tx_type) {
257     case DCT_DCT:
258       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
259       fdct4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
260       transpose_32bit_4x4(in, in);
261       fdct4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
262       write_buffer_4x4(in, coeff);
263       break;
264     case ADST_DCT:
265       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
266       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
267       transpose_32bit_4x4(in, in);
268       fdct4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
269       write_buffer_4x4(in, coeff);
270       break;
271     case DCT_ADST:
272       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
273       fdct4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
274       transpose_32bit_4x4(in, in);
275       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
276       write_buffer_4x4(in, coeff);
277       break;
278     case ADST_ADST:
279       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
280       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
281       transpose_32bit_4x4(in, in);
282       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
283       write_buffer_4x4(in, coeff);
284       break;
285     case FLIPADST_DCT:
286       load_buffer_4x4(input, in, input_stride, 1, 0, shift[0]);
287       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
288       transpose_32bit_4x4(in, in);
289       fdct4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
290       write_buffer_4x4(in, coeff);
291       break;
292     case DCT_FLIPADST:
293       load_buffer_4x4(input, in, input_stride, 0, 1, shift[0]);
294       fdct4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
295       transpose_32bit_4x4(in, in);
296       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
297       write_buffer_4x4(in, coeff);
298       break;
299     case FLIPADST_FLIPADST:
300       load_buffer_4x4(input, in, input_stride, 1, 1, shift[0]);
301       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
302       transpose_32bit_4x4(in, in);
303       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
304       write_buffer_4x4(in, coeff);
305       break;
306     case ADST_FLIPADST:
307       load_buffer_4x4(input, in, input_stride, 0, 1, shift[0]);
308       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
309       transpose_32bit_4x4(in, in);
310       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
311       write_buffer_4x4(in, coeff);
312       break;
313     case FLIPADST_ADST:
314       load_buffer_4x4(input, in, input_stride, 1, 0, shift[0]);
315       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
316       transpose_32bit_4x4(in, in);
317       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
318       write_buffer_4x4(in, coeff);
319       break;
320     case IDTX:
321       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
322       idtx4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
323       transpose_32bit_4x4(in, in);
324       idtx4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
325       write_buffer_4x4(in, coeff);
326       break;
327     case V_DCT:
328       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
329       fdct4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
330       transpose_32bit_4x4(in, in);
331       idtx4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
332       write_buffer_4x4(in, coeff);
333       break;
334     case H_DCT:
335       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
336       idtx4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
337       transpose_32bit_4x4(in, in);
338       fdct4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
339       write_buffer_4x4(in, coeff);
340       break;
341     case V_ADST:
342       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
343       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
344       transpose_32bit_4x4(in, in);
345       idtx4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
346       write_buffer_4x4(in, coeff);
347       break;
348     case H_ADST:
349       load_buffer_4x4(input, in, input_stride, 0, 0, shift[0]);
350       idtx4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
351       transpose_32bit_4x4(in, in);
352       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_col[txw_idx][txh_idx], 1);
353       write_buffer_4x4(in, coeff);
354       break;
355     case V_FLIPADST:
356       load_buffer_4x4(input, in, input_stride, 1, 0, shift[0]);
357       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
358       transpose_32bit_4x4(in, in);
359       idtx4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
360       write_buffer_4x4(in, coeff);
361       break;
362     case H_FLIPADST:
363       load_buffer_4x4(input, in, input_stride, 0, 1, shift[0]);
364       idtx4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
365       transpose_32bit_4x4(in, in);
366       fadst4x4_sse4_1(in, in, av1_fwd_cos_bit_row[txw_idx][txh_idx], 1);
367       write_buffer_4x4(in, coeff);
368       break;
369     default: assert(0);
370   }
371   (void)bd;
372 }
373 
load_buffer_8x8(const int16_t * input,__m128i * in,int stride,int flipud,int fliplr,int shift)374 static inline void load_buffer_8x8(const int16_t *input, __m128i *in,
375                                    int stride, int flipud, int fliplr,
376                                    int shift) {
377   __m128i u;
378   if (!flipud) {
379     in[0] = _mm_load_si128((const __m128i *)(input + 0 * stride));
380     in[1] = _mm_load_si128((const __m128i *)(input + 1 * stride));
381     in[2] = _mm_load_si128((const __m128i *)(input + 2 * stride));
382     in[3] = _mm_load_si128((const __m128i *)(input + 3 * stride));
383     in[4] = _mm_load_si128((const __m128i *)(input + 4 * stride));
384     in[5] = _mm_load_si128((const __m128i *)(input + 5 * stride));
385     in[6] = _mm_load_si128((const __m128i *)(input + 6 * stride));
386     in[7] = _mm_load_si128((const __m128i *)(input + 7 * stride));
387   } else {
388     in[0] = _mm_load_si128((const __m128i *)(input + 7 * stride));
389     in[1] = _mm_load_si128((const __m128i *)(input + 6 * stride));
390     in[2] = _mm_load_si128((const __m128i *)(input + 5 * stride));
391     in[3] = _mm_load_si128((const __m128i *)(input + 4 * stride));
392     in[4] = _mm_load_si128((const __m128i *)(input + 3 * stride));
393     in[5] = _mm_load_si128((const __m128i *)(input + 2 * stride));
394     in[6] = _mm_load_si128((const __m128i *)(input + 1 * stride));
395     in[7] = _mm_load_si128((const __m128i *)(input + 0 * stride));
396   }
397 
398   if (fliplr) {
399     in[0] = mm_reverse_epi16(in[0]);
400     in[1] = mm_reverse_epi16(in[1]);
401     in[2] = mm_reverse_epi16(in[2]);
402     in[3] = mm_reverse_epi16(in[3]);
403     in[4] = mm_reverse_epi16(in[4]);
404     in[5] = mm_reverse_epi16(in[5]);
405     in[6] = mm_reverse_epi16(in[6]);
406     in[7] = mm_reverse_epi16(in[7]);
407   }
408 
409   u = _mm_unpackhi_epi64(in[4], in[4]);
410   in[8] = _mm_cvtepi16_epi32(in[4]);
411   in[9] = _mm_cvtepi16_epi32(u);
412 
413   u = _mm_unpackhi_epi64(in[5], in[5]);
414   in[10] = _mm_cvtepi16_epi32(in[5]);
415   in[11] = _mm_cvtepi16_epi32(u);
416 
417   u = _mm_unpackhi_epi64(in[6], in[6]);
418   in[12] = _mm_cvtepi16_epi32(in[6]);
419   in[13] = _mm_cvtepi16_epi32(u);
420 
421   u = _mm_unpackhi_epi64(in[7], in[7]);
422   in[14] = _mm_cvtepi16_epi32(in[7]);
423   in[15] = _mm_cvtepi16_epi32(u);
424 
425   u = _mm_unpackhi_epi64(in[3], in[3]);
426   in[6] = _mm_cvtepi16_epi32(in[3]);
427   in[7] = _mm_cvtepi16_epi32(u);
428 
429   u = _mm_unpackhi_epi64(in[2], in[2]);
430   in[4] = _mm_cvtepi16_epi32(in[2]);
431   in[5] = _mm_cvtepi16_epi32(u);
432 
433   u = _mm_unpackhi_epi64(in[1], in[1]);
434   in[2] = _mm_cvtepi16_epi32(in[1]);
435   in[3] = _mm_cvtepi16_epi32(u);
436 
437   u = _mm_unpackhi_epi64(in[0], in[0]);
438   in[0] = _mm_cvtepi16_epi32(in[0]);
439   in[1] = _mm_cvtepi16_epi32(u);
440 
441   in[0] = _mm_slli_epi32(in[0], shift);
442   in[1] = _mm_slli_epi32(in[1], shift);
443   in[2] = _mm_slli_epi32(in[2], shift);
444   in[3] = _mm_slli_epi32(in[3], shift);
445   in[4] = _mm_slli_epi32(in[4], shift);
446   in[5] = _mm_slli_epi32(in[5], shift);
447   in[6] = _mm_slli_epi32(in[6], shift);
448   in[7] = _mm_slli_epi32(in[7], shift);
449 
450   in[8] = _mm_slli_epi32(in[8], shift);
451   in[9] = _mm_slli_epi32(in[9], shift);
452   in[10] = _mm_slli_epi32(in[10], shift);
453   in[11] = _mm_slli_epi32(in[11], shift);
454   in[12] = _mm_slli_epi32(in[12], shift);
455   in[13] = _mm_slli_epi32(in[13], shift);
456   in[14] = _mm_slli_epi32(in[14], shift);
457   in[15] = _mm_slli_epi32(in[15], shift);
458 }
459 
col_txfm_8x8_rounding(__m128i * in,int shift)460 static inline void col_txfm_8x8_rounding(__m128i *in, int shift) {
461   const __m128i rounding = _mm_set1_epi32(1 << (shift - 1));
462 
463   in[0] = _mm_add_epi32(in[0], rounding);
464   in[1] = _mm_add_epi32(in[1], rounding);
465   in[2] = _mm_add_epi32(in[2], rounding);
466   in[3] = _mm_add_epi32(in[3], rounding);
467   in[4] = _mm_add_epi32(in[4], rounding);
468   in[5] = _mm_add_epi32(in[5], rounding);
469   in[6] = _mm_add_epi32(in[6], rounding);
470   in[7] = _mm_add_epi32(in[7], rounding);
471   in[8] = _mm_add_epi32(in[8], rounding);
472   in[9] = _mm_add_epi32(in[9], rounding);
473   in[10] = _mm_add_epi32(in[10], rounding);
474   in[11] = _mm_add_epi32(in[11], rounding);
475   in[12] = _mm_add_epi32(in[12], rounding);
476   in[13] = _mm_add_epi32(in[13], rounding);
477   in[14] = _mm_add_epi32(in[14], rounding);
478   in[15] = _mm_add_epi32(in[15], rounding);
479 
480   in[0] = _mm_srai_epi32(in[0], shift);
481   in[1] = _mm_srai_epi32(in[1], shift);
482   in[2] = _mm_srai_epi32(in[2], shift);
483   in[3] = _mm_srai_epi32(in[3], shift);
484   in[4] = _mm_srai_epi32(in[4], shift);
485   in[5] = _mm_srai_epi32(in[5], shift);
486   in[6] = _mm_srai_epi32(in[6], shift);
487   in[7] = _mm_srai_epi32(in[7], shift);
488   in[8] = _mm_srai_epi32(in[8], shift);
489   in[9] = _mm_srai_epi32(in[9], shift);
490   in[10] = _mm_srai_epi32(in[10], shift);
491   in[11] = _mm_srai_epi32(in[11], shift);
492   in[12] = _mm_srai_epi32(in[12], shift);
493   in[13] = _mm_srai_epi32(in[13], shift);
494   in[14] = _mm_srai_epi32(in[14], shift);
495   in[15] = _mm_srai_epi32(in[15], shift);
496 }
497 
col_txfm_4x8_rounding(__m128i * in,int shift)498 static inline void col_txfm_4x8_rounding(__m128i *in, int shift) {
499   const __m128i rounding = _mm_set1_epi32(1 << (shift - 1));
500 
501   in[0] = _mm_add_epi32(in[0], rounding);
502   in[1] = _mm_add_epi32(in[1], rounding);
503   in[2] = _mm_add_epi32(in[2], rounding);
504   in[3] = _mm_add_epi32(in[3], rounding);
505   in[4] = _mm_add_epi32(in[4], rounding);
506   in[5] = _mm_add_epi32(in[5], rounding);
507   in[6] = _mm_add_epi32(in[6], rounding);
508   in[7] = _mm_add_epi32(in[7], rounding);
509 
510   in[0] = _mm_srai_epi32(in[0], shift);
511   in[1] = _mm_srai_epi32(in[1], shift);
512   in[2] = _mm_srai_epi32(in[2], shift);
513   in[3] = _mm_srai_epi32(in[3], shift);
514   in[4] = _mm_srai_epi32(in[4], shift);
515   in[5] = _mm_srai_epi32(in[5], shift);
516   in[6] = _mm_srai_epi32(in[6], shift);
517   in[7] = _mm_srai_epi32(in[7], shift);
518 }
519 
write_buffer_8x8(const __m128i * res,int32_t * output)520 static inline void write_buffer_8x8(const __m128i *res, int32_t *output) {
521   _mm_store_si128((__m128i *)(output + 0 * 4), res[0]);
522   _mm_store_si128((__m128i *)(output + 1 * 4), res[1]);
523   _mm_store_si128((__m128i *)(output + 2 * 4), res[2]);
524   _mm_store_si128((__m128i *)(output + 3 * 4), res[3]);
525 
526   _mm_store_si128((__m128i *)(output + 4 * 4), res[4]);
527   _mm_store_si128((__m128i *)(output + 5 * 4), res[5]);
528   _mm_store_si128((__m128i *)(output + 6 * 4), res[6]);
529   _mm_store_si128((__m128i *)(output + 7 * 4), res[7]);
530 
531   _mm_store_si128((__m128i *)(output + 8 * 4), res[8]);
532   _mm_store_si128((__m128i *)(output + 9 * 4), res[9]);
533   _mm_store_si128((__m128i *)(output + 10 * 4), res[10]);
534   _mm_store_si128((__m128i *)(output + 11 * 4), res[11]);
535 
536   _mm_store_si128((__m128i *)(output + 12 * 4), res[12]);
537   _mm_store_si128((__m128i *)(output + 13 * 4), res[13]);
538   _mm_store_si128((__m128i *)(output + 14 * 4), res[14]);
539   _mm_store_si128((__m128i *)(output + 15 * 4), res[15]);
540 }
541 
write_buffer_16x8(const __m128i * res,int32_t * output,const int stride)542 static inline void write_buffer_16x8(const __m128i *res, int32_t *output,
543                                      const int stride) {
544   _mm_storeu_si128((__m128i *)(output), res[0]);
545   _mm_storeu_si128((__m128i *)(output + 4), res[1]);
546   _mm_storeu_si128((__m128i *)(output + stride), res[2]);
547   _mm_storeu_si128((__m128i *)(output + stride + 4), res[3]);
548 
549   _mm_storeu_si128((__m128i *)(output + (stride * 2)), res[4]);
550   _mm_storeu_si128((__m128i *)(output + (stride * 2) + 4), res[5]);
551   _mm_storeu_si128((__m128i *)(output + (stride * 3)), res[6]);
552   _mm_storeu_si128((__m128i *)(output + (stride * 3) + 4), res[7]);
553 
554   _mm_storeu_si128((__m128i *)(output + (stride * 4)), res[8]);
555   _mm_storeu_si128((__m128i *)(output + (stride * 4) + 4), res[9]);
556   _mm_storeu_si128((__m128i *)(output + (stride * 5)), res[10]);
557   _mm_storeu_si128((__m128i *)(output + (stride * 5) + 4), res[11]);
558 
559   _mm_storeu_si128((__m128i *)(output + (stride * 6)), res[12]);
560   _mm_storeu_si128((__m128i *)(output + (stride * 6) + 4), res[13]);
561   _mm_storeu_si128((__m128i *)(output + (stride * 7)), res[14]);
562   _mm_storeu_si128((__m128i *)(output + (stride * 7) + 4), res[15]);
563 }
564 
fdct4x8_sse4_1(__m128i * in,__m128i * out,int bit,const int col_num)565 static void fdct4x8_sse4_1(__m128i *in, __m128i *out, int bit,
566                            const int col_num) {
567   const int32_t *cospi = cospi_arr(bit);
568   const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
569   const __m128i cospim32 = _mm_set1_epi32(-cospi[32]);
570   const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
571   const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
572   const __m128i cospi56 = _mm_set1_epi32(cospi[56]);
573   const __m128i cospi8 = _mm_set1_epi32(cospi[8]);
574   const __m128i cospi24 = _mm_set1_epi32(cospi[24]);
575   const __m128i cospi40 = _mm_set1_epi32(cospi[40]);
576   const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
577   __m128i u[8], v[8];
578 
579   int startidx = 0 * col_num;
580   int endidx = 7 * col_num;
581   // Even 8 points 0, 2, ..., 14
582   // stage 0
583   // stage 1
584   u[0] = _mm_add_epi32(in[startidx], in[endidx]);
585   v[7] = _mm_sub_epi32(in[startidx], in[endidx]);  // v[7]
586   startidx += col_num;
587   endidx -= col_num;
588   u[1] = _mm_add_epi32(in[startidx], in[endidx]);
589   u[6] = _mm_sub_epi32(in[startidx], in[endidx]);
590   startidx += col_num;
591   endidx -= col_num;
592   u[2] = _mm_add_epi32(in[startidx], in[endidx]);
593   u[5] = _mm_sub_epi32(in[startidx], in[endidx]);
594   startidx += col_num;
595   endidx -= col_num;
596   u[3] = _mm_add_epi32(in[startidx], in[endidx]);
597   v[4] = _mm_sub_epi32(in[startidx], in[endidx]);  // v[4]
598 
599   // stage 2
600   v[0] = _mm_add_epi32(u[0], u[3]);
601   v[3] = _mm_sub_epi32(u[0], u[3]);
602   v[1] = _mm_add_epi32(u[1], u[2]);
603   v[2] = _mm_sub_epi32(u[1], u[2]);
604 
605   v[5] = _mm_mullo_epi32(u[5], cospim32);
606   v[6] = _mm_mullo_epi32(u[6], cospi32);
607   v[5] = _mm_add_epi32(v[5], v[6]);
608   v[5] = _mm_add_epi32(v[5], rnding);
609   v[5] = _mm_srai_epi32(v[5], bit);
610 
611   u[0] = _mm_mullo_epi32(u[5], cospi32);
612   v[6] = _mm_mullo_epi32(u[6], cospim32);
613   v[6] = _mm_sub_epi32(u[0], v[6]);
614   v[6] = _mm_add_epi32(v[6], rnding);
615   v[6] = _mm_srai_epi32(v[6], bit);
616 
617   // stage 3
618   // type 0
619   v[0] = _mm_mullo_epi32(v[0], cospi32);
620   v[1] = _mm_mullo_epi32(v[1], cospi32);
621   u[0] = _mm_add_epi32(v[0], v[1]);
622   u[0] = _mm_add_epi32(u[0], rnding);
623   u[0] = _mm_srai_epi32(u[0], bit);
624 
625   u[1] = _mm_sub_epi32(v[0], v[1]);
626   u[1] = _mm_add_epi32(u[1], rnding);
627   u[1] = _mm_srai_epi32(u[1], bit);
628 
629   // type 1
630   v[0] = _mm_mullo_epi32(v[2], cospi48);
631   v[1] = _mm_mullo_epi32(v[3], cospi16);
632   u[2] = _mm_add_epi32(v[0], v[1]);
633   u[2] = _mm_add_epi32(u[2], rnding);
634   u[2] = _mm_srai_epi32(u[2], bit);
635 
636   v[0] = _mm_mullo_epi32(v[2], cospi16);
637   v[1] = _mm_mullo_epi32(v[3], cospi48);
638   u[3] = _mm_sub_epi32(v[1], v[0]);
639   u[3] = _mm_add_epi32(u[3], rnding);
640   u[3] = _mm_srai_epi32(u[3], bit);
641 
642   u[4] = _mm_add_epi32(v[4], v[5]);
643   u[5] = _mm_sub_epi32(v[4], v[5]);
644   u[6] = _mm_sub_epi32(v[7], v[6]);
645   u[7] = _mm_add_epi32(v[7], v[6]);
646 
647   // stage 4
648   // stage 5
649   v[0] = _mm_mullo_epi32(u[4], cospi56);
650   v[1] = _mm_mullo_epi32(u[7], cospi8);
651   v[0] = _mm_add_epi32(v[0], v[1]);
652   v[0] = _mm_add_epi32(v[0], rnding);
653   out[1 * col_num] = _mm_srai_epi32(v[0], bit);  // buf0[4]
654 
655   v[0] = _mm_mullo_epi32(u[4], cospi8);
656   v[1] = _mm_mullo_epi32(u[7], cospi56);
657   v[0] = _mm_sub_epi32(v[1], v[0]);
658   v[0] = _mm_add_epi32(v[0], rnding);
659   out[7 * col_num] = _mm_srai_epi32(v[0], bit);  // buf0[7]
660 
661   v[0] = _mm_mullo_epi32(u[5], cospi24);
662   v[1] = _mm_mullo_epi32(u[6], cospi40);
663   v[0] = _mm_add_epi32(v[0], v[1]);
664   v[0] = _mm_add_epi32(v[0], rnding);
665   out[5 * col_num] = _mm_srai_epi32(v[0], bit);  // buf0[5]
666 
667   v[0] = _mm_mullo_epi32(u[5], cospi40);
668   v[1] = _mm_mullo_epi32(u[6], cospi24);
669   v[0] = _mm_sub_epi32(v[1], v[0]);
670   v[0] = _mm_add_epi32(v[0], rnding);
671   out[3 * col_num] = _mm_srai_epi32(v[0], bit);  // buf0[6]
672 
673   out[0 * col_num] = u[0];  // buf0[0]
674   out[4 * col_num] = u[1];  // buf0[1]
675   out[2 * col_num] = u[2];  // buf0[2]
676   out[6 * col_num] = u[3];  // buf0[3]
677 }
678 
fdct8x8_sse4_1(__m128i * in,__m128i * out,int bit,const int col_num)679 static void fdct8x8_sse4_1(__m128i *in, __m128i *out, int bit,
680                            const int col_num) {
681   fdct4x8_sse4_1(in, out, bit, col_num);
682   fdct4x8_sse4_1(in + 1, out + 1, bit, col_num);
683 }
684 
fadst8x8_sse4_1(__m128i * in,__m128i * out,int bit,const int col_num)685 static void fadst8x8_sse4_1(__m128i *in, __m128i *out, int bit,
686                             const int col_num) {
687   const int32_t *cospi = cospi_arr(bit);
688   const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
689   const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
690   const __m128i cospim16 = _mm_set1_epi32(-cospi[16]);
691   const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
692   const __m128i cospim48 = _mm_set1_epi32(-cospi[48]);
693   const __m128i cospi4 = _mm_set1_epi32(cospi[4]);
694   const __m128i cospim4 = _mm_set1_epi32(-cospi[4]);
695   const __m128i cospi60 = _mm_set1_epi32(cospi[60]);
696   const __m128i cospi20 = _mm_set1_epi32(cospi[20]);
697   const __m128i cospim20 = _mm_set1_epi32(-cospi[20]);
698   const __m128i cospi44 = _mm_set1_epi32(cospi[44]);
699   const __m128i cospi28 = _mm_set1_epi32(cospi[28]);
700   const __m128i cospi36 = _mm_set1_epi32(cospi[36]);
701   const __m128i cospim36 = _mm_set1_epi32(-cospi[36]);
702   const __m128i cospi52 = _mm_set1_epi32(cospi[52]);
703   const __m128i cospim52 = _mm_set1_epi32(-cospi[52]);
704   const __m128i cospi12 = _mm_set1_epi32(cospi[12]);
705   const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
706   const __m128i zero = _mm_setzero_si128();
707   __m128i u0, u1, u2, u3, u4, u5, u6, u7;
708   __m128i v0, v1, v2, v3, v4, v5, v6, v7;
709   __m128i x, y;
710   int col;
711 
712   // Note:
713   //  Even column: 0, 2, ..., 14
714   //  Odd column: 1, 3, ..., 15
715   //  one even column plus one odd column constructs one row (8 coeffs)
716   //  total we have 8 rows (8x8).
717   for (col = 0; col < col_num; ++col) {
718     // stage 0
719     // stage 1
720     u0 = in[col_num * 0 + col];
721     u1 = _mm_sub_epi32(zero, in[col_num * 7 + col]);
722     u2 = _mm_sub_epi32(zero, in[col_num * 3 + col]);
723     u3 = in[col_num * 4 + col];
724     u4 = _mm_sub_epi32(zero, in[col_num * 1 + col]);
725     u5 = in[col_num * 6 + col];
726     u6 = in[col_num * 2 + col];
727     u7 = _mm_sub_epi32(zero, in[col_num * 5 + col]);
728 
729     // stage 2
730     v0 = u0;
731     v1 = u1;
732 
733     x = _mm_mullo_epi32(u2, cospi32);
734     y = _mm_mullo_epi32(u3, cospi32);
735     v2 = _mm_add_epi32(x, y);
736     v2 = _mm_add_epi32(v2, rnding);
737     v2 = _mm_srai_epi32(v2, bit);
738 
739     v3 = _mm_sub_epi32(x, y);
740     v3 = _mm_add_epi32(v3, rnding);
741     v3 = _mm_srai_epi32(v3, bit);
742 
743     v4 = u4;
744     v5 = u5;
745 
746     x = _mm_mullo_epi32(u6, cospi32);
747     y = _mm_mullo_epi32(u7, cospi32);
748     v6 = _mm_add_epi32(x, y);
749     v6 = _mm_add_epi32(v6, rnding);
750     v6 = _mm_srai_epi32(v6, bit);
751 
752     v7 = _mm_sub_epi32(x, y);
753     v7 = _mm_add_epi32(v7, rnding);
754     v7 = _mm_srai_epi32(v7, bit);
755 
756     // stage 3
757     u0 = _mm_add_epi32(v0, v2);
758     u1 = _mm_add_epi32(v1, v3);
759     u2 = _mm_sub_epi32(v0, v2);
760     u3 = _mm_sub_epi32(v1, v3);
761     u4 = _mm_add_epi32(v4, v6);
762     u5 = _mm_add_epi32(v5, v7);
763     u6 = _mm_sub_epi32(v4, v6);
764     u7 = _mm_sub_epi32(v5, v7);
765 
766     // stage 4
767     v0 = u0;
768     v1 = u1;
769     v2 = u2;
770     v3 = u3;
771 
772     x = _mm_mullo_epi32(u4, cospi16);
773     y = _mm_mullo_epi32(u5, cospi48);
774     v4 = _mm_add_epi32(x, y);
775     v4 = _mm_add_epi32(v4, rnding);
776     v4 = _mm_srai_epi32(v4, bit);
777 
778     x = _mm_mullo_epi32(u4, cospi48);
779     y = _mm_mullo_epi32(u5, cospim16);
780     v5 = _mm_add_epi32(x, y);
781     v5 = _mm_add_epi32(v5, rnding);
782     v5 = _mm_srai_epi32(v5, bit);
783 
784     x = _mm_mullo_epi32(u6, cospim48);
785     y = _mm_mullo_epi32(u7, cospi16);
786     v6 = _mm_add_epi32(x, y);
787     v6 = _mm_add_epi32(v6, rnding);
788     v6 = _mm_srai_epi32(v6, bit);
789 
790     x = _mm_mullo_epi32(u6, cospi16);
791     y = _mm_mullo_epi32(u7, cospi48);
792     v7 = _mm_add_epi32(x, y);
793     v7 = _mm_add_epi32(v7, rnding);
794     v7 = _mm_srai_epi32(v7, bit);
795 
796     // stage 5
797     u0 = _mm_add_epi32(v0, v4);
798     u1 = _mm_add_epi32(v1, v5);
799     u2 = _mm_add_epi32(v2, v6);
800     u3 = _mm_add_epi32(v3, v7);
801     u4 = _mm_sub_epi32(v0, v4);
802     u5 = _mm_sub_epi32(v1, v5);
803     u6 = _mm_sub_epi32(v2, v6);
804     u7 = _mm_sub_epi32(v3, v7);
805 
806     // stage 6
807     x = _mm_mullo_epi32(u0, cospi4);
808     y = _mm_mullo_epi32(u1, cospi60);
809     v0 = _mm_add_epi32(x, y);
810     v0 = _mm_add_epi32(v0, rnding);
811     v0 = _mm_srai_epi32(v0, bit);
812 
813     x = _mm_mullo_epi32(u0, cospi60);
814     y = _mm_mullo_epi32(u1, cospim4);
815     v1 = _mm_add_epi32(x, y);
816     v1 = _mm_add_epi32(v1, rnding);
817     v1 = _mm_srai_epi32(v1, bit);
818 
819     x = _mm_mullo_epi32(u2, cospi20);
820     y = _mm_mullo_epi32(u3, cospi44);
821     v2 = _mm_add_epi32(x, y);
822     v2 = _mm_add_epi32(v2, rnding);
823     v2 = _mm_srai_epi32(v2, bit);
824 
825     x = _mm_mullo_epi32(u2, cospi44);
826     y = _mm_mullo_epi32(u3, cospim20);
827     v3 = _mm_add_epi32(x, y);
828     v3 = _mm_add_epi32(v3, rnding);
829     v3 = _mm_srai_epi32(v3, bit);
830 
831     x = _mm_mullo_epi32(u4, cospi36);
832     y = _mm_mullo_epi32(u5, cospi28);
833     v4 = _mm_add_epi32(x, y);
834     v4 = _mm_add_epi32(v4, rnding);
835     v4 = _mm_srai_epi32(v4, bit);
836 
837     x = _mm_mullo_epi32(u4, cospi28);
838     y = _mm_mullo_epi32(u5, cospim36);
839     v5 = _mm_add_epi32(x, y);
840     v5 = _mm_add_epi32(v5, rnding);
841     v5 = _mm_srai_epi32(v5, bit);
842 
843     x = _mm_mullo_epi32(u6, cospi52);
844     y = _mm_mullo_epi32(u7, cospi12);
845     v6 = _mm_add_epi32(x, y);
846     v6 = _mm_add_epi32(v6, rnding);
847     v6 = _mm_srai_epi32(v6, bit);
848 
849     x = _mm_mullo_epi32(u6, cospi12);
850     y = _mm_mullo_epi32(u7, cospim52);
851     v7 = _mm_add_epi32(x, y);
852     v7 = _mm_add_epi32(v7, rnding);
853     v7 = _mm_srai_epi32(v7, bit);
854 
855     // stage 7
856     out[col_num * 0 + col] = v1;
857     out[col_num * 1 + col] = v6;
858     out[col_num * 2 + col] = v3;
859     out[col_num * 3 + col] = v4;
860     out[col_num * 4 + col] = v5;
861     out[col_num * 5 + col] = v2;
862     out[col_num * 6 + col] = v7;
863     out[col_num * 7 + col] = v0;
864   }
865 }
idtx8x8_sse4_1(__m128i * in,__m128i * out,int bit,int col_num)866 static void idtx8x8_sse4_1(__m128i *in, __m128i *out, int bit, int col_num) {
867   (void)bit;
868 
869   for (int i = 0; i < col_num; i += 1) {
870     out[0 + 8 * i] = _mm_add_epi32(in[0 + 8 * i], in[0 + 8 * i]);
871     out[1 + 8 * i] = _mm_add_epi32(in[1 + 8 * i], in[1 + 8 * i]);
872     out[2 + 8 * i] = _mm_add_epi32(in[2 + 8 * i], in[2 + 8 * i]);
873     out[3 + 8 * i] = _mm_add_epi32(in[3 + 8 * i], in[3 + 8 * i]);
874     out[4 + 8 * i] = _mm_add_epi32(in[4 + 8 * i], in[4 + 8 * i]);
875     out[5 + 8 * i] = _mm_add_epi32(in[5 + 8 * i], in[5 + 8 * i]);
876     out[6 + 8 * i] = _mm_add_epi32(in[6 + 8 * i], in[6 + 8 * i]);
877     out[7 + 8 * i] = _mm_add_epi32(in[7 + 8 * i], in[7 + 8 * i]);
878   }
879 }
880 #if !CONFIG_REALTIME_ONLY
idtx32x8_sse4_1(__m128i * in,__m128i * out,int bit,int col_num)881 static void idtx32x8_sse4_1(__m128i *in, __m128i *out, int bit, int col_num) {
882   (void)bit;
883   (void)col_num;
884   for (int j = 0; j < 2; j++) {
885     out[j + 8 * 0] = _mm_add_epi32(in[j + 8 * 0], in[j + 8 * 0]);
886     out[j + 8 * 1] = _mm_add_epi32(in[j + 8 * 1], in[j + 8 * 1]);
887     out[j + 8 * 2] = _mm_add_epi32(in[j + 8 * 2], in[j + 8 * 2]);
888     out[j + 8 * 3] = _mm_add_epi32(in[j + 8 * 3], in[j + 8 * 3]);
889     out[j + 8 * 4] = _mm_add_epi32(in[j + 8 * 4], in[j + 8 * 4]);
890     out[j + 8 * 5] = _mm_add_epi32(in[j + 8 * 5], in[j + 8 * 5]);
891     out[j + 8 * 6] = _mm_add_epi32(in[j + 8 * 6], in[j + 8 * 6]);
892     out[j + 8 * 7] = _mm_add_epi32(in[j + 8 * 7], in[j + 8 * 7]);
893   }
894 }
895 #endif
av1_fwd_txfm2d_8x8_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)896 void av1_fwd_txfm2d_8x8_sse4_1(const int16_t *input, int32_t *coeff, int stride,
897                                TX_TYPE tx_type, int bd) {
898   __m128i in[16], out[16];
899   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_8X8];
900   const int txw_idx = get_txw_idx(TX_8X8);
901   const int txh_idx = get_txh_idx(TX_8X8);
902 
903   switch (tx_type) {
904     case DCT_DCT:
905       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
906       fdct8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
907       col_txfm_8x8_rounding(out, -shift[1]);
908       transpose_8x8(out, in);
909       fdct8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
910       write_buffer_8x8(out, coeff);
911       break;
912     case ADST_DCT:
913       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
914       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
915       col_txfm_8x8_rounding(out, -shift[1]);
916       transpose_8x8(out, in);
917       fdct8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
918       write_buffer_8x8(out, coeff);
919       break;
920     case DCT_ADST:
921       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
922       fdct8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
923       col_txfm_8x8_rounding(out, -shift[1]);
924       transpose_8x8(out, in);
925       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
926       write_buffer_8x8(out, coeff);
927       break;
928     case ADST_ADST:
929       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
930       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
931       col_txfm_8x8_rounding(out, -shift[1]);
932       transpose_8x8(out, in);
933       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
934       write_buffer_8x8(out, coeff);
935       break;
936     case FLIPADST_DCT:
937       load_buffer_8x8(input, in, stride, 1, 0, shift[0]);
938       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
939       col_txfm_8x8_rounding(out, -shift[1]);
940       transpose_8x8(out, in);
941       fdct8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
942       write_buffer_8x8(out, coeff);
943       break;
944     case DCT_FLIPADST:
945       load_buffer_8x8(input, in, stride, 0, 1, shift[0]);
946       fdct8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
947       col_txfm_8x8_rounding(out, -shift[1]);
948       transpose_8x8(out, in);
949       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
950       write_buffer_8x8(out, coeff);
951       break;
952     case FLIPADST_FLIPADST:
953       load_buffer_8x8(input, in, stride, 1, 1, shift[0]);
954       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
955       col_txfm_8x8_rounding(out, -shift[1]);
956       transpose_8x8(out, in);
957       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
958       write_buffer_8x8(out, coeff);
959       break;
960     case ADST_FLIPADST:
961       load_buffer_8x8(input, in, stride, 0, 1, shift[0]);
962       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
963       col_txfm_8x8_rounding(out, -shift[1]);
964       transpose_8x8(out, in);
965       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
966       write_buffer_8x8(out, coeff);
967       break;
968     case FLIPADST_ADST:
969       load_buffer_8x8(input, in, stride, 1, 0, shift[0]);
970       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
971       col_txfm_8x8_rounding(out, -shift[1]);
972       transpose_8x8(out, in);
973       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], 2);
974       write_buffer_8x8(out, coeff);
975       break;
976     case IDTX:
977       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
978       idtx8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
979       col_txfm_8x8_rounding(out, -shift[1]);
980       transpose_8x8(out, in);
981       idtx8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
982       write_buffer_8x8(out, coeff);
983       break;
984     case V_DCT:
985       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
986       fdct8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
987       col_txfm_8x8_rounding(out, -shift[1]);
988       transpose_8x8(out, in);
989       idtx8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
990       write_buffer_8x8(out, coeff);
991       break;
992     case H_DCT:
993       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
994       idtx8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
995       col_txfm_8x8_rounding(out, -shift[1]);
996       transpose_8x8(out, in);
997       fdct8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
998       write_buffer_8x8(out, coeff);
999       break;
1000     case V_ADST:
1001       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
1002       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
1003       col_txfm_8x8_rounding(out, -shift[1]);
1004       transpose_8x8(out, in);
1005       idtx8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
1006       write_buffer_8x8(out, coeff);
1007       break;
1008     case H_ADST:
1009       load_buffer_8x8(input, in, stride, 0, 0, shift[0]);
1010       idtx8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
1011       col_txfm_8x8_rounding(out, -shift[1]);
1012       transpose_8x8(out, in);
1013       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
1014       write_buffer_8x8(out, coeff);
1015       break;
1016     case V_FLIPADST:
1017       load_buffer_8x8(input, in, stride, 1, 0, shift[0]);
1018       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
1019       col_txfm_8x8_rounding(out, -shift[1]);
1020       transpose_8x8(out, in);
1021       idtx8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
1022       write_buffer_8x8(out, coeff);
1023       break;
1024     case H_FLIPADST:
1025       load_buffer_8x8(input, in, stride, 0, 1, shift[0]);
1026       idtx8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
1027       col_txfm_8x8_rounding(out, -shift[1]);
1028       transpose_8x8(out, in);
1029       fadst8x8_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], 2);
1030       write_buffer_8x8(out, coeff);
1031       break;
1032     default: assert(0);
1033   }
1034   (void)bd;
1035 }
1036 
1037 // Hybrid Transform 16x16
1038 
convert_8x8_to_16x16(const __m128i * in,__m128i * out)1039 static inline void convert_8x8_to_16x16(const __m128i *in, __m128i *out) {
1040   int row_index = 0;
1041   int dst_index = 0;
1042   int src_index = 0;
1043 
1044   // row 0, 1, .., 7
1045   do {
1046     out[dst_index] = in[src_index];
1047     out[dst_index + 1] = in[src_index + 1];
1048     out[dst_index + 2] = in[src_index + 16];
1049     out[dst_index + 3] = in[src_index + 17];
1050     dst_index += 4;
1051     src_index += 2;
1052     row_index += 1;
1053   } while (row_index < 8);
1054 
1055   // row 8, 9, ..., 15
1056   src_index += 16;
1057   do {
1058     out[dst_index] = in[src_index];
1059     out[dst_index + 1] = in[src_index + 1];
1060     out[dst_index + 2] = in[src_index + 16];
1061     out[dst_index + 3] = in[src_index + 17];
1062     dst_index += 4;
1063     src_index += 2;
1064     row_index += 1;
1065   } while (row_index < 16);
1066 }
1067 
load_buffer_16x16(const int16_t * input,__m128i * out,int stride,int flipud,int fliplr,int shift)1068 static inline void load_buffer_16x16(const int16_t *input, __m128i *out,
1069                                      int stride, int flipud, int fliplr,
1070                                      int shift) {
1071   __m128i in[64];
1072   // Load 4 8x8 blocks
1073   const int16_t *topL = input;
1074   const int16_t *topR = input + 8;
1075   const int16_t *botL = input + 8 * stride;
1076   const int16_t *botR = input + 8 * stride + 8;
1077 
1078   const int16_t *tmp;
1079 
1080   if (flipud) {
1081     // Swap left columns
1082     tmp = topL;
1083     topL = botL;
1084     botL = tmp;
1085     // Swap right columns
1086     tmp = topR;
1087     topR = botR;
1088     botR = tmp;
1089   }
1090 
1091   if (fliplr) {
1092     // Swap top rows
1093     tmp = topL;
1094     topL = topR;
1095     topR = tmp;
1096     // Swap bottom rows
1097     tmp = botL;
1098     botL = botR;
1099     botR = tmp;
1100   }
1101 
1102   // load first 8 columns
1103   load_buffer_8x8(topL, &in[0], stride, flipud, fliplr, shift);
1104   load_buffer_8x8(botL, &in[32], stride, flipud, fliplr, shift);
1105 
1106   // load second 8 columns
1107   load_buffer_8x8(topR, &in[16], stride, flipud, fliplr, shift);
1108   load_buffer_8x8(botR, &in[48], stride, flipud, fliplr, shift);
1109 
1110   convert_8x8_to_16x16(in, out);
1111 }
1112 
load_buffer_8x16(const int16_t * input,__m128i * out,int stride,int flipud,int fliplr,int shift)1113 static inline void load_buffer_8x16(const int16_t *input, __m128i *out,
1114                                     int stride, int flipud, int fliplr,
1115                                     int shift) {
1116   const int16_t *topL = input;
1117   const int16_t *botL = input + 8 * stride;
1118 
1119   const int16_t *tmp;
1120 
1121   if (flipud) {
1122     tmp = topL;
1123     topL = botL;
1124     botL = tmp;
1125   }
1126 
1127   load_buffer_8x8(topL, out, stride, flipud, fliplr, shift);
1128   load_buffer_8x8(botL, out + 16, stride, flipud, fliplr, shift);
1129 }
1130 
load_buffer_8x4(const int16_t * input,__m128i * out,int stride,int flipud,int fliplr,int shift)1131 static inline void load_buffer_8x4(const int16_t *input, __m128i *out,
1132                                    int stride, int flipud, int fliplr,
1133                                    int shift) {
1134   const int16_t *topL = input;
1135   const int16_t *topR = input + 4;
1136 
1137   const int16_t *tmp;
1138 
1139   if (fliplr) {
1140     tmp = topL;
1141     topL = topR;
1142     topR = tmp;
1143   }
1144 
1145   load_buffer_4x4(topL, out, stride, flipud, fliplr, shift);
1146   load_buffer_4x4(topR, out + 4, stride, flipud, fliplr, shift);
1147 }
1148 
load_buffer_16x4(const int16_t * input,__m128i * out,int stride,int flipud,int fliplr,int shift)1149 static inline void load_buffer_16x4(const int16_t *input, __m128i *out,
1150                                     int stride, int flipud, int fliplr,
1151                                     int shift) {
1152   const int16_t *topL = input;
1153   const int16_t *topR = input + 8;
1154 
1155   const int16_t *tmp;
1156 
1157   if (fliplr) {
1158     tmp = topL;
1159     topL = topR;
1160     topR = tmp;
1161   }
1162 
1163   load_buffer_8x4(topL, out, stride, flipud, fliplr, shift);
1164   load_buffer_8x4(topR, out + 8, stride, flipud, fliplr, shift);
1165 }
1166 
load_buffer_4x8(const int16_t * input,__m128i * out,int stride,int flipud,int fliplr,int shift)1167 static inline void load_buffer_4x8(const int16_t *input, __m128i *out,
1168                                    int stride, int flipud, int fliplr,
1169                                    int shift) {
1170   const int16_t *topL = input;
1171   const int16_t *botL = input + 4 * stride;
1172 
1173   const int16_t *tmp;
1174 
1175   if (flipud) {
1176     tmp = topL;
1177     topL = botL;
1178     botL = tmp;
1179   }
1180 
1181   load_buffer_4x4(topL, out, stride, flipud, fliplr, shift);
1182   load_buffer_4x4(botL, out + 4, stride, flipud, fliplr, shift);
1183 }
1184 
1185 #if !CONFIG_REALTIME_ONLY
load_buffer_4x16(const int16_t * input,__m128i * out,const int stride,const int flipud,const int fliplr,const int shift)1186 static inline void load_buffer_4x16(const int16_t *input, __m128i *out,
1187                                     const int stride, const int flipud,
1188                                     const int fliplr, const int shift) {
1189   const int16_t *topL = input;
1190   const int16_t *botL = input + 8 * stride;
1191 
1192   const int16_t *tmp;
1193 
1194   if (flipud) {
1195     tmp = topL;
1196     topL = botL;
1197     botL = tmp;
1198   }
1199   load_buffer_4x8(topL, out, stride, flipud, fliplr, shift);
1200   load_buffer_4x8(botL, out + 8, stride, flipud, fliplr, shift);
1201 }
1202 #endif
1203 
load_buffer_32x8n(const int16_t * input,__m128i * out,int stride,int flipud,int fliplr,int shift,const int height)1204 static inline void load_buffer_32x8n(const int16_t *input, __m128i *out,
1205                                      int stride, int flipud, int fliplr,
1206                                      int shift, const int height) {
1207   const int16_t *in = input;
1208   __m128i *output = out;
1209   for (int col = 0; col < height; col++) {
1210     in = input + col * stride;
1211     output = out + col * 8;
1212     load_buffer_4x4(in, output, 4, flipud, fliplr, shift);
1213     load_buffer_4x4((in + 16), (output + 4), 4, flipud, fliplr, shift);
1214   }
1215 }
1216 
fdct16x16_sse4_1(__m128i * in,__m128i * out,int bit,const int col_num)1217 static void fdct16x16_sse4_1(__m128i *in, __m128i *out, int bit,
1218                              const int col_num) {
1219   const int32_t *cospi = cospi_arr(bit);
1220   const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
1221   const __m128i cospim32 = _mm_set1_epi32(-cospi[32]);
1222   const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
1223   const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
1224   const __m128i cospim48 = _mm_set1_epi32(-cospi[48]);
1225   const __m128i cospim16 = _mm_set1_epi32(-cospi[16]);
1226   const __m128i cospi56 = _mm_set1_epi32(cospi[56]);
1227   const __m128i cospi8 = _mm_set1_epi32(cospi[8]);
1228   const __m128i cospi24 = _mm_set1_epi32(cospi[24]);
1229   const __m128i cospi40 = _mm_set1_epi32(cospi[40]);
1230   const __m128i cospi60 = _mm_set1_epi32(cospi[60]);
1231   const __m128i cospi4 = _mm_set1_epi32(cospi[4]);
1232   const __m128i cospi28 = _mm_set1_epi32(cospi[28]);
1233   const __m128i cospi36 = _mm_set1_epi32(cospi[36]);
1234   const __m128i cospi44 = _mm_set1_epi32(cospi[44]);
1235   const __m128i cospi20 = _mm_set1_epi32(cospi[20]);
1236   const __m128i cospi12 = _mm_set1_epi32(cospi[12]);
1237   const __m128i cospi52 = _mm_set1_epi32(cospi[52]);
1238   const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
1239   __m128i u[16], v[16], x;
1240   int col;
1241 
1242   // Calculate the column 0, 1, 2, 3
1243   for (col = 0; col < col_num; ++col) {
1244     // stage 0
1245     // stage 1
1246     u[0] = _mm_add_epi32(in[0 * col_num + col], in[15 * col_num + col]);
1247     u[15] = _mm_sub_epi32(in[0 * col_num + col], in[15 * col_num + col]);
1248     u[1] = _mm_add_epi32(in[1 * col_num + col], in[14 * col_num + col]);
1249     u[14] = _mm_sub_epi32(in[1 * col_num + col], in[14 * col_num + col]);
1250     u[2] = _mm_add_epi32(in[2 * col_num + col], in[13 * col_num + col]);
1251     u[13] = _mm_sub_epi32(in[2 * col_num + col], in[13 * col_num + col]);
1252     u[3] = _mm_add_epi32(in[3 * col_num + col], in[12 * col_num + col]);
1253     u[12] = _mm_sub_epi32(in[3 * col_num + col], in[12 * col_num + col]);
1254     u[4] = _mm_add_epi32(in[4 * col_num + col], in[11 * col_num + col]);
1255     u[11] = _mm_sub_epi32(in[4 * col_num + col], in[11 * col_num + col]);
1256     u[5] = _mm_add_epi32(in[5 * col_num + col], in[10 * col_num + col]);
1257     u[10] = _mm_sub_epi32(in[5 * col_num + col], in[10 * col_num + col]);
1258     u[6] = _mm_add_epi32(in[6 * col_num + col], in[9 * col_num + col]);
1259     u[9] = _mm_sub_epi32(in[6 * col_num + col], in[9 * col_num + col]);
1260     u[7] = _mm_add_epi32(in[7 * col_num + col], in[8 * col_num + col]);
1261     u[8] = _mm_sub_epi32(in[7 * col_num + col], in[8 * col_num + col]);
1262 
1263     // stage 2
1264     v[0] = _mm_add_epi32(u[0], u[7]);
1265     v[7] = _mm_sub_epi32(u[0], u[7]);
1266     v[1] = _mm_add_epi32(u[1], u[6]);
1267     v[6] = _mm_sub_epi32(u[1], u[6]);
1268     v[2] = _mm_add_epi32(u[2], u[5]);
1269     v[5] = _mm_sub_epi32(u[2], u[5]);
1270     v[3] = _mm_add_epi32(u[3], u[4]);
1271     v[4] = _mm_sub_epi32(u[3], u[4]);
1272     v[8] = u[8];
1273     v[9] = u[9];
1274 
1275     v[10] = _mm_mullo_epi32(u[10], cospim32);
1276     x = _mm_mullo_epi32(u[13], cospi32);
1277     v[10] = _mm_add_epi32(v[10], x);
1278     v[10] = _mm_add_epi32(v[10], rnding);
1279     v[10] = _mm_srai_epi32(v[10], bit);
1280 
1281     v[13] = _mm_mullo_epi32(u[10], cospi32);
1282     x = _mm_mullo_epi32(u[13], cospim32);
1283     v[13] = _mm_sub_epi32(v[13], x);
1284     v[13] = _mm_add_epi32(v[13], rnding);
1285     v[13] = _mm_srai_epi32(v[13], bit);
1286 
1287     v[11] = _mm_mullo_epi32(u[11], cospim32);
1288     x = _mm_mullo_epi32(u[12], cospi32);
1289     v[11] = _mm_add_epi32(v[11], x);
1290     v[11] = _mm_add_epi32(v[11], rnding);
1291     v[11] = _mm_srai_epi32(v[11], bit);
1292 
1293     v[12] = _mm_mullo_epi32(u[11], cospi32);
1294     x = _mm_mullo_epi32(u[12], cospim32);
1295     v[12] = _mm_sub_epi32(v[12], x);
1296     v[12] = _mm_add_epi32(v[12], rnding);
1297     v[12] = _mm_srai_epi32(v[12], bit);
1298     v[14] = u[14];
1299     v[15] = u[15];
1300 
1301     // stage 3
1302     u[0] = _mm_add_epi32(v[0], v[3]);
1303     u[3] = _mm_sub_epi32(v[0], v[3]);
1304     u[1] = _mm_add_epi32(v[1], v[2]);
1305     u[2] = _mm_sub_epi32(v[1], v[2]);
1306     u[4] = v[4];
1307 
1308     u[5] = _mm_mullo_epi32(v[5], cospim32);
1309     x = _mm_mullo_epi32(v[6], cospi32);
1310     u[5] = _mm_add_epi32(u[5], x);
1311     u[5] = _mm_add_epi32(u[5], rnding);
1312     u[5] = _mm_srai_epi32(u[5], bit);
1313 
1314     u[6] = _mm_mullo_epi32(v[5], cospi32);
1315     x = _mm_mullo_epi32(v[6], cospim32);
1316     u[6] = _mm_sub_epi32(u[6], x);
1317     u[6] = _mm_add_epi32(u[6], rnding);
1318     u[6] = _mm_srai_epi32(u[6], bit);
1319 
1320     u[7] = v[7];
1321     u[8] = _mm_add_epi32(v[8], v[11]);
1322     u[11] = _mm_sub_epi32(v[8], v[11]);
1323     u[9] = _mm_add_epi32(v[9], v[10]);
1324     u[10] = _mm_sub_epi32(v[9], v[10]);
1325     u[12] = _mm_sub_epi32(v[15], v[12]);
1326     u[15] = _mm_add_epi32(v[15], v[12]);
1327     u[13] = _mm_sub_epi32(v[14], v[13]);
1328     u[14] = _mm_add_epi32(v[14], v[13]);
1329 
1330     // stage 4
1331     u[0] = _mm_mullo_epi32(u[0], cospi32);
1332     u[1] = _mm_mullo_epi32(u[1], cospi32);
1333     v[0] = _mm_add_epi32(u[0], u[1]);
1334     v[0] = _mm_add_epi32(v[0], rnding);
1335     v[0] = _mm_srai_epi32(v[0], bit);
1336 
1337     v[1] = _mm_sub_epi32(u[0], u[1]);
1338     v[1] = _mm_add_epi32(v[1], rnding);
1339     v[1] = _mm_srai_epi32(v[1], bit);
1340 
1341     v[2] = _mm_mullo_epi32(u[2], cospi48);
1342     x = _mm_mullo_epi32(u[3], cospi16);
1343     v[2] = _mm_add_epi32(v[2], x);
1344     v[2] = _mm_add_epi32(v[2], rnding);
1345     v[2] = _mm_srai_epi32(v[2], bit);
1346 
1347     v[3] = _mm_mullo_epi32(u[2], cospi16);
1348     x = _mm_mullo_epi32(u[3], cospi48);
1349     v[3] = _mm_sub_epi32(x, v[3]);
1350     v[3] = _mm_add_epi32(v[3], rnding);
1351     v[3] = _mm_srai_epi32(v[3], bit);
1352 
1353     v[4] = _mm_add_epi32(u[4], u[5]);
1354     v[5] = _mm_sub_epi32(u[4], u[5]);
1355     v[6] = _mm_sub_epi32(u[7], u[6]);
1356     v[7] = _mm_add_epi32(u[7], u[6]);
1357     v[8] = u[8];
1358 
1359     v[9] = _mm_mullo_epi32(u[9], cospim16);
1360     x = _mm_mullo_epi32(u[14], cospi48);
1361     v[9] = _mm_add_epi32(v[9], x);
1362     v[9] = _mm_add_epi32(v[9], rnding);
1363     v[9] = _mm_srai_epi32(v[9], bit);
1364 
1365     v[14] = _mm_mullo_epi32(u[9], cospi48);
1366     x = _mm_mullo_epi32(u[14], cospim16);
1367     v[14] = _mm_sub_epi32(v[14], x);
1368     v[14] = _mm_add_epi32(v[14], rnding);
1369     v[14] = _mm_srai_epi32(v[14], bit);
1370 
1371     v[10] = _mm_mullo_epi32(u[10], cospim48);
1372     x = _mm_mullo_epi32(u[13], cospim16);
1373     v[10] = _mm_add_epi32(v[10], x);
1374     v[10] = _mm_add_epi32(v[10], rnding);
1375     v[10] = _mm_srai_epi32(v[10], bit);
1376 
1377     v[13] = _mm_mullo_epi32(u[10], cospim16);
1378     x = _mm_mullo_epi32(u[13], cospim48);
1379     v[13] = _mm_sub_epi32(v[13], x);
1380     v[13] = _mm_add_epi32(v[13], rnding);
1381     v[13] = _mm_srai_epi32(v[13], bit);
1382 
1383     v[11] = u[11];
1384     v[12] = u[12];
1385     v[15] = u[15];
1386 
1387     // stage 5
1388     u[0] = v[0];
1389     u[1] = v[1];
1390     u[2] = v[2];
1391     u[3] = v[3];
1392 
1393     u[4] = _mm_mullo_epi32(v[4], cospi56);
1394     x = _mm_mullo_epi32(v[7], cospi8);
1395     u[4] = _mm_add_epi32(u[4], x);
1396     u[4] = _mm_add_epi32(u[4], rnding);
1397     u[4] = _mm_srai_epi32(u[4], bit);
1398 
1399     u[7] = _mm_mullo_epi32(v[4], cospi8);
1400     x = _mm_mullo_epi32(v[7], cospi56);
1401     u[7] = _mm_sub_epi32(x, u[7]);
1402     u[7] = _mm_add_epi32(u[7], rnding);
1403     u[7] = _mm_srai_epi32(u[7], bit);
1404 
1405     u[5] = _mm_mullo_epi32(v[5], cospi24);
1406     x = _mm_mullo_epi32(v[6], cospi40);
1407     u[5] = _mm_add_epi32(u[5], x);
1408     u[5] = _mm_add_epi32(u[5], rnding);
1409     u[5] = _mm_srai_epi32(u[5], bit);
1410 
1411     u[6] = _mm_mullo_epi32(v[5], cospi40);
1412     x = _mm_mullo_epi32(v[6], cospi24);
1413     u[6] = _mm_sub_epi32(x, u[6]);
1414     u[6] = _mm_add_epi32(u[6], rnding);
1415     u[6] = _mm_srai_epi32(u[6], bit);
1416 
1417     u[8] = _mm_add_epi32(v[8], v[9]);
1418     u[9] = _mm_sub_epi32(v[8], v[9]);
1419     u[10] = _mm_sub_epi32(v[11], v[10]);
1420     u[11] = _mm_add_epi32(v[11], v[10]);
1421     u[12] = _mm_add_epi32(v[12], v[13]);
1422     u[13] = _mm_sub_epi32(v[12], v[13]);
1423     u[14] = _mm_sub_epi32(v[15], v[14]);
1424     u[15] = _mm_add_epi32(v[15], v[14]);
1425 
1426     // stage 6
1427     v[0] = u[0];
1428     v[1] = u[1];
1429     v[2] = u[2];
1430     v[3] = u[3];
1431     v[4] = u[4];
1432     v[5] = u[5];
1433     v[6] = u[6];
1434     v[7] = u[7];
1435 
1436     v[8] = _mm_mullo_epi32(u[8], cospi60);
1437     x = _mm_mullo_epi32(u[15], cospi4);
1438     v[8] = _mm_add_epi32(v[8], x);
1439     v[8] = _mm_add_epi32(v[8], rnding);
1440     v[8] = _mm_srai_epi32(v[8], bit);
1441 
1442     v[15] = _mm_mullo_epi32(u[8], cospi4);
1443     x = _mm_mullo_epi32(u[15], cospi60);
1444     v[15] = _mm_sub_epi32(x, v[15]);
1445     v[15] = _mm_add_epi32(v[15], rnding);
1446     v[15] = _mm_srai_epi32(v[15], bit);
1447 
1448     v[9] = _mm_mullo_epi32(u[9], cospi28);
1449     x = _mm_mullo_epi32(u[14], cospi36);
1450     v[9] = _mm_add_epi32(v[9], x);
1451     v[9] = _mm_add_epi32(v[9], rnding);
1452     v[9] = _mm_srai_epi32(v[9], bit);
1453 
1454     v[14] = _mm_mullo_epi32(u[9], cospi36);
1455     x = _mm_mullo_epi32(u[14], cospi28);
1456     v[14] = _mm_sub_epi32(x, v[14]);
1457     v[14] = _mm_add_epi32(v[14], rnding);
1458     v[14] = _mm_srai_epi32(v[14], bit);
1459 
1460     v[10] = _mm_mullo_epi32(u[10], cospi44);
1461     x = _mm_mullo_epi32(u[13], cospi20);
1462     v[10] = _mm_add_epi32(v[10], x);
1463     v[10] = _mm_add_epi32(v[10], rnding);
1464     v[10] = _mm_srai_epi32(v[10], bit);
1465 
1466     v[13] = _mm_mullo_epi32(u[10], cospi20);
1467     x = _mm_mullo_epi32(u[13], cospi44);
1468     v[13] = _mm_sub_epi32(x, v[13]);
1469     v[13] = _mm_add_epi32(v[13], rnding);
1470     v[13] = _mm_srai_epi32(v[13], bit);
1471 
1472     v[11] = _mm_mullo_epi32(u[11], cospi12);
1473     x = _mm_mullo_epi32(u[12], cospi52);
1474     v[11] = _mm_add_epi32(v[11], x);
1475     v[11] = _mm_add_epi32(v[11], rnding);
1476     v[11] = _mm_srai_epi32(v[11], bit);
1477 
1478     v[12] = _mm_mullo_epi32(u[11], cospi52);
1479     x = _mm_mullo_epi32(u[12], cospi12);
1480     v[12] = _mm_sub_epi32(x, v[12]);
1481     v[12] = _mm_add_epi32(v[12], rnding);
1482     v[12] = _mm_srai_epi32(v[12], bit);
1483 
1484     out[0 * col_num + col] = v[0];
1485     out[1 * col_num + col] = v[8];
1486     out[2 * col_num + col] = v[4];
1487     out[3 * col_num + col] = v[12];
1488     out[4 * col_num + col] = v[2];
1489     out[5 * col_num + col] = v[10];
1490     out[6 * col_num + col] = v[6];
1491     out[7 * col_num + col] = v[14];
1492     out[8 * col_num + col] = v[1];
1493     out[9 * col_num + col] = v[9];
1494     out[10 * col_num + col] = v[5];
1495     out[11 * col_num + col] = v[13];
1496     out[12 * col_num + col] = v[3];
1497     out[13 * col_num + col] = v[11];
1498     out[14 * col_num + col] = v[7];
1499     out[15 * col_num + col] = v[15];
1500   }
1501 }
1502 
fadst16x16_sse4_1(__m128i * in,__m128i * out,int bit,const int num_cols)1503 static void fadst16x16_sse4_1(__m128i *in, __m128i *out, int bit,
1504                               const int num_cols) {
1505   const int32_t *cospi = cospi_arr(bit);
1506   const __m128i cospi32 = _mm_set1_epi32(cospi[32]);
1507   const __m128i cospi48 = _mm_set1_epi32(cospi[48]);
1508   const __m128i cospi16 = _mm_set1_epi32(cospi[16]);
1509   const __m128i cospim16 = _mm_set1_epi32(-cospi[16]);
1510   const __m128i cospim48 = _mm_set1_epi32(-cospi[48]);
1511   const __m128i cospi8 = _mm_set1_epi32(cospi[8]);
1512   const __m128i cospi56 = _mm_set1_epi32(cospi[56]);
1513   const __m128i cospim56 = _mm_set1_epi32(-cospi[56]);
1514   const __m128i cospim8 = _mm_set1_epi32(-cospi[8]);
1515   const __m128i cospi24 = _mm_set1_epi32(cospi[24]);
1516   const __m128i cospim24 = _mm_set1_epi32(-cospi[24]);
1517   const __m128i cospim40 = _mm_set1_epi32(-cospi[40]);
1518   const __m128i cospi40 = _mm_set1_epi32(cospi[40]);
1519   const __m128i cospi2 = _mm_set1_epi32(cospi[2]);
1520   const __m128i cospi62 = _mm_set1_epi32(cospi[62]);
1521   const __m128i cospim2 = _mm_set1_epi32(-cospi[2]);
1522   const __m128i cospi10 = _mm_set1_epi32(cospi[10]);
1523   const __m128i cospi54 = _mm_set1_epi32(cospi[54]);
1524   const __m128i cospim10 = _mm_set1_epi32(-cospi[10]);
1525   const __m128i cospi18 = _mm_set1_epi32(cospi[18]);
1526   const __m128i cospi46 = _mm_set1_epi32(cospi[46]);
1527   const __m128i cospim18 = _mm_set1_epi32(-cospi[18]);
1528   const __m128i cospi26 = _mm_set1_epi32(cospi[26]);
1529   const __m128i cospi38 = _mm_set1_epi32(cospi[38]);
1530   const __m128i cospim26 = _mm_set1_epi32(-cospi[26]);
1531   const __m128i cospi34 = _mm_set1_epi32(cospi[34]);
1532   const __m128i cospi30 = _mm_set1_epi32(cospi[30]);
1533   const __m128i cospim34 = _mm_set1_epi32(-cospi[34]);
1534   const __m128i cospi42 = _mm_set1_epi32(cospi[42]);
1535   const __m128i cospi22 = _mm_set1_epi32(cospi[22]);
1536   const __m128i cospim42 = _mm_set1_epi32(-cospi[42]);
1537   const __m128i cospi50 = _mm_set1_epi32(cospi[50]);
1538   const __m128i cospi14 = _mm_set1_epi32(cospi[14]);
1539   const __m128i cospim50 = _mm_set1_epi32(-cospi[50]);
1540   const __m128i cospi58 = _mm_set1_epi32(cospi[58]);
1541   const __m128i cospi6 = _mm_set1_epi32(cospi[6]);
1542   const __m128i cospim58 = _mm_set1_epi32(-cospi[58]);
1543   const __m128i rnding = _mm_set1_epi32(1 << (bit - 1));
1544   const __m128i zero = _mm_setzero_si128();
1545 
1546   __m128i u[16], v[16], x, y;
1547   int col;
1548 
1549   for (col = 0; col < num_cols; ++col) {
1550     // stage 0
1551     // stage 1
1552     u[0] = in[0 * num_cols + col];
1553     u[1] = _mm_sub_epi32(zero, in[15 * num_cols + col]);
1554     u[2] = _mm_sub_epi32(zero, in[7 * num_cols + col]);
1555     u[3] = in[8 * num_cols + col];
1556     u[4] = _mm_sub_epi32(zero, in[3 * num_cols + col]);
1557     u[5] = in[12 * num_cols + col];
1558     u[6] = in[4 * num_cols + col];
1559     u[7] = _mm_sub_epi32(zero, in[11 * num_cols + col]);
1560     u[8] = _mm_sub_epi32(zero, in[1 * num_cols + col]);
1561     u[9] = in[14 * num_cols + col];
1562     u[10] = in[6 * num_cols + col];
1563     u[11] = _mm_sub_epi32(zero, in[9 * num_cols + col]);
1564     u[12] = in[2 * num_cols + col];
1565     u[13] = _mm_sub_epi32(zero, in[13 * num_cols + col]);
1566     u[14] = _mm_sub_epi32(zero, in[5 * num_cols + col]);
1567     u[15] = in[10 * num_cols + col];
1568 
1569     // stage 2
1570     v[0] = u[0];
1571     v[1] = u[1];
1572 
1573     x = _mm_mullo_epi32(u[2], cospi32);
1574     y = _mm_mullo_epi32(u[3], cospi32);
1575     v[2] = _mm_add_epi32(x, y);
1576     v[2] = _mm_add_epi32(v[2], rnding);
1577     v[2] = _mm_srai_epi32(v[2], bit);
1578 
1579     v[3] = _mm_sub_epi32(x, y);
1580     v[3] = _mm_add_epi32(v[3], rnding);
1581     v[3] = _mm_srai_epi32(v[3], bit);
1582 
1583     v[4] = u[4];
1584     v[5] = u[5];
1585 
1586     x = _mm_mullo_epi32(u[6], cospi32);
1587     y = _mm_mullo_epi32(u[7], cospi32);
1588     v[6] = _mm_add_epi32(x, y);
1589     v[6] = _mm_add_epi32(v[6], rnding);
1590     v[6] = _mm_srai_epi32(v[6], bit);
1591 
1592     v[7] = _mm_sub_epi32(x, y);
1593     v[7] = _mm_add_epi32(v[7], rnding);
1594     v[7] = _mm_srai_epi32(v[7], bit);
1595 
1596     v[8] = u[8];
1597     v[9] = u[9];
1598 
1599     x = _mm_mullo_epi32(u[10], cospi32);
1600     y = _mm_mullo_epi32(u[11], cospi32);
1601     v[10] = _mm_add_epi32(x, y);
1602     v[10] = _mm_add_epi32(v[10], rnding);
1603     v[10] = _mm_srai_epi32(v[10], bit);
1604 
1605     v[11] = _mm_sub_epi32(x, y);
1606     v[11] = _mm_add_epi32(v[11], rnding);
1607     v[11] = _mm_srai_epi32(v[11], bit);
1608 
1609     v[12] = u[12];
1610     v[13] = u[13];
1611 
1612     x = _mm_mullo_epi32(u[14], cospi32);
1613     y = _mm_mullo_epi32(u[15], cospi32);
1614     v[14] = _mm_add_epi32(x, y);
1615     v[14] = _mm_add_epi32(v[14], rnding);
1616     v[14] = _mm_srai_epi32(v[14], bit);
1617 
1618     v[15] = _mm_sub_epi32(x, y);
1619     v[15] = _mm_add_epi32(v[15], rnding);
1620     v[15] = _mm_srai_epi32(v[15], bit);
1621 
1622     // stage 3
1623     u[0] = _mm_add_epi32(v[0], v[2]);
1624     u[1] = _mm_add_epi32(v[1], v[3]);
1625     u[2] = _mm_sub_epi32(v[0], v[2]);
1626     u[3] = _mm_sub_epi32(v[1], v[3]);
1627     u[4] = _mm_add_epi32(v[4], v[6]);
1628     u[5] = _mm_add_epi32(v[5], v[7]);
1629     u[6] = _mm_sub_epi32(v[4], v[6]);
1630     u[7] = _mm_sub_epi32(v[5], v[7]);
1631     u[8] = _mm_add_epi32(v[8], v[10]);
1632     u[9] = _mm_add_epi32(v[9], v[11]);
1633     u[10] = _mm_sub_epi32(v[8], v[10]);
1634     u[11] = _mm_sub_epi32(v[9], v[11]);
1635     u[12] = _mm_add_epi32(v[12], v[14]);
1636     u[13] = _mm_add_epi32(v[13], v[15]);
1637     u[14] = _mm_sub_epi32(v[12], v[14]);
1638     u[15] = _mm_sub_epi32(v[13], v[15]);
1639 
1640     // stage 4
1641     v[0] = u[0];
1642     v[1] = u[1];
1643     v[2] = u[2];
1644     v[3] = u[3];
1645     v[4] = half_btf_sse4_1(&cospi16, &u[4], &cospi48, &u[5], &rnding, bit);
1646     v[5] = half_btf_sse4_1(&cospi48, &u[4], &cospim16, &u[5], &rnding, bit);
1647     v[6] = half_btf_sse4_1(&cospim48, &u[6], &cospi16, &u[7], &rnding, bit);
1648     v[7] = half_btf_sse4_1(&cospi16, &u[6], &cospi48, &u[7], &rnding, bit);
1649     v[8] = u[8];
1650     v[9] = u[9];
1651     v[10] = u[10];
1652     v[11] = u[11];
1653     v[12] = half_btf_sse4_1(&cospi16, &u[12], &cospi48, &u[13], &rnding, bit);
1654     v[13] = half_btf_sse4_1(&cospi48, &u[12], &cospim16, &u[13], &rnding, bit);
1655     v[14] = half_btf_sse4_1(&cospim48, &u[14], &cospi16, &u[15], &rnding, bit);
1656     v[15] = half_btf_sse4_1(&cospi16, &u[14], &cospi48, &u[15], &rnding, bit);
1657 
1658     // stage 5
1659     u[0] = _mm_add_epi32(v[0], v[4]);
1660     u[1] = _mm_add_epi32(v[1], v[5]);
1661     u[2] = _mm_add_epi32(v[2], v[6]);
1662     u[3] = _mm_add_epi32(v[3], v[7]);
1663     u[4] = _mm_sub_epi32(v[0], v[4]);
1664     u[5] = _mm_sub_epi32(v[1], v[5]);
1665     u[6] = _mm_sub_epi32(v[2], v[6]);
1666     u[7] = _mm_sub_epi32(v[3], v[7]);
1667     u[8] = _mm_add_epi32(v[8], v[12]);
1668     u[9] = _mm_add_epi32(v[9], v[13]);
1669     u[10] = _mm_add_epi32(v[10], v[14]);
1670     u[11] = _mm_add_epi32(v[11], v[15]);
1671     u[12] = _mm_sub_epi32(v[8], v[12]);
1672     u[13] = _mm_sub_epi32(v[9], v[13]);
1673     u[14] = _mm_sub_epi32(v[10], v[14]);
1674     u[15] = _mm_sub_epi32(v[11], v[15]);
1675 
1676     // stage 6
1677     v[0] = u[0];
1678     v[1] = u[1];
1679     v[2] = u[2];
1680     v[3] = u[3];
1681     v[4] = u[4];
1682     v[5] = u[5];
1683     v[6] = u[6];
1684     v[7] = u[7];
1685     v[8] = half_btf_sse4_1(&cospi8, &u[8], &cospi56, &u[9], &rnding, bit);
1686     v[9] = half_btf_sse4_1(&cospi56, &u[8], &cospim8, &u[9], &rnding, bit);
1687     v[10] = half_btf_sse4_1(&cospi40, &u[10], &cospi24, &u[11], &rnding, bit);
1688     v[11] = half_btf_sse4_1(&cospi24, &u[10], &cospim40, &u[11], &rnding, bit);
1689     v[12] = half_btf_sse4_1(&cospim56, &u[12], &cospi8, &u[13], &rnding, bit);
1690     v[13] = half_btf_sse4_1(&cospi8, &u[12], &cospi56, &u[13], &rnding, bit);
1691     v[14] = half_btf_sse4_1(&cospim24, &u[14], &cospi40, &u[15], &rnding, bit);
1692     v[15] = half_btf_sse4_1(&cospi40, &u[14], &cospi24, &u[15], &rnding, bit);
1693 
1694     // stage 7
1695     u[0] = _mm_add_epi32(v[0], v[8]);
1696     u[1] = _mm_add_epi32(v[1], v[9]);
1697     u[2] = _mm_add_epi32(v[2], v[10]);
1698     u[3] = _mm_add_epi32(v[3], v[11]);
1699     u[4] = _mm_add_epi32(v[4], v[12]);
1700     u[5] = _mm_add_epi32(v[5], v[13]);
1701     u[6] = _mm_add_epi32(v[6], v[14]);
1702     u[7] = _mm_add_epi32(v[7], v[15]);
1703     u[8] = _mm_sub_epi32(v[0], v[8]);
1704     u[9] = _mm_sub_epi32(v[1], v[9]);
1705     u[10] = _mm_sub_epi32(v[2], v[10]);
1706     u[11] = _mm_sub_epi32(v[3], v[11]);
1707     u[12] = _mm_sub_epi32(v[4], v[12]);
1708     u[13] = _mm_sub_epi32(v[5], v[13]);
1709     u[14] = _mm_sub_epi32(v[6], v[14]);
1710     u[15] = _mm_sub_epi32(v[7], v[15]);
1711 
1712     // stage 8
1713     v[0] = half_btf_sse4_1(&cospi2, &u[0], &cospi62, &u[1], &rnding, bit);
1714     v[1] = half_btf_sse4_1(&cospi62, &u[0], &cospim2, &u[1], &rnding, bit);
1715     v[2] = half_btf_sse4_1(&cospi10, &u[2], &cospi54, &u[3], &rnding, bit);
1716     v[3] = half_btf_sse4_1(&cospi54, &u[2], &cospim10, &u[3], &rnding, bit);
1717     v[4] = half_btf_sse4_1(&cospi18, &u[4], &cospi46, &u[5], &rnding, bit);
1718     v[5] = half_btf_sse4_1(&cospi46, &u[4], &cospim18, &u[5], &rnding, bit);
1719     v[6] = half_btf_sse4_1(&cospi26, &u[6], &cospi38, &u[7], &rnding, bit);
1720     v[7] = half_btf_sse4_1(&cospi38, &u[6], &cospim26, &u[7], &rnding, bit);
1721     v[8] = half_btf_sse4_1(&cospi34, &u[8], &cospi30, &u[9], &rnding, bit);
1722     v[9] = half_btf_sse4_1(&cospi30, &u[8], &cospim34, &u[9], &rnding, bit);
1723     v[10] = half_btf_sse4_1(&cospi42, &u[10], &cospi22, &u[11], &rnding, bit);
1724     v[11] = half_btf_sse4_1(&cospi22, &u[10], &cospim42, &u[11], &rnding, bit);
1725     v[12] = half_btf_sse4_1(&cospi50, &u[12], &cospi14, &u[13], &rnding, bit);
1726     v[13] = half_btf_sse4_1(&cospi14, &u[12], &cospim50, &u[13], &rnding, bit);
1727     v[14] = half_btf_sse4_1(&cospi58, &u[14], &cospi6, &u[15], &rnding, bit);
1728     v[15] = half_btf_sse4_1(&cospi6, &u[14], &cospim58, &u[15], &rnding, bit);
1729 
1730     // stage 9
1731     out[0 * num_cols + col] = v[1];
1732     out[1 * num_cols + col] = v[14];
1733     out[2 * num_cols + col] = v[3];
1734     out[3 * num_cols + col] = v[12];
1735     out[4 * num_cols + col] = v[5];
1736     out[5 * num_cols + col] = v[10];
1737     out[6 * num_cols + col] = v[7];
1738     out[7 * num_cols + col] = v[8];
1739     out[8 * num_cols + col] = v[9];
1740     out[9 * num_cols + col] = v[6];
1741     out[10 * num_cols + col] = v[11];
1742     out[11 * num_cols + col] = v[4];
1743     out[12 * num_cols + col] = v[13];
1744     out[13 * num_cols + col] = v[2];
1745     out[14 * num_cols + col] = v[15];
1746     out[15 * num_cols + col] = v[0];
1747   }
1748 }
1749 
col_txfm_16x16_rounding(__m128i * in,int shift)1750 static void col_txfm_16x16_rounding(__m128i *in, int shift) {
1751   // Note:
1752   //  We split 16x16 rounding into 4 sections of 8x8 rounding,
1753   //  instead of 4 columns
1754   col_txfm_8x8_rounding(&in[0], shift);
1755   col_txfm_8x8_rounding(&in[16], shift);
1756   col_txfm_8x8_rounding(&in[32], shift);
1757   col_txfm_8x8_rounding(&in[48], shift);
1758 }
1759 
col_txfm_8x16_rounding(__m128i * in,int shift)1760 static void col_txfm_8x16_rounding(__m128i *in, int shift) {
1761   col_txfm_8x8_rounding(&in[0], shift);
1762   col_txfm_8x8_rounding(&in[16], shift);
1763 }
1764 
write_buffer_16x16(const __m128i * in,int32_t * output)1765 static void write_buffer_16x16(const __m128i *in, int32_t *output) {
1766   const int size_8x8 = 16 * 4;
1767   write_buffer_8x8(&in[0], output);
1768   output += size_8x8;
1769   write_buffer_8x8(&in[16], output);
1770   output += size_8x8;
1771   write_buffer_8x8(&in[32], output);
1772   output += size_8x8;
1773   write_buffer_8x8(&in[48], output);
1774 }
idtx16x16_sse4_1(__m128i * in,__m128i * out,int bit,int col_num)1775 static void idtx16x16_sse4_1(__m128i *in, __m128i *out, int bit, int col_num) {
1776   (void)bit;
1777   __m128i fact = _mm_set1_epi32(2 * NewSqrt2);
1778   __m128i offset = _mm_set1_epi32(1 << (NewSqrt2Bits - 1));
1779   __m128i a_low;
1780 
1781   int num_iters = 16 * col_num;
1782   for (int i = 0; i < num_iters; i++) {
1783     a_low = _mm_mullo_epi32(in[i], fact);
1784     a_low = _mm_add_epi32(a_low, offset);
1785     out[i] = _mm_srai_epi32(a_low, NewSqrt2Bits);
1786   }
1787 }
av1_fwd_txfm2d_16x16_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)1788 void av1_fwd_txfm2d_16x16_sse4_1(const int16_t *input, int32_t *coeff,
1789                                  int stride, TX_TYPE tx_type, int bd) {
1790   __m128i in[64], out[64];
1791   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_16X16];
1792   const int txw_idx = get_txw_idx(TX_16X16);
1793   const int txh_idx = get_txh_idx(TX_16X16);
1794   const int col_num = 4;
1795   switch (tx_type) {
1796     case DCT_DCT:
1797       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1798       fdct16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], col_num);
1799       col_txfm_16x16_rounding(out, -shift[1]);
1800       transpose_16x16(out, in);
1801       fdct16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], col_num);
1802       write_buffer_16x16(out, coeff);
1803       break;
1804     case ADST_DCT:
1805       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1806       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx],
1807                         col_num);
1808       col_txfm_16x16_rounding(out, -shift[1]);
1809       transpose_16x16(out, in);
1810       fdct16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], col_num);
1811       write_buffer_16x16(out, coeff);
1812       break;
1813     case DCT_ADST:
1814       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1815       fdct16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], col_num);
1816       col_txfm_16x16_rounding(out, -shift[1]);
1817       transpose_16x16(out, in);
1818       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx],
1819                         col_num);
1820       write_buffer_16x16(out, coeff);
1821       break;
1822     case ADST_ADST:
1823       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1824       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx],
1825                         col_num);
1826       col_txfm_16x16_rounding(out, -shift[1]);
1827       transpose_16x16(out, in);
1828       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx],
1829                         col_num);
1830       write_buffer_16x16(out, coeff);
1831       break;
1832     case FLIPADST_DCT:
1833       load_buffer_16x16(input, in, stride, 1, 0, shift[0]);
1834       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx],
1835                         col_num);
1836       col_txfm_16x16_rounding(out, -shift[1]);
1837       transpose_16x16(out, in);
1838       fdct16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], col_num);
1839       write_buffer_16x16(out, coeff);
1840       break;
1841     case DCT_FLIPADST:
1842       load_buffer_16x16(input, in, stride, 0, 1, shift[0]);
1843       fdct16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], col_num);
1844       col_txfm_16x16_rounding(out, -shift[1]);
1845       transpose_16x16(out, in);
1846       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx],
1847                         col_num);
1848       write_buffer_16x16(out, coeff);
1849       break;
1850     case FLIPADST_FLIPADST:
1851       load_buffer_16x16(input, in, stride, 1, 1, shift[0]);
1852       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx],
1853                         col_num);
1854       col_txfm_16x16_rounding(out, -shift[1]);
1855       transpose_16x16(out, in);
1856       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx],
1857                         col_num);
1858       write_buffer_16x16(out, coeff);
1859       break;
1860     case ADST_FLIPADST:
1861       load_buffer_16x16(input, in, stride, 0, 1, shift[0]);
1862       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx],
1863                         col_num);
1864       col_txfm_16x16_rounding(out, -shift[1]);
1865       transpose_16x16(out, in);
1866       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx],
1867                         col_num);
1868       write_buffer_16x16(out, coeff);
1869       break;
1870     case FLIPADST_ADST:
1871       load_buffer_16x16(input, in, stride, 1, 0, shift[0]);
1872       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx],
1873                         col_num);
1874       col_txfm_16x16_rounding(out, -shift[1]);
1875       transpose_16x16(out, in);
1876       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx],
1877                         col_num);
1878       write_buffer_16x16(out, coeff);
1879       break;
1880     case IDTX:
1881       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1882       idtx16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], col_num);
1883       col_txfm_16x16_rounding(out, -shift[1]);
1884       transpose_16x16(out, in);
1885       idtx16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], col_num);
1886       write_buffer_16x16(out, coeff);
1887       break;
1888     case V_DCT:
1889       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1890       fdct16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], col_num);
1891       col_txfm_16x16_rounding(out, -shift[1]);
1892       transpose_16x16(out, in);
1893       idtx16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], col_num);
1894       write_buffer_16x16(out, coeff);
1895       break;
1896     case H_DCT:
1897       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1898       idtx16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], col_num);
1899       col_txfm_16x16_rounding(out, -shift[1]);
1900       transpose_16x16(out, in);
1901       fdct16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], col_num);
1902       write_buffer_16x16(out, coeff);
1903       break;
1904     case V_ADST:
1905       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1906       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx],
1907                         col_num);
1908       col_txfm_16x16_rounding(out, -shift[1]);
1909       transpose_16x16(out, in);
1910       idtx16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], col_num);
1911       write_buffer_16x16(out, coeff);
1912       break;
1913     case H_ADST:
1914       load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
1915       idtx16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], col_num);
1916       col_txfm_16x16_rounding(out, -shift[1]);
1917       transpose_16x16(out, in);
1918       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx],
1919                         col_num);
1920       write_buffer_16x16(out, coeff);
1921       break;
1922     case V_FLIPADST:
1923       load_buffer_16x16(input, in, stride, 1, 0, shift[0]);
1924       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx],
1925                         col_num);
1926       col_txfm_16x16_rounding(out, -shift[1]);
1927       transpose_16x16(out, in);
1928       idtx16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx], col_num);
1929       write_buffer_16x16(out, coeff);
1930       break;
1931     case H_FLIPADST:
1932       load_buffer_16x16(input, in, stride, 0, 1, shift[0]);
1933       idtx16x16_sse4_1(in, out, av1_fwd_cos_bit_col[txw_idx][txh_idx], col_num);
1934       col_txfm_16x16_rounding(out, -shift[1]);
1935       transpose_16x16(out, in);
1936       fadst16x16_sse4_1(in, out, av1_fwd_cos_bit_row[txw_idx][txh_idx],
1937                         col_num);
1938       write_buffer_16x16(out, coeff);
1939       break;
1940     default: assert(0);
1941   }
1942   (void)bd;
1943 }
1944 
flip_buf_sse4_1(__m128i * in,__m128i * out,int size)1945 static inline void flip_buf_sse4_1(__m128i *in, __m128i *out, int size) {
1946   for (int i = 0; i < size; i += 2) in[30 - i] = out[i];
1947   for (int i = 1; i < size; i += 2) in[size - i] = out[i];
1948 }
1949 
1950 static const fwd_transform_1d_sse4_1 col_highbd_txfm8x8_arr[TX_TYPES] = {
1951   fdct8x8_sse4_1,   // DCT_DCT
1952   fadst8x8_sse4_1,  // ADST_DCT
1953   fdct8x8_sse4_1,   // DCT_ADST
1954   fadst8x8_sse4_1,  // ADST_ADST
1955   fadst8x8_sse4_1,  // FLIPADST_DCT
1956   fdct8x8_sse4_1,   // DCT_FLIPADST
1957   fadst8x8_sse4_1,  // FLIPADST_FLIPADST
1958   fadst8x8_sse4_1,  // ADST_FLIPADST
1959   fadst8x8_sse4_1,  // FLIPADST_ADST
1960   idtx8x8_sse4_1,   // IDTX
1961   fdct8x8_sse4_1,   // V_DCT
1962   idtx8x8_sse4_1,   // H_DCT
1963   fadst8x8_sse4_1,  // V_ADST
1964   idtx8x8_sse4_1,   // H_ADST
1965   fadst8x8_sse4_1,  // V_FLIPADST
1966   idtx8x8_sse4_1    // H_FLIPADST
1967 };
1968 #if !CONFIG_REALTIME_ONLY
1969 static const fwd_transform_1d_sse4_1 row_highbd_txfm32x8_arr[TX_TYPES] = {
1970   fdct8x8_sse4_1,   // DCT_DCT
1971   NULL,             // ADST_DCT
1972   NULL,             // DCT_ADST
1973   NULL,             // ADST_ADST
1974   NULL,             // FLIPADST_DCT
1975   NULL,             // DCT_FLIPADST
1976   NULL,             // FLIPADST_FLIPADST
1977   NULL,             // ADST_FLIPADST
1978   NULL,             // FLIPADST-ADST
1979   idtx32x8_sse4_1,  // IDTX
1980   NULL,             // V_DCT
1981   NULL,             // H_DCT
1982   NULL,             // V_ADST
1983   NULL,             // H_ADST
1984   NULL,             // V_FLIPADST
1985   NULL,             // H_FLIPADST
1986 };
1987 #endif
1988 static const fwd_transform_1d_sse4_1 col_highbd_txfm4x8_arr[TX_TYPES] = {
1989   fdct4x8_sse4_1,   // DCT_DCT
1990   fadst8x8_sse4_1,  // ADST_DCT
1991   fdct4x8_sse4_1,   // DCT_ADST
1992   fadst8x8_sse4_1,  // ADST_ADST
1993   fadst8x8_sse4_1,  // FLIPADST_DCT
1994   fdct4x8_sse4_1,   // DCT_FLIPADST
1995   fadst8x8_sse4_1,  // FLIPADST_FLIPADST
1996   fadst8x8_sse4_1,  // ADST_FLIPADST
1997   fadst8x8_sse4_1,  // FLIPADST_ADST
1998   idtx8x8_sse4_1,   // IDTX
1999   fdct4x8_sse4_1,   // V_DCT
2000   idtx8x8_sse4_1,   // H_DCT
2001   fadst8x8_sse4_1,  // V_ADST
2002   idtx8x8_sse4_1,   // H_ADST
2003   fadst8x8_sse4_1,  // V_FLIPADST
2004   idtx8x8_sse4_1    // H_FLIPADST
2005 };
2006 
2007 static const fwd_transform_1d_sse4_1 row_highbd_txfm8x16_arr[TX_TYPES] = {
2008   fdct16x16_sse4_1,   // DCT_DCT
2009   fdct16x16_sse4_1,   // ADST_DCT
2010   fadst16x16_sse4_1,  // DCT_ADST
2011   fadst16x16_sse4_1,  // ADST_ADST
2012   fdct16x16_sse4_1,   // FLIPADST_DCT
2013   fadst16x16_sse4_1,  // DCT_FLIPADST
2014   fadst16x16_sse4_1,  // FLIPADST_FLIPADST
2015   fadst16x16_sse4_1,  // ADST_FLIPADST
2016   fadst16x16_sse4_1,  // FLIPADST_ADST
2017   idtx16x16_sse4_1,   // IDTX
2018   idtx16x16_sse4_1,   // V_DCT
2019   fdct16x16_sse4_1,   // H_DCT
2020   idtx16x16_sse4_1,   // V_ADST
2021   fadst16x16_sse4_1,  // H_ADST
2022   idtx16x16_sse4_1,   // V_FLIPADST
2023   fadst16x16_sse4_1   // H_FLIPADST
2024 };
2025 
2026 static const fwd_transform_1d_sse4_1 col_highbd_txfm8x16_arr[TX_TYPES] = {
2027   fdct16x16_sse4_1,   // DCT_DCT
2028   fadst16x16_sse4_1,  // ADST_DCT
2029   fdct16x16_sse4_1,   // DCT_ADST
2030   fadst16x16_sse4_1,  // ADST_ADST
2031   fadst16x16_sse4_1,  // FLIPADST_DCT
2032   fdct16x16_sse4_1,   // DCT_FLIPADST
2033   fadst16x16_sse4_1,  // FLIPADST_FLIPADST
2034   fadst16x16_sse4_1,  // ADST_FLIPADST
2035   fadst16x16_sse4_1,  // FLIPADST_ADST
2036   idtx16x16_sse4_1,   // IDTX
2037   fdct16x16_sse4_1,   // V_DCT
2038   idtx16x16_sse4_1,   // H_DCT
2039   fadst16x16_sse4_1,  // V_ADST
2040   idtx16x16_sse4_1,   // H_ADST
2041   fadst16x16_sse4_1,  // V_FLIPADST
2042   idtx16x16_sse4_1    // H_FLIPADST
2043 };
2044 static const fwd_transform_1d_sse4_1 row_highbd_txfm8x8_arr[TX_TYPES] = {
2045   fdct8x8_sse4_1,   // DCT_DCT
2046   fdct8x8_sse4_1,   // ADST_DCT
2047   fadst8x8_sse4_1,  // DCT_ADST
2048   fadst8x8_sse4_1,  // ADST_ADST
2049   fdct8x8_sse4_1,   // FLIPADST_DCT
2050   fadst8x8_sse4_1,  // DCT_FLIPADST
2051   fadst8x8_sse4_1,  // FLIPADST_FLIPADST
2052   fadst8x8_sse4_1,  // ADST_FLIPADST
2053   fadst8x8_sse4_1,  // FLIPADST_ADST
2054   idtx8x8_sse4_1,   // IDTX
2055   idtx8x8_sse4_1,   // V_DCT
2056   fdct8x8_sse4_1,   // H_DCT
2057   idtx8x8_sse4_1,   // V_ADST
2058   fadst8x8_sse4_1,  // H_ADST
2059   idtx8x8_sse4_1,   // V_FLIPADST
2060   fadst8x8_sse4_1   // H_FLIPADST
2061 };
2062 
2063 static const fwd_transform_1d_sse4_1 row_highbd_txfm4x8_arr[TX_TYPES] = {
2064   fdct4x8_sse4_1,   // DCT_DCT
2065   fdct4x8_sse4_1,   // ADST_DCT
2066   fadst8x8_sse4_1,  // DCT_ADST
2067   fadst8x8_sse4_1,  // ADST_ADST
2068   fdct4x8_sse4_1,   // FLIPADST_DCT
2069   fadst8x8_sse4_1,  // DCT_FLIPADST
2070   fadst8x8_sse4_1,  // FLIPADST_FLIPADST
2071   fadst8x8_sse4_1,  // ADST_FLIPADST
2072   fadst8x8_sse4_1,  // FLIPADST_ADST
2073   idtx8x8_sse4_1,   // IDTX
2074   idtx8x8_sse4_1,   // V_DCT
2075   fdct4x8_sse4_1,   // H_DCT
2076   idtx8x8_sse4_1,   // V_ADST
2077   fadst8x8_sse4_1,  // H_ADST
2078   idtx8x8_sse4_1,   // V_FLIPADST
2079   fadst8x8_sse4_1   // H_FLIPADST
2080 };
2081 
2082 static const fwd_transform_1d_sse4_1 row_highbd_txfm4x4_arr[TX_TYPES] = {
2083   fdct4x4_sse4_1,   // DCT_DCT
2084   fdct4x4_sse4_1,   // ADST_DCT
2085   fadst4x4_sse4_1,  // DCT_ADST
2086   fadst4x4_sse4_1,  // ADST_ADST
2087   fdct4x4_sse4_1,   // FLIPADST_DCT
2088   fadst4x4_sse4_1,  // DCT_FLIPADST
2089   fadst4x4_sse4_1,  // FLIPADST_FLIPADST
2090   fadst4x4_sse4_1,  // ADST_FLIPADST
2091   fadst4x4_sse4_1,  // FLIPADST_ADST
2092   idtx4x4_sse4_1,   // IDTX
2093   idtx4x4_sse4_1,   // V_DCT
2094   fdct4x4_sse4_1,   // H_DCT
2095   idtx4x4_sse4_1,   // V_ADST
2096   fadst4x4_sse4_1,  // H_ADST
2097   idtx4x4_sse4_1,   // V_FLIPADST
2098   fadst4x4_sse4_1   // H_FLIPADST
2099 };
2100 
2101 static const fwd_transform_1d_sse4_1 col_highbd_txfm4x4_arr[TX_TYPES] = {
2102   fdct4x4_sse4_1,   // DCT_DCT
2103   fadst4x4_sse4_1,  // ADST_DCT
2104   fdct4x4_sse4_1,   // DCT_ADST
2105   fadst4x4_sse4_1,  // ADST_ADST
2106   fadst4x4_sse4_1,  // FLIPADST_DCT
2107   fdct4x4_sse4_1,   // DCT_FLIPADST
2108   fadst4x4_sse4_1,  // FLIPADST_FLIPADST
2109   fadst4x4_sse4_1,  // ADST_FLIPADST
2110   fadst4x4_sse4_1,  // FLIPADST_ADST
2111   idtx4x4_sse4_1,   // IDTX
2112   fdct4x4_sse4_1,   // V_DCT
2113   idtx4x4_sse4_1,   // H_DCT
2114   fadst4x4_sse4_1,  // V_ADST
2115   idtx4x4_sse4_1,   // H_ADST
2116   fadst4x4_sse4_1,  // V_FLIPADST
2117   idtx4x4_sse4_1    // H_FLIPADST
2118 };
2119 
2120 static const fwd_transform_1d_sse4_1 col_highbd_txfm8x32_arr[TX_TYPES] = {
2121   av1_fdct32_sse4_1,  // DCT_DCT
2122   NULL,               // ADST_DCT
2123   NULL,               // DCT_ADST
2124   NULL,               // ADST_ADST
2125   NULL,               // FLIPADST_DCT
2126   NULL,               // DCT_FLIPADST
2127   NULL,               // FLIPADST_FLIPADST
2128   NULL,               // ADST_FLIPADST
2129   NULL,               // FLIPADST_ADST
2130   av1_idtx32_sse4_1,  // IDTX
2131   NULL,               // V_DCT
2132   NULL,               // H_DCT
2133   NULL,               // V_ADST
2134   NULL,               // H_ADST
2135   NULL,               // V_FLIPADST
2136   NULL                // H_FLIPADST
2137 };
2138 
2139 static const fwd_transform_1d_sse4_1 row_highbd_txfm8x32_arr[TX_TYPES] = {
2140   fdct16x16_sse4_1,  // DCT_DCT
2141   NULL,              // ADST_DCT
2142   NULL,              // DCT_ADST
2143   NULL,              // ADST_ADST
2144   NULL,              // FLIPADST_DCT
2145   NULL,              // DCT_FLIPADST
2146   NULL,              // FLIPADST_FLIPADST
2147   NULL,              // ADST_FLIPADST
2148   NULL,              // FLIPADST_ADST
2149   idtx16x16_sse4_1,  // IDTX
2150   NULL,              // V_DCT
2151   NULL,              // H_DCT
2152   NULL,              // V_ADST
2153   NULL,              // H_ADST
2154   NULL,              // V_FLIPADST
2155   NULL               // H_FLIPADST
2156 };
2157 
av1_fwd_txfm2d_16x8_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2158 void av1_fwd_txfm2d_16x8_sse4_1(const int16_t *input, int32_t *coeff,
2159                                 int stride, TX_TYPE tx_type, int bd) {
2160   __m128i in[32], out[32];
2161   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_16X8];
2162   const int txw_idx = get_txw_idx(TX_16X8);
2163   const int txh_idx = get_txh_idx(TX_16X8);
2164   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm8x8_arr[tx_type];
2165   const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm8x16_arr[tx_type];
2166   int bit = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2167   int ud_flip, lr_flip;
2168   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
2169 
2170   for (int i = 0; i < 2; i++) {
2171     load_buffer_8x8(input + i * 8, in, stride, ud_flip, 0, shift[0]);
2172     col_txfm(in, in, bit, 2);
2173     col_txfm_8x8_rounding(in, -shift[1]);
2174     transpose_8x8(in, out + i * 16);
2175   }
2176 
2177   if (lr_flip) {
2178     flip_buf_sse4_1(in, out, 32);
2179     row_txfm(in, out, bit, 2);
2180   } else {
2181     row_txfm(out, out, bit, 2);
2182   }
2183 
2184   for (int i = 0; i < 2; i++) {
2185     av1_round_shift_rect_array_32_sse4_1(out + i * 16, in, 16, -shift[2],
2186                                          NewSqrt2);
2187     write_buffer_8x8(in, coeff + i * 64);
2188   }
2189   (void)bd;
2190 }
2191 
av1_fwd_txfm2d_8x16_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2192 void av1_fwd_txfm2d_8x16_sse4_1(const int16_t *input, int32_t *coeff,
2193                                 int stride, TX_TYPE tx_type, int bd) {
2194   __m128i in[32], out[32];
2195   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_8X16];
2196   const int txw_idx = get_txw_idx(TX_8X16);
2197   const int txh_idx = get_txh_idx(TX_8X16);
2198   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm8x16_arr[tx_type];
2199   const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm8x8_arr[tx_type];
2200   int bit = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2201   int ud_flip, lr_flip;
2202   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
2203 
2204   load_buffer_8x16(input, in, stride, ud_flip, lr_flip, shift[0]);
2205   col_txfm(in, in, bit, 2);
2206   col_txfm_8x16_rounding(in, -shift[1]);
2207   transpose_8x8(in, out);
2208   transpose_8x8(in + 16, out + 16);
2209 
2210   for (int i = 0; i < 2; i++) {
2211     row_txfm(out + i * 16, out, bit, 2);
2212     av1_round_shift_rect_array_32_sse4_1(out, out, 16, -shift[2], NewSqrt2);
2213     write_buffer_16x8(out, coeff + i * 8, 16);
2214   }
2215   (void)bd;
2216 }
2217 
2218 #if !CONFIG_REALTIME_ONLY
av1_fwd_txfm2d_4x16_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2219 void av1_fwd_txfm2d_4x16_sse4_1(const int16_t *input, int32_t *coeff,
2220                                 int stride, TX_TYPE tx_type, int bd) {
2221   __m128i in[16];
2222   __m128i *outcoeff128 = (__m128i *)coeff;
2223   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_4X16];
2224   const int txw_idx = get_txw_idx(TX_4X16);
2225   const int txh_idx = get_txh_idx(TX_4X16);
2226   const int txfm_size_col = tx_size_wide[TX_4X16];
2227   const int txfm_size_row = tx_size_high[TX_4X16];
2228   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2229   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2230   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm8x16_arr[tx_type];
2231   const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm4x4_arr[tx_type];
2232 
2233   int ud_flip, lr_flip;
2234   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
2235   // col transform
2236   load_buffer_4x16(input, in, stride, ud_flip, lr_flip, shift[0]);
2237   col_txfm(in, outcoeff128, bitcol, 1);
2238   col_txfm_8x8_rounding(outcoeff128, -shift[1]);
2239   transpose_8nx8n(outcoeff128, in, txfm_size_col, txfm_size_row);
2240 
2241   // row transform
2242   for (int i = 0; i < 4; i++) {
2243     __m128i tmp[4];
2244     row_txfm(in + i, tmp, bitrow, txfm_size_row >> 2);
2245     store_output_w4(coeff + i * 4, tmp, txfm_size_row, txfm_size_col);
2246   }
2247   (void)bd;
2248 }
2249 #endif
2250 
av1_fwd_txfm2d_16x4_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2251 void av1_fwd_txfm2d_16x4_sse4_1(const int16_t *input, int32_t *coeff,
2252                                 int stride, TX_TYPE tx_type, int bd) {
2253   __m128i in[16];
2254   __m128i *outcoeff128 = (__m128i *)coeff;
2255   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_16X4];
2256   const int txw_idx = get_txw_idx(TX_16X4);
2257   const int txh_idx = get_txh_idx(TX_16X4);
2258   const int txfm_size_col = tx_size_wide[TX_16X4];
2259   const int txfm_size_row = tx_size_high[TX_16X4];
2260   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2261   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2262   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm4x4_arr[tx_type];
2263   const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm8x16_arr[tx_type];
2264   int ud_flip, lr_flip;
2265   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
2266 
2267   // col transform
2268   load_buffer_16x4(input, in, stride, ud_flip, lr_flip, shift[0]);
2269 
2270   for (int i = 0; i < (txfm_size_col >> 2); i++) {
2271     __m128i *cur_in = &in[i * txfm_size_row];
2272     col_txfm(cur_in, cur_in, bitcol, 1);
2273     transpose_32bit_4x4(cur_in, cur_in);
2274   }
2275   col_txfm_8x8_rounding(in, -shift[1]);
2276 
2277   // row transform
2278   row_txfm(in, outcoeff128, bitrow, 1);
2279   (void)bd;
2280 }
2281 
av1_fwd_txfm2d_16x32_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2282 void av1_fwd_txfm2d_16x32_sse4_1(const int16_t *input, int32_t *coeff,
2283                                  int stride, TX_TYPE tx_type, int bd) {
2284   __m128i in[128];
2285   __m128i *outcoef128 = (__m128i *)coeff;
2286   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_16X32];
2287   const int txw_idx = get_txw_idx(TX_16X32);
2288   const int txh_idx = get_txh_idx(TX_16X32);
2289   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm8x32_arr[tx_type];
2290   const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm8x32_arr[tx_type];
2291   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2292   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2293 
2294   // column transform
2295   load_buffer_16x16(input, in, stride, 0, 0, shift[0]);
2296   load_buffer_16x16(input + 16 * stride, in + 64, stride, 0, 0, shift[0]);
2297 
2298   for (int i = 0; i < 4; i++) {
2299     col_txfm((in + i), (in + i), bitcol, 4);
2300   }
2301   col_txfm_16x16_rounding(&in[0], -shift[1]);
2302   col_txfm_16x16_rounding(&in[64], -shift[1]);
2303   transpose_8nx8n(in, outcoef128, 16, 32);
2304 
2305   // row transform
2306   row_txfm(outcoef128, in, bitrow, 8);
2307   av1_round_shift_rect_array_32_sse4_1(in, outcoef128, 128, -shift[2],
2308                                        NewSqrt2);
2309   (void)bd;
2310 }
2311 
av1_fwd_txfm2d_32x64_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2312 void av1_fwd_txfm2d_32x64_sse4_1(const int16_t *input, int32_t *coeff,
2313                                  int stride, TX_TYPE tx_type, int bd) {
2314   (void)tx_type;
2315   __m128i in[512];
2316   __m128i *outcoef128 = (__m128i *)coeff;
2317   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_32X64];
2318   const int txw_idx = get_txw_idx(TX_32X64);
2319   const int txh_idx = get_txh_idx(TX_32X64);
2320   const int txfm_size_col = tx_size_wide[TX_32X64];
2321   const int txfm_size_row = tx_size_high[TX_32X64];
2322   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2323   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2324   const int num_row = txfm_size_row >> 2;
2325   const int num_col = txfm_size_col >> 2;
2326 
2327   // column transform
2328   load_buffer_32x8n(input, in, stride, 0, 0, shift[0], txfm_size_row);
2329   for (int i = 0; i < num_col; i++) {
2330     av1_fdct64_sse4_1((in + i), (in + i), bitcol, num_col, num_col);
2331   }
2332   for (int i = 0; i < num_col; i++) {
2333     col_txfm_16x16_rounding((in + i * txfm_size_row), -shift[1]);
2334   }
2335   transpose_8nx8n(in, outcoef128, txfm_size_col, txfm_size_row);
2336 
2337   // row transform
2338   for (int i = 0; i < num_row; i++) {
2339     av1_fdct32_sse4_1((outcoef128 + i), (in + i), bitrow, num_row);
2340   }
2341   for (int i = 0; i < txfm_size_col; i++) {
2342     av1_round_shift_rect_array_32_sse4_1(in + i * 16, outcoef128 + i * 8, 8,
2343                                          -shift[2], NewSqrt2);
2344   }
2345   (void)bd;
2346 }
2347 
av1_fwd_txfm2d_64x32_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2348 void av1_fwd_txfm2d_64x32_sse4_1(const int16_t *input, int32_t *coeff,
2349                                  int stride, TX_TYPE tx_type, int bd) {
2350   (void)tx_type;
2351   __m128i in[512];
2352   __m128i *outcoef128 = (__m128i *)coeff;
2353   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_64X32];
2354   const int txw_idx = get_txw_idx(TX_64X32);
2355   const int txh_idx = get_txh_idx(TX_64X32);
2356   const int txfm_size_col = tx_size_wide[TX_64X32];
2357   const int txfm_size_row = tx_size_high[TX_64X32];
2358   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2359   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2360   const int num_row = txfm_size_row >> 2;
2361   const int num_col = txfm_size_col >> 2;
2362 
2363   // column transform
2364   for (int i = 0; i < 32; i++) {
2365     load_buffer_4x4(input + 0 + i * stride, in + 0 + i * 16, 4, 0, 0, shift[0]);
2366     load_buffer_4x4(input + 16 + i * stride, in + 4 + i * 16, 4, 0, 0,
2367                     shift[0]);
2368     load_buffer_4x4(input + 32 + i * stride, in + 8 + i * 16, 4, 0, 0,
2369                     shift[0]);
2370     load_buffer_4x4(input + 48 + i * stride, in + 12 + i * 16, 4, 0, 0,
2371                     shift[0]);
2372   }
2373 
2374   for (int i = 0; i < num_col; i++) {
2375     av1_fdct32_sse4_1((in + i), (in + i), bitcol, num_col);
2376   }
2377 
2378   for (int i = 0; i < num_row; i++) {
2379     col_txfm_16x16_rounding((in + i * txfm_size_col), -shift[1]);
2380   }
2381   transpose_8nx8n(in, outcoef128, txfm_size_col, txfm_size_row);
2382 
2383   // row transform
2384   for (int i = 0; i < num_row; i++) {
2385     av1_fdct64_sse4_1((outcoef128 + i), (in + i), bitrow, num_row, num_row);
2386   }
2387   av1_round_shift_rect_array_32_sse4_1(in, outcoef128, 512, -shift[2],
2388                                        NewSqrt2);
2389   (void)bd;
2390 }
2391 
av1_fwd_txfm2d_32x16_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2392 void av1_fwd_txfm2d_32x16_sse4_1(const int16_t *input, int32_t *coeff,
2393                                  int stride, TX_TYPE tx_type, int bd) {
2394   __m128i in[128];
2395   __m128i *outcoef128 = (__m128i *)coeff;
2396   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_32X16];
2397   const int txw_idx = get_txw_idx(TX_32X16);
2398   const int txh_idx = get_txh_idx(TX_32X16);
2399   const fwd_transform_1d_sse4_1 col_txfm = row_highbd_txfm8x32_arr[tx_type];
2400   const fwd_transform_1d_sse4_1 row_txfm = col_highbd_txfm8x32_arr[tx_type];
2401   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2402   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2403 
2404   // column transform
2405   load_buffer_32x8n(input, in, stride, 0, 0, shift[0], 16);
2406   col_txfm(in, in, bitcol, 8);
2407   col_txfm_16x16_rounding(&in[0], -shift[1]);
2408   col_txfm_16x16_rounding(&in[64], -shift[1]);
2409   transpose_8nx8n(in, outcoef128, 32, 16);
2410 
2411   // row transform
2412   for (int i = 0; i < 4; i++) {
2413     row_txfm((outcoef128 + i), (in + i), bitrow, 4);
2414   }
2415   av1_round_shift_rect_array_32_sse4_1(in, outcoef128, 128, -shift[2],
2416                                        NewSqrt2);
2417   (void)bd;
2418 }
2419 
2420 #if !CONFIG_REALTIME_ONLY
av1_fwd_txfm2d_8x32_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2421 void av1_fwd_txfm2d_8x32_sse4_1(const int16_t *input, int32_t *coeff,
2422                                 int stride, TX_TYPE tx_type, int bd) {
2423   __m128i in[64];
2424   __m128i *outcoef128 = (__m128i *)coeff;
2425   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_8X32];
2426   const int txw_idx = get_txw_idx(TX_8X32);
2427   const int txh_idx = get_txh_idx(TX_8X32);
2428   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm8x32_arr[tx_type];
2429   const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm32x8_arr[tx_type];
2430   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2431   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2432 
2433   const int txfm_size_col = tx_size_wide[TX_8X32];
2434   const int txfm_size_row = tx_size_high[TX_8X32];
2435   const int num_col = txfm_size_col >> 2;
2436 
2437   // column transform
2438   load_buffer_8x16(input, in, stride, 0, 0, shift[0]);
2439   load_buffer_8x16(input + (txfm_size_row >> 1) * stride, in + txfm_size_row,
2440                    stride, 0, 0, shift[0]);
2441 
2442   for (int i = 0; i < num_col; i++) {
2443     col_txfm((in + i), (in + i), bitcol, num_col);
2444   }
2445   col_txfm_16x16_rounding(in, -shift[1]);
2446   transpose_8nx8n(in, outcoef128, txfm_size_col, txfm_size_row);
2447 
2448   // row transform
2449   for (int i = 0; i < txfm_size_col; i += 2) {
2450     row_txfm((outcoef128 + i), (outcoef128 + i), bitrow, txfm_size_col);
2451   }
2452   (void)bd;
2453 }
2454 
av1_fwd_txfm2d_32x8_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2455 void av1_fwd_txfm2d_32x8_sse4_1(const int16_t *input, int32_t *coeff,
2456                                 int stride, TX_TYPE tx_type, int bd) {
2457   __m128i in[64];
2458   __m128i *outcoef128 = (__m128i *)coeff;
2459   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_32X8];
2460   const int txw_idx = get_txw_idx(TX_32X8);
2461   const int txh_idx = get_txh_idx(TX_32X8);
2462   const fwd_transform_1d_sse4_1 col_txfm = row_highbd_txfm32x8_arr[tx_type];
2463   const fwd_transform_1d_sse4_1 row_txfm = col_highbd_txfm8x32_arr[tx_type];
2464   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2465   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2466 
2467   const int txfm_size_col = tx_size_wide[TX_32X8];
2468   const int txfm_size_row = tx_size_high[TX_32X8];
2469   const int num_col = txfm_size_row >> 2;
2470 
2471   // column transform
2472   load_buffer_32x8n(input, in, stride, 0, 0, shift[0], 8);
2473   for (int i = 0; i < txfm_size_row; i += 2) {
2474     col_txfm((in + i), (in + i), bitcol, txfm_size_row);
2475   }
2476 
2477   col_txfm_16x16_rounding(&in[0], -shift[1]);
2478   transpose_8nx8n(in, outcoef128, txfm_size_col, txfm_size_row);
2479 
2480   // row transform
2481   for (int i = 0; i < num_col; i++) {
2482     row_txfm((outcoef128 + i), (outcoef128 + i), bitrow, num_col);
2483   }
2484   (void)bd;
2485 }
2486 #endif
2487 
av1_fwd_txfm2d_4x8_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2488 void av1_fwd_txfm2d_4x8_sse4_1(const int16_t *input, int32_t *coeff, int stride,
2489                                TX_TYPE tx_type, int bd) {
2490   __m128i in[8];
2491   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_4X8];
2492   const int txw_idx = get_txw_idx(TX_4X8);
2493   const int txh_idx = get_txh_idx(TX_4X8);
2494   const int txfm_size_col = tx_size_wide[TX_4X8];
2495   const int txfm_size_row = tx_size_high[TX_4X8];
2496   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2497   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2498   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm4x8_arr[tx_type];
2499   const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm4x4_arr[tx_type];
2500 
2501   int ud_flip, lr_flip;
2502   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
2503 
2504   load_buffer_4x8(input, in, stride, ud_flip, lr_flip, shift[0]);
2505   col_txfm(in, in, bitcol, 1);
2506   col_txfm_4x8_rounding(in, -shift[1]);
2507 
2508   for (int i = 0; i < 2; i++) {
2509     __m128i *cur_in = &in[i * 4];
2510     transpose_32bit_4x4(cur_in, cur_in);
2511     row_txfm(cur_in, cur_in, bitrow, 1);
2512     av1_round_shift_rect_array_32_sse4_1(cur_in, cur_in, txfm_size_col,
2513                                          -shift[2], NewSqrt2);
2514     store_output_w4(coeff + i * 4, cur_in, txfm_size_row, 4);
2515   }
2516   (void)bd;
2517 }
2518 
av1_fwd_txfm2d_8x4_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2519 void av1_fwd_txfm2d_8x4_sse4_1(const int16_t *input, int32_t *coeff, int stride,
2520                                TX_TYPE tx_type, int bd) {
2521   __m128i in[8];
2522   __m128i *outcoeff128 = (__m128i *)coeff;
2523   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_8X4];
2524   const int txw_idx = get_txw_idx(TX_8X4);
2525   const int txh_idx = get_txh_idx(TX_8X4);
2526   const int txfm_size_col = tx_size_wide[TX_8X4];
2527   const int txfm_size_row = tx_size_high[TX_8X4];
2528   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2529   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2530   const fwd_transform_1d_sse4_1 col_txfm = col_highbd_txfm4x4_arr[tx_type];
2531   const fwd_transform_1d_sse4_1 row_txfm = row_highbd_txfm4x8_arr[tx_type];
2532   int ud_flip, lr_flip;
2533   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
2534   // col tranform
2535   load_buffer_8x4(input, in, stride, ud_flip, lr_flip, shift[0]);
2536   for (int i = 0; i < 2; i++) {
2537     __m128i *cur_in = &in[i * txfm_size_row];
2538     col_txfm(cur_in, cur_in, bitcol, 1);
2539     transpose_32bit_4x4(cur_in, cur_in);
2540   }
2541   col_txfm_4x8_rounding(in, -shift[1]);
2542 
2543   // row tranform
2544   row_txfm(in, outcoeff128, bitrow, 1);
2545   av1_round_shift_rect_array_32_sse4_1(outcoeff128, outcoeff128, txfm_size_col,
2546                                        -shift[2], NewSqrt2);
2547   (void)bd;
2548 }
2549 
2550 #if !CONFIG_REALTIME_ONLY
av1_fwd_txfm2d_16x64_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2551 void av1_fwd_txfm2d_16x64_sse4_1(const int16_t *input, int32_t *coeff,
2552                                  int stride, TX_TYPE tx_type, int bd) {
2553   __m128i in[256];
2554   __m128i *outcoeff128 = (__m128i *)coeff;
2555   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_16X64];
2556   const int txw_idx = get_txw_idx(TX_16X64);
2557   const int txh_idx = get_txh_idx(TX_16X64);
2558   const int txfm_size_col = tx_size_wide[TX_16X64];
2559   const int txfm_size_row = tx_size_high[TX_16X64];
2560   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2561   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2562   int ud_flip, lr_flip;
2563   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
2564   const int num_col = txfm_size_col >> 2;
2565   // col tranform
2566   for (int i = 0; i < txfm_size_row; i += num_col) {
2567     load_buffer_4x4(input + (i + 0) * stride, in + (i + 0) * num_col, num_col,
2568                     ud_flip, lr_flip, shift[0]);
2569     load_buffer_4x4(input + (i + 1) * stride, in + (i + 1) * num_col, num_col,
2570                     ud_flip, lr_flip, shift[0]);
2571     load_buffer_4x4(input + (i + 2) * stride, in + (i + 2) * num_col, num_col,
2572                     ud_flip, lr_flip, shift[0]);
2573     load_buffer_4x4(input + (i + 3) * stride, in + (i + 3) * num_col, num_col,
2574                     ud_flip, lr_flip, shift[0]);
2575   }
2576 
2577   for (int i = 0; i < num_col; i++) {
2578     av1_fdct64_sse4_1(in + i, outcoeff128 + i, bitcol, num_col, num_col);
2579   }
2580 
2581   col_txfm_16x16_rounding(outcoeff128, -shift[1]);
2582   col_txfm_16x16_rounding(outcoeff128 + 64, -shift[1]);
2583   col_txfm_16x16_rounding(outcoeff128 + 128, -shift[1]);
2584   col_txfm_16x16_rounding(outcoeff128 + 192, -shift[1]);
2585 
2586   transpose_8nx8n(outcoeff128, in, txfm_size_col, 32);
2587   fdct16x16_sse4_1(in, outcoeff128, bitrow, 8);
2588   (void)bd;
2589 }
2590 
av1_fwd_txfm2d_64x16_sse4_1(const int16_t * input,int32_t * coeff,int stride,TX_TYPE tx_type,int bd)2591 void av1_fwd_txfm2d_64x16_sse4_1(const int16_t *input, int32_t *coeff,
2592                                  int stride, TX_TYPE tx_type, int bd) {
2593   __m128i in[256];
2594   __m128i *outcoeff128 = (__m128i *)coeff;
2595   const int8_t *shift = av1_fwd_txfm_shift_ls[TX_64X16];
2596   const int txw_idx = get_txw_idx(TX_64X16);
2597   const int txh_idx = get_txh_idx(TX_64X16);
2598   const int txfm_size_col = tx_size_wide[TX_64X16];
2599   const int txfm_size_row = tx_size_high[TX_64X16];
2600   int bitcol = av1_fwd_cos_bit_col[txw_idx][txh_idx];
2601   int bitrow = av1_fwd_cos_bit_row[txw_idx][txh_idx];
2602   int ud_flip, lr_flip;
2603   get_flip_cfg(tx_type, &ud_flip, &lr_flip);
2604   // col tranform
2605   for (int i = 0; i < txfm_size_row; i++) {
2606     load_buffer_4x4(input + 0 + i * stride, in + 0 + i * txfm_size_row, 4,
2607                     ud_flip, lr_flip, shift[0]);
2608     load_buffer_4x4(input + 16 + i * stride, in + 4 + i * txfm_size_row, 4,
2609                     ud_flip, lr_flip, shift[0]);
2610     load_buffer_4x4(input + 32 + i * stride, in + 8 + i * txfm_size_row, 4,
2611                     ud_flip, lr_flip, shift[0]);
2612     load_buffer_4x4(input + 48 + i * stride, in + 12 + i * txfm_size_row, 4,
2613                     ud_flip, lr_flip, shift[0]);
2614   }
2615 
2616   fdct16x16_sse4_1(in, outcoeff128, bitcol, txfm_size_row);
2617   col_txfm_16x16_rounding(outcoeff128, -shift[1]);
2618   col_txfm_16x16_rounding(outcoeff128 + 64, -shift[1]);
2619   col_txfm_16x16_rounding(outcoeff128 + 128, -shift[1]);
2620   col_txfm_16x16_rounding(outcoeff128 + 192, -shift[1]);
2621 
2622   transpose_8nx8n(outcoeff128, in, txfm_size_col, txfm_size_row);
2623   for (int i = 0; i < 4; i++) {
2624     av1_fdct64_sse4_1(in + i, outcoeff128 + i, bitrow, 4, 4);
2625   }
2626   memset(coeff + txfm_size_row * 32, 0, txfm_size_row * 32 * sizeof(*coeff));
2627   (void)bd;
2628 }
2629 #endif
2630