xref: /aosp_15_r20/external/libaom/av1/encoder/x86/av1_fwd_txfm2d_sse4.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include "config/av1_rtcd.h"
13 
14 #include "av1/common/enums.h"
15 #include "av1/common/av1_txfm.h"
16 #include "av1/common/x86/av1_txfm_sse2.h"
17 #include "av1/common/x86/highbd_txfm_utility_sse4.h"
18 #include "av1/encoder/av1_fwd_txfm1d_cfg.h"
19 #include "av1/encoder/x86/av1_txfm1d_sse4.h"
20 #include "av1/encoder/x86/av1_fwd_txfm_sse2.h"
21 
int16_array_with_stride_to_int32_array_without_stride(const int16_t * input,int stride,int32_t * output,int txfm1d_size)22 static inline void int16_array_with_stride_to_int32_array_without_stride(
23     const int16_t *input, int stride, int32_t *output, int txfm1d_size) {
24   int r, c;
25   for (r = 0; r < txfm1d_size; r++) {
26     for (c = 0; c < txfm1d_size; c++) {
27       output[r * txfm1d_size + c] = (int32_t)input[r * stride + c];
28     }
29   }
30 }
31 
store_output_32bit_w8(int32_t * const out,const __m128i * const in1,const __m128i * const in2,const int stride,const int out_size)32 static inline void store_output_32bit_w8(int32_t *const out,
33                                          const __m128i *const in1,
34                                          const __m128i *const in2,
35                                          const int stride, const int out_size) {
36   for (int i = 0; i < out_size; ++i) {
37     _mm_store_si128((__m128i *)(out + stride * i), in1[i]);
38     _mm_store_si128((__m128i *)(out + stride * i + 4), in2[i]);
39   }
40 }
41 
42 typedef void (*TxfmFuncSSE2)(__m128i *input, __m128i *output,
43                              const int8_t cos_bit, const int8_t *stage_range);
44 
fdct32_sse4_1(__m128i * input,__m128i * output,const int8_t cos_bit,const int8_t * stage_range)45 static void fdct32_sse4_1(__m128i *input, __m128i *output, const int8_t cos_bit,
46                           const int8_t *stage_range) {
47   const int txfm_size = 32;
48   const int num_per_128 = 4;
49   int col_num = txfm_size / num_per_128;
50   int col;
51   (void)stage_range;
52   for (col = 0; col < col_num; col++) {
53     av1_fdct32_sse4_1((input + col), (output + col), cos_bit, col_num);
54   }
55 }
56 
fdct64_new_sse4_1(__m128i * input,__m128i * output,const int8_t cos_bit,const int8_t * stage_range)57 static void fdct64_new_sse4_1(__m128i *input, __m128i *output,
58                               const int8_t cos_bit, const int8_t *stage_range) {
59   const int txfm_size = 64;
60   const int num_per_128 = 4;
61   int col_num = txfm_size / num_per_128;
62   (void)stage_range;
63   for (int col = 0; col < col_num; col++) {
64     av1_fdct64_sse4_1((input + col), (output + col), cos_bit, col_num, col_num);
65   }
66 }
idtx32x32_sse4_1(__m128i * input,__m128i * output,const int8_t cos_bit,const int8_t * stage_range)67 static void idtx32x32_sse4_1(__m128i *input, __m128i *output,
68                              const int8_t cos_bit, const int8_t *stage_range) {
69   (void)stage_range;
70 
71   for (int i = 0; i < 8; i++) {
72     av1_idtx32_sse4_1(&input[i * 32], &output[i * 32], cos_bit, 1);
73   }
74 }
75 
fwd_txfm_type_to_func(TXFM_TYPE txfm_type)76 static inline TxfmFuncSSE2 fwd_txfm_type_to_func(TXFM_TYPE txfm_type) {
77   switch (txfm_type) {
78     case TXFM_TYPE_DCT32: return fdct32_sse4_1;
79     case TXFM_TYPE_DCT64: return fdct64_new_sse4_1;
80     case TXFM_TYPE_IDENTITY32: return idtx32x32_sse4_1;
81     default: assert(0);
82   }
83   return NULL;
84 }
85 
fwd_txfm2d_sse4_1(const int16_t * input,int32_t * output,const int stride,const TXFM_2D_FLIP_CFG * cfg,int32_t * txfm_buf)86 static inline void fwd_txfm2d_sse4_1(const int16_t *input, int32_t *output,
87                                      const int stride,
88                                      const TXFM_2D_FLIP_CFG *cfg,
89                                      int32_t *txfm_buf) {
90   // TODO(sarahparker) This does not currently support rectangular transforms
91   // and will break without splitting txfm_size out into row and col size.
92   // Rectangular transforms use c code only, so it should be ok for now.
93   // It will be corrected when there are sse implementations for rectangular
94   // transforms.
95   assert(cfg->tx_size < TX_SIZES);
96   const int txfm_size = tx_size_wide[cfg->tx_size];
97   const int8_t *shift = cfg->shift;
98   const int8_t *stage_range_col = cfg->stage_range_col;
99   const int8_t *stage_range_row = cfg->stage_range_row;
100   const int8_t cos_bit_col = cfg->cos_bit_col;
101   const int8_t cos_bit_row = cfg->cos_bit_row;
102   const TxfmFuncSSE2 txfm_func_col = fwd_txfm_type_to_func(cfg->txfm_type_col);
103   const TxfmFuncSSE2 txfm_func_row = fwd_txfm_type_to_func(cfg->txfm_type_row);
104 
105   __m128i *buf_128 = (__m128i *)txfm_buf;
106   __m128i *out_128 = (__m128i *)output;
107   int num_per_128 = 4;
108   int txfm2d_size_128 = txfm_size * txfm_size / num_per_128;
109 
110   int16_array_with_stride_to_int32_array_without_stride(input, stride, txfm_buf,
111                                                         txfm_size);
112   av1_round_shift_array_32_sse4_1(buf_128, out_128, txfm2d_size_128, -shift[0]);
113   txfm_func_col(out_128, buf_128, cos_bit_col, stage_range_col);
114   av1_round_shift_array_32_sse4_1(buf_128, out_128, txfm2d_size_128, -shift[1]);
115   transpose_32(txfm_size, out_128, buf_128);
116   txfm_func_row(buf_128, out_128, cos_bit_row, stage_range_row);
117   av1_round_shift_array_32_sse4_1(out_128, out_128, txfm2d_size_128, -shift[2]);
118 }
119 
fwd_txfm2d_64x64_sse4_1(const int16_t * input,int32_t * output,const int stride,const TXFM_2D_FLIP_CFG * cfg,int32_t * txfm_buf)120 static inline void fwd_txfm2d_64x64_sse4_1(const int16_t *input,
121                                            int32_t *output, const int stride,
122                                            const TXFM_2D_FLIP_CFG *cfg,
123                                            int32_t *txfm_buf) {
124   assert(cfg->tx_size < TX_SIZES);
125   const int txfm_size = tx_size_wide[cfg->tx_size];
126   const int8_t *shift = cfg->shift;
127   const int8_t *stage_range_col = cfg->stage_range_col;
128   const int8_t cos_bit_col = cfg->cos_bit_col;
129   const int8_t cos_bit_row = cfg->cos_bit_row;
130   const TxfmFuncSSE2 txfm_func_col = fwd_txfm_type_to_func(cfg->txfm_type_col);
131   __m128i *buf_128 = (__m128i *)txfm_buf;
132   __m128i *out_128 = (__m128i *)output;
133 
134   const int num_per_128 = 4;
135   int txfm2d_size_128 = txfm_size * txfm_size / num_per_128;
136   int col_num = txfm_size / num_per_128;
137 
138   int16_array_with_stride_to_int32_array_without_stride(input, stride, output,
139                                                         txfm_size);
140   /*col wise transform*/
141   txfm_func_col(out_128, buf_128, cos_bit_col, stage_range_col);
142   av1_round_shift_array_32_sse4_1(buf_128, out_128, txfm2d_size_128, -shift[1]);
143   transpose_32(txfm_size, out_128, buf_128);
144 
145   /*row wise transform*/
146   for (int col = 0; col < (col_num >> 1); col++) {
147     av1_fdct64_sse4_1((buf_128 + col), (out_128 + col), cos_bit_row, col_num,
148                       (col_num >> 1));
149   }
150 
151   txfm2d_size_128 = (col_num >> 1) * (txfm_size >> 1);
152   av1_round_shift_array_32_sse4_1(out_128, out_128, txfm2d_size_128, -shift[2]);
153 }
154 
av1_fwd_txfm2d_32x32_sse4_1(const int16_t * input,int32_t * output,int stride,TX_TYPE tx_type,int bd)155 void av1_fwd_txfm2d_32x32_sse4_1(const int16_t *input, int32_t *output,
156                                  int stride, TX_TYPE tx_type, int bd) {
157   DECLARE_ALIGNED(16, int32_t, txfm_buf[1024]);
158   TXFM_2D_FLIP_CFG cfg;
159   av1_get_fwd_txfm_cfg(tx_type, TX_32X32, &cfg);
160   (void)bd;
161   fwd_txfm2d_sse4_1(input, output, stride, &cfg, txfm_buf);
162 }
163 
av1_fwd_txfm2d_64x64_sse4_1(const int16_t * input,int32_t * output,int stride,TX_TYPE tx_type,int bd)164 void av1_fwd_txfm2d_64x64_sse4_1(const int16_t *input, int32_t *output,
165                                  int stride, TX_TYPE tx_type, int bd) {
166   DECLARE_ALIGNED(16, int32_t, txfm_buf[4096]);
167   TXFM_2D_FLIP_CFG cfg;
168   av1_get_fwd_txfm_cfg(tx_type, TX_64X64, &cfg);
169   (void)bd;
170   fwd_txfm2d_64x64_sse4_1(input, output, stride, &cfg, txfm_buf);
171 }
172 
lowbd_fwd_txfm2d_64x64_sse4_1(const int16_t * input,int32_t * output,int stride,TX_TYPE tx_type,int bd)173 static void lowbd_fwd_txfm2d_64x64_sse4_1(const int16_t *input, int32_t *output,
174                                           int stride, TX_TYPE tx_type, int bd) {
175   (void)bd;
176   (void)tx_type;
177   assert(tx_type == DCT_DCT);
178   const TX_SIZE tx_size = TX_64X64;
179   __m128i buf0[64], buf1[512];
180   const int8_t *shift = av1_fwd_txfm_shift_ls[tx_size];
181   const int txw_idx = get_txw_idx(tx_size);
182   const int txh_idx = get_txh_idx(tx_size);
183   const int cos_bit_col = av1_fwd_cos_bit_col[txw_idx][txh_idx];
184   const int cos_bit_row = av1_fwd_cos_bit_row[txw_idx][txh_idx];
185   const int width = tx_size_wide[tx_size];
186   const int height = tx_size_high[tx_size];
187   const transform_1d_sse2 col_txfm = av1_fdct8x64_new_sse2;
188   const int width_div8 = (width >> 3);
189   const int height_div8 = (height >> 3);
190 
191   for (int i = 0; i < width_div8; i++) {
192     load_buffer_16bit_to_16bit(input + 8 * i, stride, buf0, height);
193     round_shift_16bit(buf0, height, shift[0]);
194     col_txfm(buf0, buf0, cos_bit_col);
195     round_shift_16bit(buf0, height, shift[1]);
196     for (int j = 0; j < AOMMIN(4, height_div8); ++j) {
197       transpose_16bit_8x8(buf0 + j * 8, buf1 + j * width + 8 * i);
198     }
199   }
200   for (int i = 0; i < AOMMIN(4, height_div8); i++) {
201     __m128i bufA[64];
202     __m128i bufB[64];
203     __m128i *buf = buf1 + width * i;
204     for (int j = 0; j < width; ++j) {
205       bufA[j] = _mm_cvtepi16_epi32(buf[j]);
206       bufB[j] = _mm_cvtepi16_epi32(_mm_unpackhi_epi64(buf[j], buf[j]));
207     }
208     av1_fdct64_sse4_1(bufA, bufA, cos_bit_row, 1, 1);
209     av1_fdct64_sse4_1(bufB, bufB, cos_bit_row, 1, 1);
210     av1_round_shift_array_32_sse4_1(bufA, bufA, 32, -shift[2]);
211     av1_round_shift_array_32_sse4_1(bufB, bufB, 32, -shift[2]);
212 
213     store_output_32bit_w8(output + i * 8, bufA, bufB, 32, 32);
214   }
215 }
216 
lowbd_fwd_txfm2d_64x32_sse4_1(const int16_t * input,int32_t * output,int stride,TX_TYPE tx_type,int bd)217 static void lowbd_fwd_txfm2d_64x32_sse4_1(const int16_t *input, int32_t *output,
218                                           int stride, TX_TYPE tx_type, int bd) {
219   (void)bd;
220   const TX_SIZE tx_size = TX_64X32;
221   __m128i buf0[64], buf1[256];
222   const int8_t *shift = av1_fwd_txfm_shift_ls[tx_size];
223   const int txw_idx = get_txw_idx(tx_size);
224   const int txh_idx = get_txh_idx(tx_size);
225   const int cos_bit_col = av1_fwd_cos_bit_col[txw_idx][txh_idx];
226   const int cos_bit_row = av1_fwd_cos_bit_row[txw_idx][txh_idx];
227   const int width = tx_size_wide[tx_size];
228   const int height = tx_size_high[tx_size];
229   const transform_1d_sse2 col_txfm = col_txfm8x32_arr[tx_type];
230   const int width_div8 = (width >> 3);
231   const int height_div8 = (height >> 3);
232 
233   for (int i = 0; i < width_div8; i++) {
234     load_buffer_16bit_to_16bit(input + 8 * i, stride, buf0, height);
235     round_shift_16bit(buf0, height, shift[0]);
236     col_txfm(buf0, buf0, cos_bit_col);
237     round_shift_16bit(buf0, height, shift[1]);
238     for (int j = 0; j < AOMMIN(4, height_div8); ++j) {
239       transpose_16bit_8x8(buf0 + j * 8, buf1 + j * width + 8 * i);
240     }
241   }
242   assert(tx_type == DCT_DCT);
243   for (int i = 0; i < AOMMIN(4, height_div8); i++) {
244     __m128i bufA[64];
245     __m128i bufB[64];
246     __m128i *buf = buf1 + width * i;
247     for (int j = 0; j < width; ++j) {
248       bufA[j] = _mm_cvtepi16_epi32(buf[j]);
249       bufB[j] = _mm_cvtepi16_epi32(_mm_unpackhi_epi64(buf[j], buf[j]));
250     }
251     av1_fdct64_sse4_1(bufA, bufA, cos_bit_row, 1, 1);
252     av1_fdct64_sse4_1(bufB, bufB, cos_bit_row, 1, 1);
253     av1_round_shift_rect_array_32_sse4_1(bufA, bufA, 32, -shift[2], NewSqrt2);
254     av1_round_shift_rect_array_32_sse4_1(bufB, bufB, 32, -shift[2], NewSqrt2);
255 
256     store_output_32bit_w8(output + i * 8, bufA, bufB, 32, 32);
257   }
258 }
259 
lowbd_fwd_txfm2d_32x64_sse4_1(const int16_t * input,int32_t * output,int stride,TX_TYPE tx_type,int bd)260 static void lowbd_fwd_txfm2d_32x64_sse4_1(const int16_t *input, int32_t *output,
261                                           int stride, TX_TYPE tx_type, int bd) {
262   (void)bd;
263   (void)tx_type;
264   assert(tx_type == DCT_DCT);
265   const TX_SIZE tx_size = TX_32X64;
266   __m128i buf0[64], buf1[256];
267   const int8_t *shift = av1_fwd_txfm_shift_ls[tx_size];
268   const int txw_idx = get_txw_idx(tx_size);
269   const int txh_idx = get_txh_idx(tx_size);
270   const int cos_bit_col = av1_fwd_cos_bit_col[txw_idx][txh_idx];
271   const int cos_bit_row = av1_fwd_cos_bit_row[txw_idx][txh_idx];
272   const int width = tx_size_wide[tx_size];
273   const int height = tx_size_high[tx_size];
274   const transform_1d_sse2 col_txfm = av1_fdct8x64_new_sse2;
275   const int width_div8 = (width >> 3);
276   const int height_div8 = (height >> 3);
277 
278   for (int i = 0; i < width_div8; i++) {
279     load_buffer_16bit_to_16bit(input + 8 * i, stride, buf0, height);
280     round_shift_16bit(buf0, height, shift[0]);
281     col_txfm(buf0, buf0, cos_bit_col);
282     round_shift_16bit(buf0, height, shift[1]);
283     for (int j = 0; j < AOMMIN(4, height_div8); ++j) {
284       transpose_16bit_8x8(buf0 + j * 8, buf1 + j * width + 8 * i);
285     }
286   }
287 
288   for (int i = 0; i < AOMMIN(4, height_div8); i++) {
289     __m128i bufA[32];
290     __m128i bufB[32];
291     __m128i *buf = buf1 + width * i;
292     for (int j = 0; j < width; ++j) {
293       bufA[j] = _mm_cvtepi16_epi32(buf[j]);
294       bufB[j] = _mm_cvtepi16_epi32(_mm_unpackhi_epi64(buf[j], buf[j]));
295     }
296     av1_fdct32_sse4_1(bufA, bufA, cos_bit_row, 1);
297     av1_fdct32_sse4_1(bufB, bufB, cos_bit_row, 1);
298     av1_round_shift_rect_array_32_sse4_1(bufA, bufA, 32, -shift[2], NewSqrt2);
299     av1_round_shift_rect_array_32_sse4_1(bufB, bufB, 32, -shift[2], NewSqrt2);
300 
301     store_output_32bit_w8(output + i * 8, bufA, bufB, 32, 32);
302   }
303 }
304 
305 static FwdTxfm2dFunc fwd_txfm2d_func_ls[TX_SIZES_ALL] = {
306   av1_lowbd_fwd_txfm2d_4x4_sse2,    // 4x4 transform
307   av1_lowbd_fwd_txfm2d_8x8_sse2,    // 8x8 transform
308   av1_lowbd_fwd_txfm2d_16x16_sse2,  // 16x16 transform
309   av1_lowbd_fwd_txfm2d_32x32_sse2,  // 32x32 transform
310   lowbd_fwd_txfm2d_64x64_sse4_1,    // 64x64 transform
311   av1_lowbd_fwd_txfm2d_4x8_sse2,    // 4x8 transform
312   av1_lowbd_fwd_txfm2d_8x4_sse2,    // 8x4 transform
313   av1_lowbd_fwd_txfm2d_8x16_sse2,   // 8x16 transform
314   av1_lowbd_fwd_txfm2d_16x8_sse2,   // 16x8 transform
315   av1_lowbd_fwd_txfm2d_16x32_sse2,  // 16x32 transform
316   av1_lowbd_fwd_txfm2d_32x16_sse2,  // 32x16 transform
317   lowbd_fwd_txfm2d_32x64_sse4_1,    // 32x64 transform
318   lowbd_fwd_txfm2d_64x32_sse4_1,    // 64x32 transform
319   av1_lowbd_fwd_txfm2d_4x16_sse2,   // 4x16 transform
320   av1_lowbd_fwd_txfm2d_16x4_sse2,   // 16x4 transform
321   av1_lowbd_fwd_txfm2d_8x32_sse2,   // 8x32 transform
322   av1_lowbd_fwd_txfm2d_32x8_sse2,   // 32x8 transform
323   av1_lowbd_fwd_txfm2d_16x64_sse2,  // 16x64 transform
324   av1_lowbd_fwd_txfm2d_64x16_sse2,  // 64x16 transform
325 };
326 
av1_lowbd_fwd_txfm_sse4_1(const int16_t * src_diff,tran_low_t * coeff,int diff_stride,TxfmParam * txfm_param)327 void av1_lowbd_fwd_txfm_sse4_1(const int16_t *src_diff, tran_low_t *coeff,
328                                int diff_stride, TxfmParam *txfm_param) {
329   FwdTxfm2dFunc fwd_txfm2d_func = fwd_txfm2d_func_ls[txfm_param->tx_size];
330   if (txfm_param->lossless && txfm_param->tx_size == TX_4X4) {
331     av1_lowbd_fwd_txfm_c(src_diff, coeff, diff_stride, txfm_param);
332   } else {
333     fwd_txfm2d_func(src_diff, coeff, diff_stride, txfm_param->tx_type,
334                     txfm_param->bd);
335   }
336 }
337