xref: /aosp_15_r20/external/libvpx/vp9/encoder/arm/neon/vp9_dct_neon.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2022 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 
13 #include "./vpx_config.h"
14 #include "./vp9_rtcd.h"
15 #include "./vpx_dsp_rtcd.h"
16 
17 #include "vpx_dsp/txfm_common.h"
18 #include "vpx_dsp/arm/mem_neon.h"
19 #include "vpx_dsp/arm/transpose_neon.h"
20 #include "vpx_dsp/arm/fdct_neon.h"
21 #include "vpx_dsp/arm/fdct4x4_neon.h"
22 #include "vpx_dsp/arm/fdct8x8_neon.h"
23 #include "vpx_dsp/arm/fdct16x16_neon.h"
24 
load_buffer_4x4(const int16_t * input,int16x8_t * in,int stride)25 static INLINE void load_buffer_4x4(const int16_t *input, int16x8_t *in,
26                                    int stride) {
27   // { 0, 1, 1, 1 };
28   const int16x4_t nonzero_bias_a = vext_s16(vdup_n_s16(0), vdup_n_s16(1), 3);
29   // { 1, 0, 0, 0 };
30   const int16x4_t nonzero_bias_b = vext_s16(vdup_n_s16(1), vdup_n_s16(0), 3);
31   int16x4_t mask;
32 
33   int16x4_t input_0 = vshl_n_s16(vld1_s16(input + 0 * stride), 4);
34   int16x4_t input_1 = vshl_n_s16(vld1_s16(input + 1 * stride), 4);
35   int16x4_t input_2 = vshl_n_s16(vld1_s16(input + 2 * stride), 4);
36   int16x4_t input_3 = vshl_n_s16(vld1_s16(input + 3 * stride), 4);
37 
38   // Copy the SSE method, use a mask to avoid an 'if' branch here to increase by
39   // one non-zero first elements
40   mask = vreinterpret_s16_u16(vceq_s16(input_0, nonzero_bias_a));
41   input_0 = vadd_s16(input_0, mask);
42   input_0 = vadd_s16(input_0, nonzero_bias_b);
43 
44   in[0] = vcombine_s16(input_0, input_1);
45   in[1] = vcombine_s16(input_2, input_3);
46 }
47 
write_buffer_4x4(tran_low_t * output,int16x8_t * res)48 static INLINE void write_buffer_4x4(tran_low_t *output, int16x8_t *res) {
49   const int16x8_t one_s16 = vdupq_n_s16(1);
50   res[0] = vaddq_s16(res[0], one_s16);
51   res[1] = vaddq_s16(res[1], one_s16);
52   res[0] = vshrq_n_s16(res[0], 2);
53   res[1] = vshrq_n_s16(res[1], 2);
54   store_s16q_to_tran_low(output + 0 * 8, res[0]);
55   store_s16q_to_tran_low(output + 1 * 8, res[1]);
56 }
57 
fadst4x4_neon(int16x8_t * in)58 static INLINE void fadst4x4_neon(int16x8_t *in) {
59   int32x4_t u[4], t[4];
60   int16x4_t s[4], out[4];
61 
62   s[0] = vget_low_s16(in[0]);   // | x_00 | x_01 | x_02 | x_03 |
63   s[1] = vget_high_s16(in[0]);  // | x_10 | x_11 | x_12 | x_13 |
64   s[2] = vget_low_s16(in[1]);   // | x_20 | x_21 | x_22 | x_23 |
65   s[3] = vget_high_s16(in[1]);  // | x_30 | x_31 | x_32 | x_33 |
66 
67   // Must expand all elements to s32. See 'needs32' comment in fwd_txfm.c.
68   // t0 = s0 * sinpi_1_9 + s1 * sinpi_2_9 + s3 * sinpi_4_9
69   t[0] = vmull_n_s16(s[0], sinpi_1_9);
70   t[0] = vmlal_n_s16(t[0], s[1], sinpi_2_9);
71   t[0] = vmlal_n_s16(t[0], s[3], sinpi_4_9);
72 
73   // t1 = (s0 + s1) * sinpi_3_9 - s3 * sinpi_3_9
74   t[1] = vmull_n_s16(s[0], sinpi_3_9);
75   t[1] = vmlal_n_s16(t[1], s[1], sinpi_3_9);
76   t[1] = vmlsl_n_s16(t[1], s[3], sinpi_3_9);
77 
78   // t2 = s0 * sinpi_4_9 - s1* sinpi_1_9 + s3 * sinpi_2_9
79   t[2] = vmull_n_s16(s[0], sinpi_4_9);
80   t[2] = vmlsl_n_s16(t[2], s[1], sinpi_1_9);
81   t[2] = vmlal_n_s16(t[2], s[3], sinpi_2_9);
82 
83   // t3 = s2 * sinpi_3_9
84   t[3] = vmull_n_s16(s[2], sinpi_3_9);
85 
86   /*
87    * u0 = t0 + t3
88    * u1 = t1
89    * u2 = t2 - t3
90    * u3 = t2 - t0 + t3
91    */
92   u[0] = vaddq_s32(t[0], t[3]);
93   u[1] = t[1];
94   u[2] = vsubq_s32(t[2], t[3]);
95   u[3] = vaddq_s32(vsubq_s32(t[2], t[0]), t[3]);
96 
97   // fdct_round_shift
98   out[0] = vrshrn_n_s32(u[0], DCT_CONST_BITS);
99   out[1] = vrshrn_n_s32(u[1], DCT_CONST_BITS);
100   out[2] = vrshrn_n_s32(u[2], DCT_CONST_BITS);
101   out[3] = vrshrn_n_s32(u[3], DCT_CONST_BITS);
102 
103   transpose_s16_4x4d(&out[0], &out[1], &out[2], &out[3]);
104 
105   in[0] = vcombine_s16(out[0], out[1]);
106   in[1] = vcombine_s16(out[2], out[3]);
107 }
108 
vp9_fht4x4_neon(const int16_t * input,tran_low_t * output,int stride,int tx_type)109 void vp9_fht4x4_neon(const int16_t *input, tran_low_t *output, int stride,
110                      int tx_type) {
111   int16x8_t in[2];
112 
113   switch (tx_type) {
114     case DCT_DCT: vpx_fdct4x4_neon(input, output, stride); break;
115     case ADST_DCT:
116       load_buffer_4x4(input, in, stride);
117       fadst4x4_neon(in);
118       // pass1 variant is not accurate enough
119       vpx_fdct4x4_pass2_neon((int16x4_t *)in);
120       write_buffer_4x4(output, in);
121       break;
122     case DCT_ADST:
123       load_buffer_4x4(input, in, stride);
124       // pass1 variant is not accurate enough
125       vpx_fdct4x4_pass2_neon((int16x4_t *)in);
126       fadst4x4_neon(in);
127       write_buffer_4x4(output, in);
128       break;
129     default:
130       assert(tx_type == ADST_ADST);
131       load_buffer_4x4(input, in, stride);
132       fadst4x4_neon(in);
133       fadst4x4_neon(in);
134       write_buffer_4x4(output, in);
135       break;
136   }
137 }
138 
load_buffer_8x8(const int16_t * input,int16x8_t * in,int stride)139 static INLINE void load_buffer_8x8(const int16_t *input, int16x8_t *in,
140                                    int stride) {
141   in[0] = vshlq_n_s16(vld1q_s16(input + 0 * stride), 2);
142   in[1] = vshlq_n_s16(vld1q_s16(input + 1 * stride), 2);
143   in[2] = vshlq_n_s16(vld1q_s16(input + 2 * stride), 2);
144   in[3] = vshlq_n_s16(vld1q_s16(input + 3 * stride), 2);
145   in[4] = vshlq_n_s16(vld1q_s16(input + 4 * stride), 2);
146   in[5] = vshlq_n_s16(vld1q_s16(input + 5 * stride), 2);
147   in[6] = vshlq_n_s16(vld1q_s16(input + 6 * stride), 2);
148   in[7] = vshlq_n_s16(vld1q_s16(input + 7 * stride), 2);
149 }
150 
151 /* right shift and rounding
152  * first get the sign bit (bit 15).
153  * If bit == 1, it's the simple case of shifting right by one bit.
154  * If bit == 2, it essentially computes the expression:
155  *
156  * out[j * 16 + i] = (temp_out[j] + 1 + (temp_out[j] < 0)) >> 2;
157  *
158  * for each row.
159  */
right_shift_8x8(int16x8_t * res,const int bit)160 static INLINE void right_shift_8x8(int16x8_t *res, const int bit) {
161   int16x8_t sign0 = vshrq_n_s16(res[0], 15);
162   int16x8_t sign1 = vshrq_n_s16(res[1], 15);
163   int16x8_t sign2 = vshrq_n_s16(res[2], 15);
164   int16x8_t sign3 = vshrq_n_s16(res[3], 15);
165   int16x8_t sign4 = vshrq_n_s16(res[4], 15);
166   int16x8_t sign5 = vshrq_n_s16(res[5], 15);
167   int16x8_t sign6 = vshrq_n_s16(res[6], 15);
168   int16x8_t sign7 = vshrq_n_s16(res[7], 15);
169 
170   if (bit == 2) {
171     const int16x8_t const_rounding = vdupq_n_s16(1);
172     res[0] = vaddq_s16(res[0], const_rounding);
173     res[1] = vaddq_s16(res[1], const_rounding);
174     res[2] = vaddq_s16(res[2], const_rounding);
175     res[3] = vaddq_s16(res[3], const_rounding);
176     res[4] = vaddq_s16(res[4], const_rounding);
177     res[5] = vaddq_s16(res[5], const_rounding);
178     res[6] = vaddq_s16(res[6], const_rounding);
179     res[7] = vaddq_s16(res[7], const_rounding);
180   }
181 
182   res[0] = vsubq_s16(res[0], sign0);
183   res[1] = vsubq_s16(res[1], sign1);
184   res[2] = vsubq_s16(res[2], sign2);
185   res[3] = vsubq_s16(res[3], sign3);
186   res[4] = vsubq_s16(res[4], sign4);
187   res[5] = vsubq_s16(res[5], sign5);
188   res[6] = vsubq_s16(res[6], sign6);
189   res[7] = vsubq_s16(res[7], sign7);
190 
191   if (bit == 1) {
192     res[0] = vshrq_n_s16(res[0], 1);
193     res[1] = vshrq_n_s16(res[1], 1);
194     res[2] = vshrq_n_s16(res[2], 1);
195     res[3] = vshrq_n_s16(res[3], 1);
196     res[4] = vshrq_n_s16(res[4], 1);
197     res[5] = vshrq_n_s16(res[5], 1);
198     res[6] = vshrq_n_s16(res[6], 1);
199     res[7] = vshrq_n_s16(res[7], 1);
200   } else {
201     res[0] = vshrq_n_s16(res[0], 2);
202     res[1] = vshrq_n_s16(res[1], 2);
203     res[2] = vshrq_n_s16(res[2], 2);
204     res[3] = vshrq_n_s16(res[3], 2);
205     res[4] = vshrq_n_s16(res[4], 2);
206     res[5] = vshrq_n_s16(res[5], 2);
207     res[6] = vshrq_n_s16(res[6], 2);
208     res[7] = vshrq_n_s16(res[7], 2);
209   }
210 }
211 
write_buffer_8x8(tran_low_t * output,int16x8_t * res,int stride)212 static INLINE void write_buffer_8x8(tran_low_t *output, int16x8_t *res,
213                                     int stride) {
214   store_s16q_to_tran_low(output + 0 * stride, res[0]);
215   store_s16q_to_tran_low(output + 1 * stride, res[1]);
216   store_s16q_to_tran_low(output + 2 * stride, res[2]);
217   store_s16q_to_tran_low(output + 3 * stride, res[3]);
218   store_s16q_to_tran_low(output + 4 * stride, res[4]);
219   store_s16q_to_tran_low(output + 5 * stride, res[5]);
220   store_s16q_to_tran_low(output + 6 * stride, res[6]);
221   store_s16q_to_tran_low(output + 7 * stride, res[7]);
222 }
223 
fadst8x8_neon(int16x8_t * in)224 static INLINE void fadst8x8_neon(int16x8_t *in) {
225   int16x4_t x_lo[8], x_hi[8];
226   int32x4_t s_lo[8], s_hi[8];
227   int32x4_t t_lo[8], t_hi[8];
228 
229   x_lo[0] = vget_low_s16(in[7]);
230   x_hi[0] = vget_high_s16(in[7]);
231   x_lo[1] = vget_low_s16(in[0]);
232   x_hi[1] = vget_high_s16(in[0]);
233   x_lo[2] = vget_low_s16(in[5]);
234   x_hi[2] = vget_high_s16(in[5]);
235   x_lo[3] = vget_low_s16(in[2]);
236   x_hi[3] = vget_high_s16(in[2]);
237   x_lo[4] = vget_low_s16(in[3]);
238   x_hi[4] = vget_high_s16(in[3]);
239   x_lo[5] = vget_low_s16(in[4]);
240   x_hi[5] = vget_high_s16(in[4]);
241   x_lo[6] = vget_low_s16(in[1]);
242   x_hi[6] = vget_high_s16(in[1]);
243   x_lo[7] = vget_low_s16(in[6]);
244   x_hi[7] = vget_high_s16(in[6]);
245 
246   // stage 1
247   // s0 = cospi_2_64 * x0 + cospi_30_64 * x1;
248   // s1 = cospi_30_64 * x0 - cospi_2_64 * x1;
249   butterfly_two_coeff_s16_s32_noround(x_lo[0], x_hi[0], x_lo[1], x_hi[1],
250                                       cospi_2_64, cospi_30_64, &s_lo[0],
251                                       &s_hi[0], &s_lo[1], &s_hi[1]);
252 
253   // s2 = cospi_10_64 * x2 + cospi_22_64 * x3;
254   // s3 = cospi_22_64 * x2 - cospi_10_64 * x3;
255   butterfly_two_coeff_s16_s32_noround(x_lo[2], x_hi[2], x_lo[3], x_hi[3],
256                                       cospi_10_64, cospi_22_64, &s_lo[2],
257                                       &s_hi[2], &s_lo[3], &s_hi[3]);
258 
259   // s4 = cospi_18_64 * x4 + cospi_14_64 * x5;
260   // s5 = cospi_14_64 * x4 - cospi_18_64 * x5;
261   butterfly_two_coeff_s16_s32_noround(x_lo[4], x_hi[4], x_lo[5], x_hi[5],
262                                       cospi_18_64, cospi_14_64, &s_lo[4],
263                                       &s_hi[4], &s_lo[5], &s_hi[5]);
264 
265   // s6 = cospi_26_64 * x6 + cospi_6_64 * x7;
266   // s7 = cospi_6_64 * x6 - cospi_26_64 * x7;
267   butterfly_two_coeff_s16_s32_noround(x_lo[6], x_hi[6], x_lo[7], x_hi[7],
268                                       cospi_26_64, cospi_6_64, &s_lo[6],
269                                       &s_hi[6], &s_lo[7], &s_hi[7]);
270 
271   // fdct_round_shift
272   t_lo[0] = vrshrq_n_s32(vaddq_s32(s_lo[0], s_lo[4]), DCT_CONST_BITS);
273   t_hi[0] = vrshrq_n_s32(vaddq_s32(s_hi[0], s_hi[4]), DCT_CONST_BITS);
274   t_lo[1] = vrshrq_n_s32(vaddq_s32(s_lo[1], s_lo[5]), DCT_CONST_BITS);
275   t_hi[1] = vrshrq_n_s32(vaddq_s32(s_hi[1], s_hi[5]), DCT_CONST_BITS);
276   t_lo[2] = vrshrq_n_s32(vaddq_s32(s_lo[2], s_lo[6]), DCT_CONST_BITS);
277   t_hi[2] = vrshrq_n_s32(vaddq_s32(s_hi[2], s_hi[6]), DCT_CONST_BITS);
278   t_lo[3] = vrshrq_n_s32(vaddq_s32(s_lo[3], s_lo[7]), DCT_CONST_BITS);
279   t_hi[3] = vrshrq_n_s32(vaddq_s32(s_hi[3], s_hi[7]), DCT_CONST_BITS);
280   t_lo[4] = vrshrq_n_s32(vsubq_s32(s_lo[0], s_lo[4]), DCT_CONST_BITS);
281   t_hi[4] = vrshrq_n_s32(vsubq_s32(s_hi[0], s_hi[4]), DCT_CONST_BITS);
282   t_lo[5] = vrshrq_n_s32(vsubq_s32(s_lo[1], s_lo[5]), DCT_CONST_BITS);
283   t_hi[5] = vrshrq_n_s32(vsubq_s32(s_hi[1], s_hi[5]), DCT_CONST_BITS);
284   t_lo[6] = vrshrq_n_s32(vsubq_s32(s_lo[2], s_lo[6]), DCT_CONST_BITS);
285   t_hi[6] = vrshrq_n_s32(vsubq_s32(s_hi[2], s_hi[6]), DCT_CONST_BITS);
286   t_lo[7] = vrshrq_n_s32(vsubq_s32(s_lo[3], s_lo[7]), DCT_CONST_BITS);
287   t_hi[7] = vrshrq_n_s32(vsubq_s32(s_hi[3], s_hi[7]), DCT_CONST_BITS);
288 
289   // stage 2
290   s_lo[0] = t_lo[0];
291   s_hi[0] = t_hi[0];
292   s_lo[1] = t_lo[1];
293   s_hi[1] = t_hi[1];
294   s_lo[2] = t_lo[2];
295   s_hi[2] = t_hi[2];
296   s_lo[3] = t_lo[3];
297   s_hi[3] = t_hi[3];
298   // s4 = cospi_8_64 * x4 + cospi_24_64 * x5;
299   // s5 = cospi_24_64 * x4 - cospi_8_64 * x5;
300   butterfly_two_coeff_s32_noround(t_lo[4], t_hi[4], t_lo[5], t_hi[5],
301                                   cospi_8_64, cospi_24_64, &s_lo[4], &s_hi[4],
302                                   &s_lo[5], &s_hi[5]);
303 
304   // s6 = -cospi_24_64 * x6 + cospi_8_64 * x7;
305   // s7 = cospi_8_64 * x6 + cospi_24_64 * x7;
306   butterfly_two_coeff_s32_noround(t_lo[6], t_hi[6], t_lo[7], t_hi[7],
307                                   -cospi_24_64, cospi_8_64, &s_lo[6], &s_hi[6],
308                                   &s_lo[7], &s_hi[7]);
309 
310   // fdct_round_shift
311   // s0 + s2
312   t_lo[0] = vaddq_s32(s_lo[0], s_lo[2]);
313   t_hi[0] = vaddq_s32(s_hi[0], s_hi[2]);
314   // s1 + s3
315   t_lo[1] = vaddq_s32(s_lo[1], s_lo[3]);
316   t_hi[1] = vaddq_s32(s_hi[1], s_hi[3]);
317   // s0 - s2
318   t_lo[2] = vsubq_s32(s_lo[0], s_lo[2]);
319   t_hi[2] = vsubq_s32(s_hi[0], s_hi[2]);
320   // s1 - s3
321   t_lo[3] = vsubq_s32(s_lo[1], s_lo[3]);
322   t_hi[3] = vsubq_s32(s_hi[1], s_hi[3]);
323   // s4 + s6
324   t_lo[4] = vrshrq_n_s32(vaddq_s32(s_lo[4], s_lo[6]), DCT_CONST_BITS);
325   t_hi[4] = vrshrq_n_s32(vaddq_s32(s_hi[4], s_hi[6]), DCT_CONST_BITS);
326   // s5 + s7
327   t_lo[5] = vrshrq_n_s32(vaddq_s32(s_lo[5], s_lo[7]), DCT_CONST_BITS);
328   t_hi[5] = vrshrq_n_s32(vaddq_s32(s_hi[5], s_hi[7]), DCT_CONST_BITS);
329   // s4 - s6
330   t_lo[6] = vrshrq_n_s32(vsubq_s32(s_lo[4], s_lo[6]), DCT_CONST_BITS);
331   t_hi[6] = vrshrq_n_s32(vsubq_s32(s_hi[4], s_hi[6]), DCT_CONST_BITS);
332   // s5 - s7
333   t_lo[7] = vrshrq_n_s32(vsubq_s32(s_lo[5], s_lo[7]), DCT_CONST_BITS);
334   t_hi[7] = vrshrq_n_s32(vsubq_s32(s_hi[5], s_hi[7]), DCT_CONST_BITS);
335 
336   // stage 3
337   // cospi_16_64 * (x2 + x3)
338   // cospi_16_64 * (x2 - x3)
339   butterfly_one_coeff_s32_noround(t_lo[2], t_hi[2], t_lo[3], t_hi[3],
340                                   cospi_16_64, &s_lo[2], &s_hi[2], &s_lo[3],
341                                   &s_hi[3]);
342 
343   // cospi_16_64 * (x6 + x7)
344   // cospi_16_64 * (x2 - x3)
345   butterfly_one_coeff_s32_noround(t_lo[6], t_hi[6], t_lo[7], t_hi[7],
346                                   cospi_16_64, &s_lo[6], &s_hi[6], &s_lo[7],
347                                   &s_hi[7]);
348 
349   // final fdct_round_shift
350   x_lo[2] = vrshrn_n_s32(s_lo[2], DCT_CONST_BITS);
351   x_hi[2] = vrshrn_n_s32(s_hi[2], DCT_CONST_BITS);
352   x_lo[3] = vrshrn_n_s32(s_lo[3], DCT_CONST_BITS);
353   x_hi[3] = vrshrn_n_s32(s_hi[3], DCT_CONST_BITS);
354   x_lo[6] = vrshrn_n_s32(s_lo[6], DCT_CONST_BITS);
355   x_hi[6] = vrshrn_n_s32(s_hi[6], DCT_CONST_BITS);
356   x_lo[7] = vrshrn_n_s32(s_lo[7], DCT_CONST_BITS);
357   x_hi[7] = vrshrn_n_s32(s_hi[7], DCT_CONST_BITS);
358 
359   // x0, x1, x4, x5 narrow down to 16-bits directly
360   x_lo[0] = vmovn_s32(t_lo[0]);
361   x_hi[0] = vmovn_s32(t_hi[0]);
362   x_lo[1] = vmovn_s32(t_lo[1]);
363   x_hi[1] = vmovn_s32(t_hi[1]);
364   x_lo[4] = vmovn_s32(t_lo[4]);
365   x_hi[4] = vmovn_s32(t_hi[4]);
366   x_lo[5] = vmovn_s32(t_lo[5]);
367   x_hi[5] = vmovn_s32(t_hi[5]);
368 
369   in[0] = vcombine_s16(x_lo[0], x_hi[0]);
370   in[1] = vnegq_s16(vcombine_s16(x_lo[4], x_hi[4]));
371   in[2] = vcombine_s16(x_lo[6], x_hi[6]);
372   in[3] = vnegq_s16(vcombine_s16(x_lo[2], x_hi[2]));
373   in[4] = vcombine_s16(x_lo[3], x_hi[3]);
374   in[5] = vnegq_s16(vcombine_s16(x_lo[7], x_hi[7]));
375   in[6] = vcombine_s16(x_lo[5], x_hi[5]);
376   in[7] = vnegq_s16(vcombine_s16(x_lo[1], x_hi[1]));
377 
378   transpose_s16_8x8(&in[0], &in[1], &in[2], &in[3], &in[4], &in[5], &in[6],
379                     &in[7]);
380 }
381 
vp9_fht8x8_neon(const int16_t * input,tran_low_t * output,int stride,int tx_type)382 void vp9_fht8x8_neon(const int16_t *input, tran_low_t *output, int stride,
383                      int tx_type) {
384   int16x8_t in[8];
385 
386   switch (tx_type) {
387     case DCT_DCT: vpx_fdct8x8_neon(input, output, stride); break;
388     case ADST_DCT:
389       load_buffer_8x8(input, in, stride);
390       fadst8x8_neon(in);
391       // pass1 variant is not accurate enough
392       vpx_fdct8x8_pass2_neon(in);
393       right_shift_8x8(in, 1);
394       write_buffer_8x8(output, in, 8);
395       break;
396     case DCT_ADST:
397       load_buffer_8x8(input, in, stride);
398       // pass1 variant is not accurate enough
399       vpx_fdct8x8_pass2_neon(in);
400       fadst8x8_neon(in);
401       right_shift_8x8(in, 1);
402       write_buffer_8x8(output, in, 8);
403       break;
404     default:
405       assert(tx_type == ADST_ADST);
406       load_buffer_8x8(input, in, stride);
407       fadst8x8_neon(in);
408       fadst8x8_neon(in);
409       right_shift_8x8(in, 1);
410       write_buffer_8x8(output, in, 8);
411       break;
412   }
413 }
414 
load_buffer_16x16(const int16_t * input,int16x8_t * in0,int16x8_t * in1,int stride)415 static INLINE void load_buffer_16x16(const int16_t *input, int16x8_t *in0,
416                                      int16x8_t *in1, int stride) {
417   // load first 8 columns
418   load_buffer_8x8(input, in0, stride);
419   load_buffer_8x8(input + 8 * stride, in0 + 8, stride);
420 
421   input += 8;
422   // load second 8 columns
423   load_buffer_8x8(input, in1, stride);
424   load_buffer_8x8(input + 8 * stride, in1 + 8, stride);
425 }
426 
write_buffer_16x16(tran_low_t * output,int16x8_t * in0,int16x8_t * in1,int stride)427 static INLINE void write_buffer_16x16(tran_low_t *output, int16x8_t *in0,
428                                       int16x8_t *in1, int stride) {
429   // write first 8 columns
430   write_buffer_8x8(output, in0, stride);
431   write_buffer_8x8(output + 8 * stride, in0 + 8, stride);
432 
433   // write second 8 columns
434   output += 8;
435   write_buffer_8x8(output, in1, stride);
436   write_buffer_8x8(output + 8 * stride, in1 + 8, stride);
437 }
438 
right_shift_16x16(int16x8_t * res0,int16x8_t * res1)439 static INLINE void right_shift_16x16(int16x8_t *res0, int16x8_t *res1) {
440   // perform rounding operations
441   right_shift_8x8(res0, 2);
442   right_shift_8x8(res0 + 8, 2);
443   right_shift_8x8(res1, 2);
444   right_shift_8x8(res1 + 8, 2);
445 }
446 
fdct16_8col(int16x8_t * in)447 static void fdct16_8col(int16x8_t *in) {
448   // perform 16x16 1-D DCT for 8 columns
449   int16x8_t i[8], s1[8], s2[8], s3[8], t[8];
450   int16x4_t t_lo[8], t_hi[8];
451   int32x4_t u_lo[8], u_hi[8];
452 
453   // stage 1
454   i[0] = vaddq_s16(in[0], in[15]);
455   i[1] = vaddq_s16(in[1], in[14]);
456   i[2] = vaddq_s16(in[2], in[13]);
457   i[3] = vaddq_s16(in[3], in[12]);
458   i[4] = vaddq_s16(in[4], in[11]);
459   i[5] = vaddq_s16(in[5], in[10]);
460   i[6] = vaddq_s16(in[6], in[9]);
461   i[7] = vaddq_s16(in[7], in[8]);
462 
463   // pass1 variant is not accurate enough
464   vpx_fdct8x8_pass2_neon(i);
465   transpose_s16_8x8(&i[0], &i[1], &i[2], &i[3], &i[4], &i[5], &i[6], &i[7]);
466 
467   // step 2
468   s1[0] = vsubq_s16(in[7], in[8]);
469   s1[1] = vsubq_s16(in[6], in[9]);
470   s1[2] = vsubq_s16(in[5], in[10]);
471   s1[3] = vsubq_s16(in[4], in[11]);
472   s1[4] = vsubq_s16(in[3], in[12]);
473   s1[5] = vsubq_s16(in[2], in[13]);
474   s1[6] = vsubq_s16(in[1], in[14]);
475   s1[7] = vsubq_s16(in[0], in[15]);
476 
477   t[2] = vsubq_s16(s1[5], s1[2]);
478   t[3] = vsubq_s16(s1[4], s1[3]);
479   t[4] = vaddq_s16(s1[4], s1[3]);
480   t[5] = vaddq_s16(s1[5], s1[2]);
481 
482   t_lo[2] = vget_low_s16(t[2]);
483   t_hi[2] = vget_high_s16(t[2]);
484   t_lo[3] = vget_low_s16(t[3]);
485   t_hi[3] = vget_high_s16(t[3]);
486   t_lo[4] = vget_low_s16(t[4]);
487   t_hi[4] = vget_high_s16(t[4]);
488   t_lo[5] = vget_low_s16(t[5]);
489   t_hi[5] = vget_high_s16(t[5]);
490 
491   u_lo[2] = vmull_n_s16(t_lo[2], cospi_16_64);
492   u_hi[2] = vmull_n_s16(t_hi[2], cospi_16_64);
493   u_lo[3] = vmull_n_s16(t_lo[3], cospi_16_64);
494   u_hi[3] = vmull_n_s16(t_hi[3], cospi_16_64);
495   u_lo[4] = vmull_n_s16(t_lo[4], cospi_16_64);
496   u_hi[4] = vmull_n_s16(t_hi[4], cospi_16_64);
497   u_lo[5] = vmull_n_s16(t_lo[5], cospi_16_64);
498   u_hi[5] = vmull_n_s16(t_hi[5], cospi_16_64);
499 
500   t_lo[2] = vrshrn_n_s32(u_lo[2], DCT_CONST_BITS);
501   t_hi[2] = vrshrn_n_s32(u_hi[2], DCT_CONST_BITS);
502   t_lo[3] = vrshrn_n_s32(u_lo[3], DCT_CONST_BITS);
503   t_hi[3] = vrshrn_n_s32(u_hi[3], DCT_CONST_BITS);
504   t_lo[4] = vrshrn_n_s32(u_lo[4], DCT_CONST_BITS);
505   t_hi[4] = vrshrn_n_s32(u_hi[4], DCT_CONST_BITS);
506   t_lo[5] = vrshrn_n_s32(u_lo[5], DCT_CONST_BITS);
507   t_hi[5] = vrshrn_n_s32(u_hi[5], DCT_CONST_BITS);
508 
509   s2[2] = vcombine_s16(t_lo[2], t_hi[2]);
510   s2[3] = vcombine_s16(t_lo[3], t_hi[3]);
511   s2[4] = vcombine_s16(t_lo[4], t_hi[4]);
512   s2[5] = vcombine_s16(t_lo[5], t_hi[5]);
513 
514   // step 3
515   s3[0] = vaddq_s16(s1[0], s2[3]);
516   s3[1] = vaddq_s16(s1[1], s2[2]);
517   s3[2] = vsubq_s16(s1[1], s2[2]);
518   s3[3] = vsubq_s16(s1[0], s2[3]);
519   s3[4] = vsubq_s16(s1[7], s2[4]);
520   s3[5] = vsubq_s16(s1[6], s2[5]);
521   s3[6] = vaddq_s16(s1[6], s2[5]);
522   s3[7] = vaddq_s16(s1[7], s2[4]);
523 
524   // step 4
525   t_lo[0] = vget_low_s16(s3[0]);
526   t_hi[0] = vget_high_s16(s3[0]);
527   t_lo[1] = vget_low_s16(s3[1]);
528   t_hi[1] = vget_high_s16(s3[1]);
529   t_lo[2] = vget_low_s16(s3[2]);
530   t_hi[2] = vget_high_s16(s3[2]);
531   t_lo[3] = vget_low_s16(s3[3]);
532   t_hi[3] = vget_high_s16(s3[3]);
533   t_lo[4] = vget_low_s16(s3[4]);
534   t_hi[4] = vget_high_s16(s3[4]);
535   t_lo[5] = vget_low_s16(s3[5]);
536   t_hi[5] = vget_high_s16(s3[5]);
537   t_lo[6] = vget_low_s16(s3[6]);
538   t_hi[6] = vget_high_s16(s3[6]);
539   t_lo[7] = vget_low_s16(s3[7]);
540   t_hi[7] = vget_high_s16(s3[7]);
541 
542   // u[1] = -cospi_8_64 * t[1] + cospi_24_64 * t[6]
543   // u[6] = cospi_24_64 * t[1] + cospi_8_64 * t[6]
544   butterfly_two_coeff_s16_s32_noround(t_lo[1], t_hi[1], t_lo[6], t_hi[6],
545                                       -cospi_8_64, cospi_24_64, &u_lo[1],
546                                       &u_hi[1], &u_lo[6], &u_hi[6]);
547 
548   // u[5] = -cospi_24_64 * t[5] + cospi_8_64 * t[2]
549   // u[2] = cospi_8_64 * t[5]   + cospi_24_64 * t[2]
550   butterfly_two_coeff_s16_s32_noround(t_lo[5], t_hi[5], t_lo[2], t_hi[2],
551                                       -cospi_24_64, cospi_8_64, &u_lo[5],
552                                       &u_hi[5], &u_lo[2], &u_hi[2]);
553 
554   t_lo[1] = vrshrn_n_s32(u_lo[1], DCT_CONST_BITS);
555   t_hi[1] = vrshrn_n_s32(u_hi[1], DCT_CONST_BITS);
556   t_lo[2] = vrshrn_n_s32(u_lo[2], DCT_CONST_BITS);
557   t_hi[2] = vrshrn_n_s32(u_hi[2], DCT_CONST_BITS);
558   t_lo[5] = vrshrn_n_s32(u_lo[5], DCT_CONST_BITS);
559   t_hi[5] = vrshrn_n_s32(u_hi[5], DCT_CONST_BITS);
560   t_lo[6] = vrshrn_n_s32(u_lo[6], DCT_CONST_BITS);
561   t_hi[6] = vrshrn_n_s32(u_hi[6], DCT_CONST_BITS);
562 
563   s2[1] = vcombine_s16(t_lo[1], t_hi[1]);
564   s2[2] = vcombine_s16(t_lo[2], t_hi[2]);
565   s2[5] = vcombine_s16(t_lo[5], t_hi[5]);
566   s2[6] = vcombine_s16(t_lo[6], t_hi[6]);
567 
568   // step 5
569   s1[0] = vaddq_s16(s3[0], s2[1]);
570   s1[1] = vsubq_s16(s3[0], s2[1]);
571   s1[2] = vaddq_s16(s3[3], s2[2]);
572   s1[3] = vsubq_s16(s3[3], s2[2]);
573   s1[4] = vsubq_s16(s3[4], s2[5]);
574   s1[5] = vaddq_s16(s3[4], s2[5]);
575   s1[6] = vsubq_s16(s3[7], s2[6]);
576   s1[7] = vaddq_s16(s3[7], s2[6]);
577 
578   // step 6
579   t_lo[0] = vget_low_s16(s1[0]);
580   t_hi[0] = vget_high_s16(s1[0]);
581   t_lo[1] = vget_low_s16(s1[1]);
582   t_hi[1] = vget_high_s16(s1[1]);
583   t_lo[2] = vget_low_s16(s1[2]);
584   t_hi[2] = vget_high_s16(s1[2]);
585   t_lo[3] = vget_low_s16(s1[3]);
586   t_hi[3] = vget_high_s16(s1[3]);
587   t_lo[4] = vget_low_s16(s1[4]);
588   t_hi[4] = vget_high_s16(s1[4]);
589   t_lo[5] = vget_low_s16(s1[5]);
590   t_hi[5] = vget_high_s16(s1[5]);
591   t_lo[6] = vget_low_s16(s1[6]);
592   t_hi[6] = vget_high_s16(s1[6]);
593   t_lo[7] = vget_low_s16(s1[7]);
594   t_hi[7] = vget_high_s16(s1[7]);
595 
596   // u[0] = step1[7] * cospi_2_64 + step1[0] * cospi_30_64
597   // u[7] = step1[7] * cospi_30_64 - step1[0] * cospi_2_64
598   butterfly_two_coeff_s16_s32_noround(t_lo[7], t_hi[7], t_lo[0], t_hi[0],
599                                       cospi_2_64, cospi_30_64, &u_lo[0],
600                                       &u_hi[0], &u_lo[7], &u_hi[7]);
601 
602   // u[1] = step1[6] * cospi_18_64 + step1[1] * cospi_14_64
603   // u[6] = step1[6] * cospi_14_64 - step1[1] * cospi_18_64
604   butterfly_two_coeff_s16_s32_noround(t_lo[6], t_hi[6], t_lo[1], t_hi[1],
605                                       cospi_18_64, cospi_14_64, &u_lo[1],
606                                       &u_hi[1], &u_lo[6], &u_hi[6]);
607 
608   // u[2] = step1[5] * cospi_10_64 + step1[2] * cospi_22_64
609   // u[5] = step1[5] * cospi_22_64 - step1[2] * cospi_10_64
610   butterfly_two_coeff_s16_s32_noround(t_lo[5], t_hi[5], t_lo[2], t_hi[2],
611                                       cospi_10_64, cospi_22_64, &u_lo[2],
612                                       &u_hi[2], &u_lo[5], &u_hi[5]);
613 
614   // u[3] = step1[4] * cospi_26_64 + step1[3] * cospi_6_64
615   // u[4] = step1[4] * cospi_6_64  - step1[3] * cospi_26_64
616   butterfly_two_coeff_s16_s32_noround(t_lo[4], t_hi[4], t_lo[3], t_hi[3],
617                                       cospi_26_64, cospi_6_64, &u_lo[3],
618                                       &u_hi[3], &u_lo[4], &u_hi[4]);
619 
620   // final fdct_round_shift
621   t_lo[0] = vrshrn_n_s32(u_lo[0], DCT_CONST_BITS);
622   t_hi[0] = vrshrn_n_s32(u_hi[0], DCT_CONST_BITS);
623   t_lo[1] = vrshrn_n_s32(u_lo[1], DCT_CONST_BITS);
624   t_hi[1] = vrshrn_n_s32(u_hi[1], DCT_CONST_BITS);
625   t_lo[2] = vrshrn_n_s32(u_lo[2], DCT_CONST_BITS);
626   t_hi[2] = vrshrn_n_s32(u_hi[2], DCT_CONST_BITS);
627   t_lo[3] = vrshrn_n_s32(u_lo[3], DCT_CONST_BITS);
628   t_hi[3] = vrshrn_n_s32(u_hi[3], DCT_CONST_BITS);
629   t_lo[4] = vrshrn_n_s32(u_lo[4], DCT_CONST_BITS);
630   t_hi[4] = vrshrn_n_s32(u_hi[4], DCT_CONST_BITS);
631   t_lo[5] = vrshrn_n_s32(u_lo[5], DCT_CONST_BITS);
632   t_hi[5] = vrshrn_n_s32(u_hi[5], DCT_CONST_BITS);
633   t_lo[6] = vrshrn_n_s32(u_lo[6], DCT_CONST_BITS);
634   t_hi[6] = vrshrn_n_s32(u_hi[6], DCT_CONST_BITS);
635   t_lo[7] = vrshrn_n_s32(u_lo[7], DCT_CONST_BITS);
636   t_hi[7] = vrshrn_n_s32(u_hi[7], DCT_CONST_BITS);
637 
638   in[0] = i[0];
639   in[2] = i[1];
640   in[4] = i[2];
641   in[6] = i[3];
642   in[8] = i[4];
643   in[10] = i[5];
644   in[12] = i[6];
645   in[14] = i[7];
646   in[1] = vcombine_s16(t_lo[0], t_hi[0]);
647   in[3] = vcombine_s16(t_lo[4], t_hi[4]);
648   in[5] = vcombine_s16(t_lo[2], t_hi[2]);
649   in[7] = vcombine_s16(t_lo[6], t_hi[6]);
650   in[9] = vcombine_s16(t_lo[1], t_hi[1]);
651   in[11] = vcombine_s16(t_lo[5], t_hi[5]);
652   in[13] = vcombine_s16(t_lo[3], t_hi[3]);
653   in[15] = vcombine_s16(t_lo[7], t_hi[7]);
654 }
655 
fadst16_8col(int16x8_t * in)656 static void fadst16_8col(int16x8_t *in) {
657   // perform 16x16 1-D ADST for 8 columns
658   int16x4_t x_lo[16], x_hi[16];
659   int32x4_t s_lo[16], s_hi[16];
660   int32x4_t t_lo[16], t_hi[16];
661 
662   x_lo[0] = vget_low_s16(in[15]);
663   x_hi[0] = vget_high_s16(in[15]);
664   x_lo[1] = vget_low_s16(in[0]);
665   x_hi[1] = vget_high_s16(in[0]);
666   x_lo[2] = vget_low_s16(in[13]);
667   x_hi[2] = vget_high_s16(in[13]);
668   x_lo[3] = vget_low_s16(in[2]);
669   x_hi[3] = vget_high_s16(in[2]);
670   x_lo[4] = vget_low_s16(in[11]);
671   x_hi[4] = vget_high_s16(in[11]);
672   x_lo[5] = vget_low_s16(in[4]);
673   x_hi[5] = vget_high_s16(in[4]);
674   x_lo[6] = vget_low_s16(in[9]);
675   x_hi[6] = vget_high_s16(in[9]);
676   x_lo[7] = vget_low_s16(in[6]);
677   x_hi[7] = vget_high_s16(in[6]);
678   x_lo[8] = vget_low_s16(in[7]);
679   x_hi[8] = vget_high_s16(in[7]);
680   x_lo[9] = vget_low_s16(in[8]);
681   x_hi[9] = vget_high_s16(in[8]);
682   x_lo[10] = vget_low_s16(in[5]);
683   x_hi[10] = vget_high_s16(in[5]);
684   x_lo[11] = vget_low_s16(in[10]);
685   x_hi[11] = vget_high_s16(in[10]);
686   x_lo[12] = vget_low_s16(in[3]);
687   x_hi[12] = vget_high_s16(in[3]);
688   x_lo[13] = vget_low_s16(in[12]);
689   x_hi[13] = vget_high_s16(in[12]);
690   x_lo[14] = vget_low_s16(in[1]);
691   x_hi[14] = vget_high_s16(in[1]);
692   x_lo[15] = vget_low_s16(in[14]);
693   x_hi[15] = vget_high_s16(in[14]);
694 
695   // stage 1
696   // s0 = cospi_1_64 * x0 + cospi_31_64 * x1;
697   // s1 = cospi_31_64 * x0 - cospi_1_64 * x1;
698   butterfly_two_coeff_s16_s32_noround(x_lo[0], x_hi[0], x_lo[1], x_hi[1],
699                                       cospi_1_64, cospi_31_64, &s_lo[0],
700                                       &s_hi[0], &s_lo[1], &s_hi[1]);
701   // s2 = cospi_5_64 * x2 + cospi_27_64 * x3;
702   // s3 = cospi_27_64 * x2 - cospi_5_64 * x3;
703   butterfly_two_coeff_s16_s32_noround(x_lo[2], x_hi[2], x_lo[3], x_hi[3],
704                                       cospi_5_64, cospi_27_64, &s_lo[2],
705                                       &s_hi[2], &s_lo[3], &s_hi[3]);
706   // s4 = cospi_9_64 * x4 + cospi_23_64 * x5;
707   // s5 = cospi_23_64 * x4 - cospi_9_64 * x5;
708   butterfly_two_coeff_s16_s32_noround(x_lo[4], x_hi[4], x_lo[5], x_hi[5],
709                                       cospi_9_64, cospi_23_64, &s_lo[4],
710                                       &s_hi[4], &s_lo[5], &s_hi[5]);
711   // s6 = cospi_13_64 * x6 + cospi_19_64 * x7;
712   // s7 = cospi_19_64 * x6 - cospi_13_64 * x7;
713   butterfly_two_coeff_s16_s32_noround(x_lo[6], x_hi[6], x_lo[7], x_hi[7],
714                                       cospi_13_64, cospi_19_64, &s_lo[6],
715                                       &s_hi[6], &s_lo[7], &s_hi[7]);
716   // s8 = cospi_17_64 * x8 + cospi_15_64 * x9;
717   // s9 = cospi_15_64 * x8 - cospi_17_64 * x9;
718   butterfly_two_coeff_s16_s32_noround(x_lo[8], x_hi[8], x_lo[9], x_hi[9],
719                                       cospi_17_64, cospi_15_64, &s_lo[8],
720                                       &s_hi[8], &s_lo[9], &s_hi[9]);
721   // s10 = cospi_21_64 * x10 + cospi_11_64 * x11;
722   // s11 = cospi_11_64 * x10 - cospi_21_64 * x11;
723   butterfly_two_coeff_s16_s32_noround(x_lo[10], x_hi[10], x_lo[11], x_hi[11],
724                                       cospi_21_64, cospi_11_64, &s_lo[10],
725                                       &s_hi[10], &s_lo[11], &s_hi[11]);
726   // s12 = cospi_25_64 * x12 + cospi_7_64 * x13;
727   // s13 = cospi_7_64 * x12 - cospi_25_64 * x13;
728   butterfly_two_coeff_s16_s32_noround(x_lo[12], x_hi[12], x_lo[13], x_hi[13],
729                                       cospi_25_64, cospi_7_64, &s_lo[12],
730                                       &s_hi[12], &s_lo[13], &s_hi[13]);
731   // s14 = cospi_29_64 * x14 + cospi_3_64 * x15;
732   // s15 = cospi_3_64 * x14 - cospi_29_64 * x15;
733   butterfly_two_coeff_s16_s32_noround(x_lo[14], x_hi[14], x_lo[15], x_hi[15],
734                                       cospi_29_64, cospi_3_64, &s_lo[14],
735                                       &s_hi[14], &s_lo[15], &s_hi[15]);
736 
737   // fdct_round_shift
738   t_lo[0] = vrshrq_n_s32(vaddq_s32(s_lo[0], s_lo[8]), DCT_CONST_BITS);
739   t_hi[0] = vrshrq_n_s32(vaddq_s32(s_hi[0], s_hi[8]), DCT_CONST_BITS);
740   t_lo[1] = vrshrq_n_s32(vaddq_s32(s_lo[1], s_lo[9]), DCT_CONST_BITS);
741   t_hi[1] = vrshrq_n_s32(vaddq_s32(s_hi[1], s_hi[9]), DCT_CONST_BITS);
742   t_lo[2] = vrshrq_n_s32(vaddq_s32(s_lo[2], s_lo[10]), DCT_CONST_BITS);
743   t_hi[2] = vrshrq_n_s32(vaddq_s32(s_hi[2], s_hi[10]), DCT_CONST_BITS);
744   t_lo[3] = vrshrq_n_s32(vaddq_s32(s_lo[3], s_lo[11]), DCT_CONST_BITS);
745   t_hi[3] = vrshrq_n_s32(vaddq_s32(s_hi[3], s_hi[11]), DCT_CONST_BITS);
746   t_lo[4] = vrshrq_n_s32(vaddq_s32(s_lo[4], s_lo[12]), DCT_CONST_BITS);
747   t_hi[4] = vrshrq_n_s32(vaddq_s32(s_hi[4], s_hi[12]), DCT_CONST_BITS);
748   t_lo[5] = vrshrq_n_s32(vaddq_s32(s_lo[5], s_lo[13]), DCT_CONST_BITS);
749   t_hi[5] = vrshrq_n_s32(vaddq_s32(s_hi[5], s_hi[13]), DCT_CONST_BITS);
750   t_lo[6] = vrshrq_n_s32(vaddq_s32(s_lo[6], s_lo[14]), DCT_CONST_BITS);
751   t_hi[6] = vrshrq_n_s32(vaddq_s32(s_hi[6], s_hi[14]), DCT_CONST_BITS);
752   t_lo[7] = vrshrq_n_s32(vaddq_s32(s_lo[7], s_lo[15]), DCT_CONST_BITS);
753   t_hi[7] = vrshrq_n_s32(vaddq_s32(s_hi[7], s_hi[15]), DCT_CONST_BITS);
754   t_lo[8] = vrshrq_n_s32(vsubq_s32(s_lo[0], s_lo[8]), DCT_CONST_BITS);
755   t_hi[8] = vrshrq_n_s32(vsubq_s32(s_hi[0], s_hi[8]), DCT_CONST_BITS);
756   t_lo[9] = vrshrq_n_s32(vsubq_s32(s_lo[1], s_lo[9]), DCT_CONST_BITS);
757   t_hi[9] = vrshrq_n_s32(vsubq_s32(s_hi[1], s_hi[9]), DCT_CONST_BITS);
758   t_lo[10] = vrshrq_n_s32(vsubq_s32(s_lo[2], s_lo[10]), DCT_CONST_BITS);
759   t_hi[10] = vrshrq_n_s32(vsubq_s32(s_hi[2], s_hi[10]), DCT_CONST_BITS);
760   t_lo[11] = vrshrq_n_s32(vsubq_s32(s_lo[3], s_lo[11]), DCT_CONST_BITS);
761   t_hi[11] = vrshrq_n_s32(vsubq_s32(s_hi[3], s_hi[11]), DCT_CONST_BITS);
762   t_lo[12] = vrshrq_n_s32(vsubq_s32(s_lo[4], s_lo[12]), DCT_CONST_BITS);
763   t_hi[12] = vrshrq_n_s32(vsubq_s32(s_hi[4], s_hi[12]), DCT_CONST_BITS);
764   t_lo[13] = vrshrq_n_s32(vsubq_s32(s_lo[5], s_lo[13]), DCT_CONST_BITS);
765   t_hi[13] = vrshrq_n_s32(vsubq_s32(s_hi[5], s_hi[13]), DCT_CONST_BITS);
766   t_lo[14] = vrshrq_n_s32(vsubq_s32(s_lo[6], s_lo[14]), DCT_CONST_BITS);
767   t_hi[14] = vrshrq_n_s32(vsubq_s32(s_hi[6], s_hi[14]), DCT_CONST_BITS);
768   t_lo[15] = vrshrq_n_s32(vsubq_s32(s_lo[7], s_lo[15]), DCT_CONST_BITS);
769   t_hi[15] = vrshrq_n_s32(vsubq_s32(s_hi[7], s_hi[15]), DCT_CONST_BITS);
770 
771   // stage 2
772   s_lo[0] = t_lo[0];
773   s_hi[0] = t_hi[0];
774   s_lo[1] = t_lo[1];
775   s_hi[1] = t_hi[1];
776   s_lo[2] = t_lo[2];
777   s_hi[2] = t_hi[2];
778   s_lo[3] = t_lo[3];
779   s_hi[3] = t_hi[3];
780   s_lo[4] = t_lo[4];
781   s_hi[4] = t_hi[4];
782   s_lo[5] = t_lo[5];
783   s_hi[5] = t_hi[5];
784   s_lo[6] = t_lo[6];
785   s_hi[6] = t_hi[6];
786   s_lo[7] = t_lo[7];
787   s_hi[7] = t_hi[7];
788   // s8 = x8 * cospi_4_64 + x9 * cospi_28_64;
789   // s9 = x8 * cospi_28_64 - x9 * cospi_4_64;
790   butterfly_two_coeff_s32_noround(t_lo[8], t_hi[8], t_lo[9], t_hi[9],
791                                   cospi_4_64, cospi_28_64, &s_lo[8], &s_hi[8],
792                                   &s_lo[9], &s_hi[9]);
793   // s10 = x10 * cospi_20_64 + x11 * cospi_12_64;
794   // s11 = x10 * cospi_12_64 - x11 * cospi_20_64;
795   butterfly_two_coeff_s32_noround(t_lo[10], t_hi[10], t_lo[11], t_hi[11],
796                                   cospi_20_64, cospi_12_64, &s_lo[10],
797                                   &s_hi[10], &s_lo[11], &s_hi[11]);
798   // s12 = -x12 * cospi_28_64 + x13 * cospi_4_64;
799   // s13 = x12 * cospi_4_64 + x13 * cospi_28_64;
800   butterfly_two_coeff_s32_noround(t_lo[13], t_hi[13], t_lo[12], t_hi[12],
801                                   cospi_28_64, cospi_4_64, &s_lo[13], &s_hi[13],
802                                   &s_lo[12], &s_hi[12]);
803   // s14 = -x14 * cospi_12_64 + x15 * cospi_20_64;
804   // s15 = x14 * cospi_20_64 + x15 * cospi_12_64;
805   butterfly_two_coeff_s32_noround(t_lo[15], t_hi[15], t_lo[14], t_hi[14],
806                                   cospi_12_64, cospi_20_64, &s_lo[15],
807                                   &s_hi[15], &s_lo[14], &s_hi[14]);
808 
809   // s0 + s4
810   t_lo[0] = vaddq_s32(s_lo[0], s_lo[4]);
811   t_hi[0] = vaddq_s32(s_hi[0], s_hi[4]);
812   // s1 + s5
813   t_lo[1] = vaddq_s32(s_lo[1], s_lo[5]);
814   t_hi[1] = vaddq_s32(s_hi[1], s_hi[5]);
815   // s2 + s6
816   t_lo[2] = vaddq_s32(s_lo[2], s_lo[6]);
817   t_hi[2] = vaddq_s32(s_hi[2], s_hi[6]);
818   // s3 + s7
819   t_lo[3] = vaddq_s32(s_lo[3], s_lo[7]);
820   t_hi[3] = vaddq_s32(s_hi[3], s_hi[7]);
821   // s0 - s4
822   t_lo[4] = vsubq_s32(s_lo[0], s_lo[4]);
823   t_hi[4] = vsubq_s32(s_hi[0], s_hi[4]);
824   // s1 - s7
825   t_lo[5] = vsubq_s32(s_lo[1], s_lo[5]);
826   t_hi[5] = vsubq_s32(s_hi[1], s_hi[5]);
827   // s2 - s6
828   t_lo[6] = vsubq_s32(s_lo[2], s_lo[6]);
829   t_hi[6] = vsubq_s32(s_hi[2], s_hi[6]);
830   // s3 - s7
831   t_lo[7] = vsubq_s32(s_lo[3], s_lo[7]);
832   t_hi[7] = vsubq_s32(s_hi[3], s_hi[7]);
833   // s8 + s12
834   t_lo[8] = vaddq_s32(s_lo[8], s_lo[12]);
835   t_hi[8] = vaddq_s32(s_hi[8], s_hi[12]);
836   // s9 + s13
837   t_lo[9] = vaddq_s32(s_lo[9], s_lo[13]);
838   t_hi[9] = vaddq_s32(s_hi[9], s_hi[13]);
839   // s10 + s14
840   t_lo[10] = vaddq_s32(s_lo[10], s_lo[14]);
841   t_hi[10] = vaddq_s32(s_hi[10], s_hi[14]);
842   // s11 + s15
843   t_lo[11] = vaddq_s32(s_lo[11], s_lo[15]);
844   t_hi[11] = vaddq_s32(s_hi[11], s_hi[15]);
845   // s8 + s12
846   t_lo[12] = vsubq_s32(s_lo[8], s_lo[12]);
847   t_hi[12] = vsubq_s32(s_hi[8], s_hi[12]);
848   // s9 + s13
849   t_lo[13] = vsubq_s32(s_lo[9], s_lo[13]);
850   t_hi[13] = vsubq_s32(s_hi[9], s_hi[13]);
851   // s10 + s14
852   t_lo[14] = vsubq_s32(s_lo[10], s_lo[14]);
853   t_hi[14] = vsubq_s32(s_hi[10], s_hi[14]);
854   // s11 + s15
855   t_lo[15] = vsubq_s32(s_lo[11], s_lo[15]);
856   t_hi[15] = vsubq_s32(s_hi[11], s_hi[15]);
857 
858   t_lo[8] = vrshrq_n_s32(t_lo[8], DCT_CONST_BITS);
859   t_hi[8] = vrshrq_n_s32(t_hi[8], DCT_CONST_BITS);
860   t_lo[9] = vrshrq_n_s32(t_lo[9], DCT_CONST_BITS);
861   t_hi[9] = vrshrq_n_s32(t_hi[9], DCT_CONST_BITS);
862   t_lo[10] = vrshrq_n_s32(t_lo[10], DCT_CONST_BITS);
863   t_hi[10] = vrshrq_n_s32(t_hi[10], DCT_CONST_BITS);
864   t_lo[11] = vrshrq_n_s32(t_lo[11], DCT_CONST_BITS);
865   t_hi[11] = vrshrq_n_s32(t_hi[11], DCT_CONST_BITS);
866   t_lo[12] = vrshrq_n_s32(t_lo[12], DCT_CONST_BITS);
867   t_hi[12] = vrshrq_n_s32(t_hi[12], DCT_CONST_BITS);
868   t_lo[13] = vrshrq_n_s32(t_lo[13], DCT_CONST_BITS);
869   t_hi[13] = vrshrq_n_s32(t_hi[13], DCT_CONST_BITS);
870   t_lo[14] = vrshrq_n_s32(t_lo[14], DCT_CONST_BITS);
871   t_hi[14] = vrshrq_n_s32(t_hi[14], DCT_CONST_BITS);
872   t_lo[15] = vrshrq_n_s32(t_lo[15], DCT_CONST_BITS);
873   t_hi[15] = vrshrq_n_s32(t_hi[15], DCT_CONST_BITS);
874 
875   // stage 3
876   s_lo[0] = t_lo[0];
877   s_hi[0] = t_hi[0];
878   s_lo[1] = t_lo[1];
879   s_hi[1] = t_hi[1];
880   s_lo[2] = t_lo[2];
881   s_hi[2] = t_hi[2];
882   s_lo[3] = t_lo[3];
883   s_hi[3] = t_hi[3];
884   // s4 = x4 * cospi_8_64 + x5 * cospi_24_64;
885   // s5 = x4 * cospi_24_64 - x5 * cospi_8_64;
886   butterfly_two_coeff_s32_noround(t_lo[4], t_hi[4], t_lo[5], t_hi[5],
887                                   cospi_8_64, cospi_24_64, &s_lo[4], &s_hi[4],
888                                   &s_lo[5], &s_hi[5]);
889   // s6 = -x6 * cospi_24_64 + x7 * cospi_8_64;
890   // s7 = x6 * cospi_8_64 + x7 * cospi_24_64;
891   butterfly_two_coeff_s32_noround(t_lo[7], t_hi[7], t_lo[6], t_hi[6],
892                                   cospi_24_64, cospi_8_64, &s_lo[7], &s_hi[7],
893                                   &s_lo[6], &s_hi[6]);
894   s_lo[8] = t_lo[8];
895   s_hi[8] = t_hi[8];
896   s_lo[9] = t_lo[9];
897   s_hi[9] = t_hi[9];
898   s_lo[10] = t_lo[10];
899   s_hi[10] = t_hi[10];
900   s_lo[11] = t_lo[11];
901   s_hi[11] = t_hi[11];
902   // s12 = x12 * cospi_8_64 + x13 * cospi_24_64;
903   // s13 = x12 * cospi_24_64 - x13 * cospi_8_64;
904   butterfly_two_coeff_s32_noround(t_lo[12], t_hi[12], t_lo[13], t_hi[13],
905                                   cospi_8_64, cospi_24_64, &s_lo[12], &s_hi[12],
906                                   &s_lo[13], &s_hi[13]);
907   // s14 = -x14 * cospi_24_64 + x15 * cospi_8_64;
908   // s15 = x14 * cospi_8_64 + x15 * cospi_24_64;
909   butterfly_two_coeff_s32_noround(t_lo[15], t_hi[15], t_lo[14], t_hi[14],
910                                   cospi_24_64, cospi_8_64, &s_lo[15], &s_hi[15],
911                                   &s_lo[14], &s_hi[14]);
912 
913   // s0 + s4
914   t_lo[0] = vaddq_s32(s_lo[0], s_lo[2]);
915   t_hi[0] = vaddq_s32(s_hi[0], s_hi[2]);
916   // s1 + s3
917   t_lo[1] = vaddq_s32(s_lo[1], s_lo[3]);
918   t_hi[1] = vaddq_s32(s_hi[1], s_hi[3]);
919   // s0 - s4
920   t_lo[2] = vsubq_s32(s_lo[0], s_lo[2]);
921   t_hi[2] = vsubq_s32(s_hi[0], s_hi[2]);
922   // s1 - s3
923   t_lo[3] = vsubq_s32(s_lo[1], s_lo[3]);
924   t_hi[3] = vsubq_s32(s_hi[1], s_hi[3]);
925   // s4 + s6
926   t_lo[4] = vaddq_s32(s_lo[4], s_lo[6]);
927   t_hi[4] = vaddq_s32(s_hi[4], s_hi[6]);
928   // s5 + s7
929   t_lo[5] = vaddq_s32(s_lo[5], s_lo[7]);
930   t_hi[5] = vaddq_s32(s_hi[5], s_hi[7]);
931   // s4 - s6
932   t_lo[6] = vsubq_s32(s_lo[4], s_lo[6]);
933   t_hi[6] = vsubq_s32(s_hi[4], s_hi[6]);
934   // s5 - s7
935   t_lo[7] = vsubq_s32(s_lo[5], s_lo[7]);
936   t_hi[7] = vsubq_s32(s_hi[5], s_hi[7]);
937   // s8 + s10
938   t_lo[8] = vaddq_s32(s_lo[8], s_lo[10]);
939   t_hi[8] = vaddq_s32(s_hi[8], s_hi[10]);
940   // s9 + s11
941   t_lo[9] = vaddq_s32(s_lo[9], s_lo[11]);
942   t_hi[9] = vaddq_s32(s_hi[9], s_hi[11]);
943   // s8 - s10
944   t_lo[10] = vsubq_s32(s_lo[8], s_lo[10]);
945   t_hi[10] = vsubq_s32(s_hi[8], s_hi[10]);
946   // s9 - s11
947   t_lo[11] = vsubq_s32(s_lo[9], s_lo[11]);
948   t_hi[11] = vsubq_s32(s_hi[9], s_hi[11]);
949   // s12 + s14
950   t_lo[12] = vaddq_s32(s_lo[12], s_lo[14]);
951   t_hi[12] = vaddq_s32(s_hi[12], s_hi[14]);
952   // s13 + s15
953   t_lo[13] = vaddq_s32(s_lo[13], s_lo[15]);
954   t_hi[13] = vaddq_s32(s_hi[13], s_hi[15]);
955   // s12 - s14
956   t_lo[14] = vsubq_s32(s_lo[12], s_lo[14]);
957   t_hi[14] = vsubq_s32(s_hi[12], s_hi[14]);
958   // s13 - s15
959   t_lo[15] = vsubq_s32(s_lo[13], s_lo[15]);
960   t_hi[15] = vsubq_s32(s_hi[13], s_hi[15]);
961 
962   t_lo[4] = vrshrq_n_s32(t_lo[4], DCT_CONST_BITS);
963   t_hi[4] = vrshrq_n_s32(t_hi[4], DCT_CONST_BITS);
964   t_lo[5] = vrshrq_n_s32(t_lo[5], DCT_CONST_BITS);
965   t_hi[5] = vrshrq_n_s32(t_hi[5], DCT_CONST_BITS);
966   t_lo[6] = vrshrq_n_s32(t_lo[6], DCT_CONST_BITS);
967   t_hi[6] = vrshrq_n_s32(t_hi[6], DCT_CONST_BITS);
968   t_lo[7] = vrshrq_n_s32(t_lo[7], DCT_CONST_BITS);
969   t_hi[7] = vrshrq_n_s32(t_hi[7], DCT_CONST_BITS);
970   t_lo[12] = vrshrq_n_s32(t_lo[12], DCT_CONST_BITS);
971   t_hi[12] = vrshrq_n_s32(t_hi[12], DCT_CONST_BITS);
972   t_lo[13] = vrshrq_n_s32(t_lo[13], DCT_CONST_BITS);
973   t_hi[13] = vrshrq_n_s32(t_hi[13], DCT_CONST_BITS);
974   t_lo[14] = vrshrq_n_s32(t_lo[14], DCT_CONST_BITS);
975   t_hi[14] = vrshrq_n_s32(t_hi[14], DCT_CONST_BITS);
976   t_lo[15] = vrshrq_n_s32(t_lo[15], DCT_CONST_BITS);
977   t_hi[15] = vrshrq_n_s32(t_hi[15], DCT_CONST_BITS);
978 
979   // stage 4
980   // s2 = (-cospi_16_64) * (x2 + x3);
981   // s3 = cospi_16_64 * (x2 - x3);
982   butterfly_one_coeff_s32_noround(t_lo[3], t_hi[3], t_lo[2], t_hi[2],
983                                   -cospi_16_64, &s_lo[2], &s_hi[2], &s_lo[3],
984                                   &s_hi[3]);
985   // s6 = cospi_16_64 * (x6 + x7);
986   // s7 = cospi_16_64 * (-x6 + x7);
987   butterfly_one_coeff_s32_noround(t_lo[7], t_hi[7], t_lo[6], t_hi[6],
988                                   cospi_16_64, &s_lo[6], &s_hi[6], &s_lo[7],
989                                   &s_hi[7]);
990   // s10 = cospi_16_64 * (x10 + x11);
991   // s11 = cospi_16_64 * (-x10 + x11);
992   butterfly_one_coeff_s32_noround(t_lo[11], t_hi[11], t_lo[10], t_hi[10],
993                                   cospi_16_64, &s_lo[10], &s_hi[10], &s_lo[11],
994                                   &s_hi[11]);
995   // s14 = (-cospi_16_64) * (x14 + x15);
996   // s15 = cospi_16_64 * (x14 - x15);
997   butterfly_one_coeff_s32_noround(t_lo[15], t_hi[15], t_lo[14], t_hi[14],
998                                   -cospi_16_64, &s_lo[14], &s_hi[14], &s_lo[15],
999                                   &s_hi[15]);
1000 
1001   // final fdct_round_shift
1002   x_lo[2] = vrshrn_n_s32(s_lo[2], DCT_CONST_BITS);
1003   x_hi[2] = vrshrn_n_s32(s_hi[2], DCT_CONST_BITS);
1004   x_lo[3] = vrshrn_n_s32(s_lo[3], DCT_CONST_BITS);
1005   x_hi[3] = vrshrn_n_s32(s_hi[3], DCT_CONST_BITS);
1006   x_lo[6] = vrshrn_n_s32(s_lo[6], DCT_CONST_BITS);
1007   x_hi[6] = vrshrn_n_s32(s_hi[6], DCT_CONST_BITS);
1008   x_lo[7] = vrshrn_n_s32(s_lo[7], DCT_CONST_BITS);
1009   x_hi[7] = vrshrn_n_s32(s_hi[7], DCT_CONST_BITS);
1010   x_lo[10] = vrshrn_n_s32(s_lo[10], DCT_CONST_BITS);
1011   x_hi[10] = vrshrn_n_s32(s_hi[10], DCT_CONST_BITS);
1012   x_lo[11] = vrshrn_n_s32(s_lo[11], DCT_CONST_BITS);
1013   x_hi[11] = vrshrn_n_s32(s_hi[11], DCT_CONST_BITS);
1014   x_lo[14] = vrshrn_n_s32(s_lo[14], DCT_CONST_BITS);
1015   x_hi[14] = vrshrn_n_s32(s_hi[14], DCT_CONST_BITS);
1016   x_lo[15] = vrshrn_n_s32(s_lo[15], DCT_CONST_BITS);
1017   x_hi[15] = vrshrn_n_s32(s_hi[15], DCT_CONST_BITS);
1018 
1019   // x0, x1, x4, x5, x8, x9, x12, x13 narrow down to 16-bits directly
1020   x_lo[0] = vmovn_s32(t_lo[0]);
1021   x_hi[0] = vmovn_s32(t_hi[0]);
1022   x_lo[1] = vmovn_s32(t_lo[1]);
1023   x_hi[1] = vmovn_s32(t_hi[1]);
1024   x_lo[4] = vmovn_s32(t_lo[4]);
1025   x_hi[4] = vmovn_s32(t_hi[4]);
1026   x_lo[5] = vmovn_s32(t_lo[5]);
1027   x_hi[5] = vmovn_s32(t_hi[5]);
1028   x_lo[8] = vmovn_s32(t_lo[8]);
1029   x_hi[8] = vmovn_s32(t_hi[8]);
1030   x_lo[9] = vmovn_s32(t_lo[9]);
1031   x_hi[9] = vmovn_s32(t_hi[9]);
1032   x_lo[12] = vmovn_s32(t_lo[12]);
1033   x_hi[12] = vmovn_s32(t_hi[12]);
1034   x_lo[13] = vmovn_s32(t_lo[13]);
1035   x_hi[13] = vmovn_s32(t_hi[13]);
1036 
1037   in[0] = vcombine_s16(x_lo[0], x_hi[0]);
1038   in[1] = vnegq_s16(vcombine_s16(x_lo[8], x_hi[8]));
1039   in[2] = vcombine_s16(x_lo[12], x_hi[12]);
1040   in[3] = vnegq_s16(vcombine_s16(x_lo[4], x_hi[4]));
1041   in[4] = vcombine_s16(x_lo[6], x_hi[6]);
1042   in[5] = vcombine_s16(x_lo[14], x_hi[14]);
1043   in[6] = vcombine_s16(x_lo[10], x_hi[10]);
1044   in[7] = vcombine_s16(x_lo[2], x_hi[2]);
1045   in[8] = vcombine_s16(x_lo[3], x_hi[3]);
1046   in[9] = vcombine_s16(x_lo[11], x_hi[11]);
1047   in[10] = vcombine_s16(x_lo[15], x_hi[15]);
1048   in[11] = vcombine_s16(x_lo[7], x_hi[7]);
1049   in[12] = vcombine_s16(x_lo[5], x_hi[5]);
1050   in[13] = vnegq_s16(vcombine_s16(x_lo[13], x_hi[13]));
1051   in[14] = vcombine_s16(x_lo[9], x_hi[9]);
1052   in[15] = vnegq_s16(vcombine_s16(x_lo[1], x_hi[1]));
1053 }
1054 
fdct16x16_neon(int16x8_t * in0,int16x8_t * in1)1055 static void fdct16x16_neon(int16x8_t *in0, int16x8_t *in1) {
1056   // Left half.
1057   fdct16_8col(in0);
1058   // Right half.
1059   fdct16_8col(in1);
1060   transpose_s16_16x16(in0, in1);
1061 }
1062 
fadst16x16_neon(int16x8_t * in0,int16x8_t * in1)1063 static void fadst16x16_neon(int16x8_t *in0, int16x8_t *in1) {
1064   fadst16_8col(in0);
1065   fadst16_8col(in1);
1066   transpose_s16_16x16(in0, in1);
1067 }
1068 
vp9_fht16x16_neon(const int16_t * input,tran_low_t * output,int stride,int tx_type)1069 void vp9_fht16x16_neon(const int16_t *input, tran_low_t *output, int stride,
1070                        int tx_type) {
1071   int16x8_t in0[16], in1[16];
1072 
1073   switch (tx_type) {
1074     case DCT_DCT: vpx_fdct16x16_neon(input, output, stride); break;
1075     case ADST_DCT:
1076       load_buffer_16x16(input, in0, in1, stride);
1077       fadst16x16_neon(in0, in1);
1078       right_shift_16x16(in0, in1);
1079       fdct16x16_neon(in0, in1);
1080       write_buffer_16x16(output, in0, in1, 16);
1081       break;
1082     case DCT_ADST:
1083       load_buffer_16x16(input, in0, in1, stride);
1084       fdct16x16_neon(in0, in1);
1085       right_shift_16x16(in0, in1);
1086       fadst16x16_neon(in0, in1);
1087       write_buffer_16x16(output, in0, in1, 16);
1088       break;
1089     default:
1090       assert(tx_type == ADST_ADST);
1091       load_buffer_16x16(input, in0, in1, stride);
1092       fadst16x16_neon(in0, in1);
1093       right_shift_16x16(in0, in1);
1094       fadst16x16_neon(in0, in1);
1095       write_buffer_16x16(output, in0, in1, 16);
1096       break;
1097   }
1098 }
1099 
1100 #if CONFIG_VP9_HIGHBITDEPTH
1101 
highbd_load_buffer_4x4(const int16_t * input,int32x4_t * in,int stride)1102 static INLINE void highbd_load_buffer_4x4(const int16_t *input,
1103                                           int32x4_t *in /*[4]*/, int stride) {
1104   // { 0, 1, 1, 1 };
1105   const int32x4_t nonzero_bias_a = vextq_s32(vdupq_n_s32(0), vdupq_n_s32(1), 3);
1106   // { 1, 0, 0, 0 };
1107   const int32x4_t nonzero_bias_b = vextq_s32(vdupq_n_s32(1), vdupq_n_s32(0), 3);
1108   int32x4_t mask;
1109 
1110   in[0] = vshll_n_s16(vld1_s16(input + 0 * stride), 4);
1111   in[1] = vshll_n_s16(vld1_s16(input + 1 * stride), 4);
1112   in[2] = vshll_n_s16(vld1_s16(input + 2 * stride), 4);
1113   in[3] = vshll_n_s16(vld1_s16(input + 3 * stride), 4);
1114 
1115   // Copy the SSE method, use a mask to avoid an 'if' branch here to increase by
1116   // one non-zero first elements
1117   mask = vreinterpretq_s32_u32(vceqq_s32(in[0], nonzero_bias_a));
1118   in[0] = vaddq_s32(in[0], mask);
1119   in[0] = vaddq_s32(in[0], nonzero_bias_b);
1120 }
1121 
highbd_write_buffer_4x4(tran_low_t * output,int32x4_t * res)1122 static INLINE void highbd_write_buffer_4x4(tran_low_t *output, int32x4_t *res) {
1123   const int32x4_t one = vdupq_n_s32(1);
1124   res[0] = vshrq_n_s32(vaddq_s32(res[0], one), 2);
1125   res[1] = vshrq_n_s32(vaddq_s32(res[1], one), 2);
1126   res[2] = vshrq_n_s32(vaddq_s32(res[2], one), 2);
1127   res[3] = vshrq_n_s32(vaddq_s32(res[3], one), 2);
1128   vst1q_s32(output + 0 * 4, res[0]);
1129   vst1q_s32(output + 1 * 4, res[1]);
1130   vst1q_s32(output + 2 * 4, res[2]);
1131   vst1q_s32(output + 3 * 4, res[3]);
1132 }
1133 
highbd_fadst4x4_neon(int32x4_t * in)1134 static INLINE void highbd_fadst4x4_neon(int32x4_t *in /*[4]*/) {
1135   int32x2_t s_lo[4], s_hi[4];
1136   int64x2_t u_lo[4], u_hi[4], t_lo[4], t_hi[4];
1137 
1138   s_lo[0] = vget_low_s32(in[0]);
1139   s_hi[0] = vget_high_s32(in[0]);
1140   s_lo[1] = vget_low_s32(in[1]);
1141   s_hi[1] = vget_high_s32(in[1]);
1142   s_lo[2] = vget_low_s32(in[2]);
1143   s_hi[2] = vget_high_s32(in[2]);
1144   s_lo[3] = vget_low_s32(in[3]);
1145   s_hi[3] = vget_high_s32(in[3]);
1146 
1147   // t0 = s0 * sinpi_1_9 + s1 * sinpi_2_9 + s3 * sinpi_4_9
1148   t_lo[0] = vmull_n_s32(s_lo[0], sinpi_1_9);
1149   t_lo[0] = vmlal_n_s32(t_lo[0], s_lo[1], sinpi_2_9);
1150   t_lo[0] = vmlal_n_s32(t_lo[0], s_lo[3], sinpi_4_9);
1151   t_hi[0] = vmull_n_s32(s_hi[0], sinpi_1_9);
1152   t_hi[0] = vmlal_n_s32(t_hi[0], s_hi[1], sinpi_2_9);
1153   t_hi[0] = vmlal_n_s32(t_hi[0], s_hi[3], sinpi_4_9);
1154 
1155   // t1 = (s0 + s1) * sinpi_3_9 - s3 * sinpi_3_9
1156   t_lo[1] = vmull_n_s32(s_lo[0], sinpi_3_9);
1157   t_lo[1] = vmlal_n_s32(t_lo[1], s_lo[1], sinpi_3_9);
1158   t_lo[1] = vmlsl_n_s32(t_lo[1], s_lo[3], sinpi_3_9);
1159   t_hi[1] = vmull_n_s32(s_hi[0], sinpi_3_9);
1160   t_hi[1] = vmlal_n_s32(t_hi[1], s_hi[1], sinpi_3_9);
1161   t_hi[1] = vmlsl_n_s32(t_hi[1], s_hi[3], sinpi_3_9);
1162 
1163   // t2 = s0 * sinpi_4_9 - s1* sinpi_1_9 + s3 * sinpi_2_9
1164   t_lo[2] = vmull_n_s32(s_lo[0], sinpi_4_9);
1165   t_lo[2] = vmlsl_n_s32(t_lo[2], s_lo[1], sinpi_1_9);
1166   t_lo[2] = vmlal_n_s32(t_lo[2], s_lo[3], sinpi_2_9);
1167   t_hi[2] = vmull_n_s32(s_hi[0], sinpi_4_9);
1168   t_hi[2] = vmlsl_n_s32(t_hi[2], s_hi[1], sinpi_1_9);
1169   t_hi[2] = vmlal_n_s32(t_hi[2], s_hi[3], sinpi_2_9);
1170 
1171   // t3 = s2 * sinpi_3_9
1172   t_lo[3] = vmull_n_s32(s_lo[2], sinpi_3_9);
1173   t_hi[3] = vmull_n_s32(s_hi[2], sinpi_3_9);
1174 
1175   /*
1176    * u0 = t0 + t3
1177    * u1 = t1
1178    * u2 = t2 - t3
1179    * u3 = t2 - t0 + t3
1180    */
1181   u_lo[0] = vaddq_s64(t_lo[0], t_lo[3]);
1182   u_hi[0] = vaddq_s64(t_hi[0], t_hi[3]);
1183   u_lo[1] = t_lo[1];
1184   u_hi[1] = t_hi[1];
1185   u_lo[2] = vsubq_s64(t_lo[2], t_lo[3]);
1186   u_hi[2] = vsubq_s64(t_hi[2], t_hi[3]);
1187   u_lo[3] = vaddq_s64(vsubq_s64(t_lo[2], t_lo[0]), t_lo[3]);
1188   u_hi[3] = vaddq_s64(vsubq_s64(t_hi[2], t_hi[0]), t_hi[3]);
1189 
1190   // fdct_round_shift
1191   in[0] = vcombine_s32(vrshrn_n_s64(u_lo[0], DCT_CONST_BITS),
1192                        vrshrn_n_s64(u_hi[0], DCT_CONST_BITS));
1193   in[1] = vcombine_s32(vrshrn_n_s64(u_lo[1], DCT_CONST_BITS),
1194                        vrshrn_n_s64(u_hi[1], DCT_CONST_BITS));
1195   in[2] = vcombine_s32(vrshrn_n_s64(u_lo[2], DCT_CONST_BITS),
1196                        vrshrn_n_s64(u_hi[2], DCT_CONST_BITS));
1197   in[3] = vcombine_s32(vrshrn_n_s64(u_lo[3], DCT_CONST_BITS),
1198                        vrshrn_n_s64(u_hi[3], DCT_CONST_BITS));
1199 
1200   transpose_s32_4x4(&in[0], &in[1], &in[2], &in[3]);
1201 }
1202 
vp9_highbd_fht4x4_neon(const int16_t * input,tran_low_t * output,int stride,int tx_type)1203 void vp9_highbd_fht4x4_neon(const int16_t *input, tran_low_t *output,
1204                             int stride, int tx_type) {
1205   int32x4_t in[4];
1206   // int i;
1207 
1208   switch (tx_type) {
1209     case DCT_DCT: vpx_highbd_fdct4x4_neon(input, output, stride); break;
1210     case ADST_DCT:
1211       highbd_load_buffer_4x4(input, in, stride);
1212       highbd_fadst4x4_neon(in);
1213       vpx_highbd_fdct4x4_pass1_neon(in);
1214       highbd_write_buffer_4x4(output, in);
1215       break;
1216     case DCT_ADST:
1217       highbd_load_buffer_4x4(input, in, stride);
1218       vpx_highbd_fdct4x4_pass1_neon(in);
1219       highbd_fadst4x4_neon(in);
1220       highbd_write_buffer_4x4(output, in);
1221       break;
1222     default:
1223       assert(tx_type == ADST_ADST);
1224       highbd_load_buffer_4x4(input, in, stride);
1225       highbd_fadst4x4_neon(in);
1226       highbd_fadst4x4_neon(in);
1227       highbd_write_buffer_4x4(output, in);
1228       break;
1229   }
1230 }
1231 
highbd_load_buffer_8x8(const int16_t * input,int32x4_t * lo,int32x4_t * hi,int stride)1232 static INLINE void highbd_load_buffer_8x8(const int16_t *input,
1233                                           int32x4_t *lo /*[8]*/,
1234                                           int32x4_t *hi /*[8]*/, int stride) {
1235   int16x8_t in[8];
1236   in[0] = vld1q_s16(input + 0 * stride);
1237   in[1] = vld1q_s16(input + 1 * stride);
1238   in[2] = vld1q_s16(input + 2 * stride);
1239   in[3] = vld1q_s16(input + 3 * stride);
1240   in[4] = vld1q_s16(input + 4 * stride);
1241   in[5] = vld1q_s16(input + 5 * stride);
1242   in[6] = vld1q_s16(input + 6 * stride);
1243   in[7] = vld1q_s16(input + 7 * stride);
1244   lo[0] = vshll_n_s16(vget_low_s16(in[0]), 2);
1245   hi[0] = vshll_n_s16(vget_high_s16(in[0]), 2);
1246   lo[1] = vshll_n_s16(vget_low_s16(in[1]), 2);
1247   hi[1] = vshll_n_s16(vget_high_s16(in[1]), 2);
1248   lo[2] = vshll_n_s16(vget_low_s16(in[2]), 2);
1249   hi[2] = vshll_n_s16(vget_high_s16(in[2]), 2);
1250   lo[3] = vshll_n_s16(vget_low_s16(in[3]), 2);
1251   hi[3] = vshll_n_s16(vget_high_s16(in[3]), 2);
1252   lo[4] = vshll_n_s16(vget_low_s16(in[4]), 2);
1253   hi[4] = vshll_n_s16(vget_high_s16(in[4]), 2);
1254   lo[5] = vshll_n_s16(vget_low_s16(in[5]), 2);
1255   hi[5] = vshll_n_s16(vget_high_s16(in[5]), 2);
1256   lo[6] = vshll_n_s16(vget_low_s16(in[6]), 2);
1257   hi[6] = vshll_n_s16(vget_high_s16(in[6]), 2);
1258   lo[7] = vshll_n_s16(vget_low_s16(in[7]), 2);
1259   hi[7] = vshll_n_s16(vget_high_s16(in[7]), 2);
1260 }
1261 
1262 /* right shift and rounding
1263  * first get the sign bit (bit 15).
1264  * If bit == 1, it's the simple case of shifting right by one bit.
1265  * If bit == 2, it essentially computes the expression:
1266  *
1267  * out[j * 16 + i] = (temp_out[j] + 1 + (temp_out[j] < 0)) >> 2;
1268  *
1269  * for each row.
1270  */
highbd_right_shift_8x8(int32x4_t * lo,int32x4_t * hi,const int bit)1271 static INLINE void highbd_right_shift_8x8(int32x4_t *lo, int32x4_t *hi,
1272                                           const int bit) {
1273   int32x4_t sign_lo[8], sign_hi[8];
1274   sign_lo[0] = vshrq_n_s32(lo[0], 31);
1275   sign_hi[0] = vshrq_n_s32(hi[0], 31);
1276   sign_lo[1] = vshrq_n_s32(lo[1], 31);
1277   sign_hi[1] = vshrq_n_s32(hi[1], 31);
1278   sign_lo[2] = vshrq_n_s32(lo[2], 31);
1279   sign_hi[2] = vshrq_n_s32(hi[2], 31);
1280   sign_lo[3] = vshrq_n_s32(lo[3], 31);
1281   sign_hi[3] = vshrq_n_s32(hi[3], 31);
1282   sign_lo[4] = vshrq_n_s32(lo[4], 31);
1283   sign_hi[4] = vshrq_n_s32(hi[4], 31);
1284   sign_lo[5] = vshrq_n_s32(lo[5], 31);
1285   sign_hi[5] = vshrq_n_s32(hi[5], 31);
1286   sign_lo[6] = vshrq_n_s32(lo[6], 31);
1287   sign_hi[6] = vshrq_n_s32(hi[6], 31);
1288   sign_lo[7] = vshrq_n_s32(lo[7], 31);
1289   sign_hi[7] = vshrq_n_s32(hi[7], 31);
1290 
1291   if (bit == 2) {
1292     const int32x4_t const_rounding = vdupq_n_s32(1);
1293     lo[0] = vaddq_s32(lo[0], const_rounding);
1294     hi[0] = vaddq_s32(hi[0], const_rounding);
1295     lo[1] = vaddq_s32(lo[1], const_rounding);
1296     hi[1] = vaddq_s32(hi[1], const_rounding);
1297     lo[2] = vaddq_s32(lo[2], const_rounding);
1298     hi[2] = vaddq_s32(hi[2], const_rounding);
1299     lo[3] = vaddq_s32(lo[3], const_rounding);
1300     hi[3] = vaddq_s32(hi[3], const_rounding);
1301     lo[4] = vaddq_s32(lo[4], const_rounding);
1302     hi[4] = vaddq_s32(hi[4], const_rounding);
1303     lo[5] = vaddq_s32(lo[5], const_rounding);
1304     hi[5] = vaddq_s32(hi[5], const_rounding);
1305     lo[6] = vaddq_s32(lo[6], const_rounding);
1306     hi[6] = vaddq_s32(hi[6], const_rounding);
1307     lo[7] = vaddq_s32(lo[7], const_rounding);
1308     hi[7] = vaddq_s32(hi[7], const_rounding);
1309   }
1310 
1311   lo[0] = vsubq_s32(lo[0], sign_lo[0]);
1312   hi[0] = vsubq_s32(hi[0], sign_hi[0]);
1313   lo[1] = vsubq_s32(lo[1], sign_lo[1]);
1314   hi[1] = vsubq_s32(hi[1], sign_hi[1]);
1315   lo[2] = vsubq_s32(lo[2], sign_lo[2]);
1316   hi[2] = vsubq_s32(hi[2], sign_hi[2]);
1317   lo[3] = vsubq_s32(lo[3], sign_lo[3]);
1318   hi[3] = vsubq_s32(hi[3], sign_hi[3]);
1319   lo[4] = vsubq_s32(lo[4], sign_lo[4]);
1320   hi[4] = vsubq_s32(hi[4], sign_hi[4]);
1321   lo[5] = vsubq_s32(lo[5], sign_lo[5]);
1322   hi[5] = vsubq_s32(hi[5], sign_hi[5]);
1323   lo[6] = vsubq_s32(lo[6], sign_lo[6]);
1324   hi[6] = vsubq_s32(hi[6], sign_hi[6]);
1325   lo[7] = vsubq_s32(lo[7], sign_lo[7]);
1326   hi[7] = vsubq_s32(hi[7], sign_hi[7]);
1327 
1328   if (bit == 1) {
1329     lo[0] = vshrq_n_s32(lo[0], 1);
1330     hi[0] = vshrq_n_s32(hi[0], 1);
1331     lo[1] = vshrq_n_s32(lo[1], 1);
1332     hi[1] = vshrq_n_s32(hi[1], 1);
1333     lo[2] = vshrq_n_s32(lo[2], 1);
1334     hi[2] = vshrq_n_s32(hi[2], 1);
1335     lo[3] = vshrq_n_s32(lo[3], 1);
1336     hi[3] = vshrq_n_s32(hi[3], 1);
1337     lo[4] = vshrq_n_s32(lo[4], 1);
1338     hi[4] = vshrq_n_s32(hi[4], 1);
1339     lo[5] = vshrq_n_s32(lo[5], 1);
1340     hi[5] = vshrq_n_s32(hi[5], 1);
1341     lo[6] = vshrq_n_s32(lo[6], 1);
1342     hi[6] = vshrq_n_s32(hi[6], 1);
1343     lo[7] = vshrq_n_s32(lo[7], 1);
1344     hi[7] = vshrq_n_s32(hi[7], 1);
1345   } else {
1346     lo[0] = vshrq_n_s32(lo[0], 2);
1347     hi[0] = vshrq_n_s32(hi[0], 2);
1348     lo[1] = vshrq_n_s32(lo[1], 2);
1349     hi[1] = vshrq_n_s32(hi[1], 2);
1350     lo[2] = vshrq_n_s32(lo[2], 2);
1351     hi[2] = vshrq_n_s32(hi[2], 2);
1352     lo[3] = vshrq_n_s32(lo[3], 2);
1353     hi[3] = vshrq_n_s32(hi[3], 2);
1354     lo[4] = vshrq_n_s32(lo[4], 2);
1355     hi[4] = vshrq_n_s32(hi[4], 2);
1356     lo[5] = vshrq_n_s32(lo[5], 2);
1357     hi[5] = vshrq_n_s32(hi[5], 2);
1358     lo[6] = vshrq_n_s32(lo[6], 2);
1359     hi[6] = vshrq_n_s32(hi[6], 2);
1360     lo[7] = vshrq_n_s32(lo[7], 2);
1361     hi[7] = vshrq_n_s32(hi[7], 2);
1362   }
1363 }
1364 
highbd_write_buffer_8x8(tran_low_t * output,int32x4_t * lo,int32x4_t * hi,int stride)1365 static INLINE void highbd_write_buffer_8x8(tran_low_t *output, int32x4_t *lo,
1366                                            int32x4_t *hi, int stride) {
1367   vst1q_s32(output + 0 * stride, lo[0]);
1368   vst1q_s32(output + 0 * stride + 4, hi[0]);
1369   vst1q_s32(output + 1 * stride, lo[1]);
1370   vst1q_s32(output + 1 * stride + 4, hi[1]);
1371   vst1q_s32(output + 2 * stride, lo[2]);
1372   vst1q_s32(output + 2 * stride + 4, hi[2]);
1373   vst1q_s32(output + 3 * stride, lo[3]);
1374   vst1q_s32(output + 3 * stride + 4, hi[3]);
1375   vst1q_s32(output + 4 * stride, lo[4]);
1376   vst1q_s32(output + 4 * stride + 4, hi[4]);
1377   vst1q_s32(output + 5 * stride, lo[5]);
1378   vst1q_s32(output + 5 * stride + 4, hi[5]);
1379   vst1q_s32(output + 6 * stride, lo[6]);
1380   vst1q_s32(output + 6 * stride + 4, hi[6]);
1381   vst1q_s32(output + 7 * stride, lo[7]);
1382   vst1q_s32(output + 7 * stride + 4, hi[7]);
1383 }
1384 
highbd_fadst8x8_neon(int32x4_t * lo,int32x4_t * hi)1385 static INLINE void highbd_fadst8x8_neon(int32x4_t *lo /*[8]*/,
1386                                         int32x4_t *hi /*[8]*/) {
1387   int32x4_t s_lo[8], s_hi[8];
1388   int32x4_t t_lo[8], t_hi[8];
1389   int32x4_t x_lo[8], x_hi[8];
1390   int64x2_t s64_lo[16], s64_hi[16];
1391 
1392   x_lo[0] = lo[7];
1393   x_hi[0] = hi[7];
1394   x_lo[1] = lo[0];
1395   x_hi[1] = hi[0];
1396   x_lo[2] = lo[5];
1397   x_hi[2] = hi[5];
1398   x_lo[3] = lo[2];
1399   x_hi[3] = hi[2];
1400   x_lo[4] = lo[3];
1401   x_hi[4] = hi[3];
1402   x_lo[5] = lo[4];
1403   x_hi[5] = hi[4];
1404   x_lo[6] = lo[1];
1405   x_hi[6] = hi[1];
1406   x_lo[7] = lo[6];
1407   x_hi[7] = hi[6];
1408 
1409   // stage 1
1410   // s0 = cospi_2_64 * x0 + cospi_30_64 * x1;
1411   // s1 = cospi_30_64 * x0 - cospi_2_64 * x1;
1412   butterfly_two_coeff_s32_s64_noround(
1413       x_lo[0], x_hi[0], x_lo[1], x_hi[1], cospi_2_64, cospi_30_64,
1414       &s64_lo[2 * 0], &s64_hi[2 * 0], &s64_lo[2 * 1], &s64_hi[2 * 1]);
1415   // s2 = cospi_10_64 * x2 + cospi_22_64 * x3;
1416   // s3 = cospi_22_64 * x2 - cospi_10_64 * x3;
1417   butterfly_two_coeff_s32_s64_noround(
1418       x_lo[2], x_hi[2], x_lo[3], x_hi[3], cospi_10_64, cospi_22_64,
1419       &s64_lo[2 * 2], &s64_hi[2 * 2], &s64_lo[2 * 3], &s64_hi[2 * 3]);
1420 
1421   // s4 = cospi_18_64 * x4 + cospi_14_64 * x5;
1422   // s5 = cospi_14_64 * x4 - cospi_18_64 * x5;
1423   butterfly_two_coeff_s32_s64_noround(
1424       x_lo[4], x_hi[4], x_lo[5], x_hi[5], cospi_18_64, cospi_14_64,
1425       &s64_lo[2 * 4], &s64_hi[2 * 4], &s64_lo[2 * 5], &s64_hi[2 * 5]);
1426 
1427   // s6 = cospi_26_64 * x6 + cospi_6_64 * x7;
1428   // s7 = cospi_6_64 * x6 - cospi_26_64 * x7;
1429   butterfly_two_coeff_s32_s64_noround(
1430       x_lo[6], x_hi[6], x_lo[7], x_hi[7], cospi_26_64, cospi_6_64,
1431       &s64_lo[2 * 6], &s64_hi[2 * 6], &s64_lo[2 * 7], &s64_hi[2 * 7]);
1432 
1433   // fdct_round_shift, indices are doubled
1434   t_lo[0] = add_s64_round_narrow(&s64_lo[2 * 0], &s64_lo[2 * 4]);
1435   t_hi[0] = add_s64_round_narrow(&s64_hi[2 * 0], &s64_hi[2 * 4]);
1436   t_lo[1] = add_s64_round_narrow(&s64_lo[2 * 1], &s64_lo[2 * 5]);
1437   t_hi[1] = add_s64_round_narrow(&s64_hi[2 * 1], &s64_hi[2 * 5]);
1438   t_lo[2] = add_s64_round_narrow(&s64_lo[2 * 2], &s64_lo[2 * 6]);
1439   t_hi[2] = add_s64_round_narrow(&s64_hi[2 * 2], &s64_hi[2 * 6]);
1440   t_lo[3] = add_s64_round_narrow(&s64_lo[2 * 3], &s64_lo[2 * 7]);
1441   t_hi[3] = add_s64_round_narrow(&s64_hi[2 * 3], &s64_hi[2 * 7]);
1442   t_lo[4] = sub_s64_round_narrow(&s64_lo[2 * 0], &s64_lo[2 * 4]);
1443   t_hi[4] = sub_s64_round_narrow(&s64_hi[2 * 0], &s64_hi[2 * 4]);
1444   t_lo[5] = sub_s64_round_narrow(&s64_lo[2 * 1], &s64_lo[2 * 5]);
1445   t_hi[5] = sub_s64_round_narrow(&s64_hi[2 * 1], &s64_hi[2 * 5]);
1446   t_lo[6] = sub_s64_round_narrow(&s64_lo[2 * 2], &s64_lo[2 * 6]);
1447   t_hi[6] = sub_s64_round_narrow(&s64_hi[2 * 2], &s64_hi[2 * 6]);
1448   t_lo[7] = sub_s64_round_narrow(&s64_lo[2 * 3], &s64_lo[2 * 7]);
1449   t_hi[7] = sub_s64_round_narrow(&s64_hi[2 * 3], &s64_hi[2 * 7]);
1450 
1451   // stage 2
1452   s_lo[0] = t_lo[0];
1453   s_hi[0] = t_hi[0];
1454   s_lo[1] = t_lo[1];
1455   s_hi[1] = t_hi[1];
1456   s_lo[2] = t_lo[2];
1457   s_hi[2] = t_hi[2];
1458   s_lo[3] = t_lo[3];
1459   s_hi[3] = t_hi[3];
1460   // s4 = cospi_8_64 * x4 + cospi_24_64 * x5;
1461   // s5 = cospi_24_64 * x4 - cospi_8_64 * x5;
1462   butterfly_two_coeff_s32_s64_noround(
1463       t_lo[4], t_hi[4], t_lo[5], t_hi[5], cospi_8_64, cospi_24_64,
1464       &s64_lo[2 * 4], &s64_hi[2 * 4], &s64_lo[2 * 5], &s64_hi[2 * 5]);
1465 
1466   // s6 = -cospi_24_64 * x6 + cospi_8_64 * x7;
1467   // s7 = cospi_8_64 * x6 + cospi_24_64 * x7;
1468   butterfly_two_coeff_s32_s64_noround(
1469       t_lo[6], t_hi[6], t_lo[7], t_hi[7], -cospi_24_64, cospi_8_64,
1470       &s64_lo[2 * 6], &s64_hi[2 * 6], &s64_lo[2 * 7], &s64_hi[2 * 7]);
1471 
1472   // fdct_round_shift
1473   // s0 + s2
1474   t_lo[0] = add_s32_s64_narrow(s_lo[0], s_lo[2]);
1475   t_hi[0] = add_s32_s64_narrow(s_hi[0], s_hi[2]);
1476   // s0 - s2
1477   t_lo[2] = sub_s32_s64_narrow(s_lo[0], s_lo[2]);
1478   t_hi[2] = sub_s32_s64_narrow(s_hi[0], s_hi[2]);
1479 
1480   // s1 + s3
1481   t_lo[1] = add_s32_s64_narrow(s_lo[1], s_lo[3]);
1482   t_hi[1] = add_s32_s64_narrow(s_hi[1], s_hi[3]);
1483   // s1 - s3
1484   t_lo[3] = sub_s32_s64_narrow(s_lo[1], s_lo[3]);
1485   t_hi[3] = sub_s32_s64_narrow(s_hi[1], s_hi[3]);
1486 
1487   // s4 + s6
1488   t_lo[4] = add_s64_round_narrow(&s64_lo[2 * 4], &s64_lo[2 * 6]);
1489   t_hi[4] = add_s64_round_narrow(&s64_hi[2 * 4], &s64_hi[2 * 6]);
1490   // s4 - s6
1491   t_lo[6] = sub_s64_round_narrow(&s64_lo[2 * 4], &s64_lo[2 * 6]);
1492   t_hi[6] = sub_s64_round_narrow(&s64_hi[2 * 4], &s64_hi[2 * 6]);
1493 
1494   // s5 + s7
1495   t_lo[5] = add_s64_round_narrow(&s64_lo[2 * 5], &s64_lo[2 * 7]);
1496   t_hi[5] = add_s64_round_narrow(&s64_hi[2 * 5], &s64_hi[2 * 7]);
1497   // s5 - s7
1498   t_lo[7] = sub_s64_round_narrow(&s64_lo[2 * 5], &s64_lo[2 * 7]);
1499   t_hi[7] = sub_s64_round_narrow(&s64_hi[2 * 5], &s64_hi[2 * 7]);
1500 
1501   // stage 3
1502   // s2 = cospi_16_64 * (x2 + x3)
1503   // s3 = cospi_16_64 * (x2 - x3)
1504   butterfly_one_coeff_s32_fast(t_lo[2], t_hi[2], t_lo[3], t_hi[3], cospi_16_64,
1505                                &s_lo[2], &s_hi[2], &s_lo[3], &s_hi[3]);
1506 
1507   // s6 = cospi_16_64 * (x6 + x7)
1508   // s7 = cospi_16_64 * (x6 - x7)
1509   butterfly_one_coeff_s32_fast(t_lo[6], t_hi[6], t_lo[7], t_hi[7], cospi_16_64,
1510                                &s_lo[6], &s_hi[6], &s_lo[7], &s_hi[7]);
1511 
1512   // x0, x2, x4, x6 pass through
1513   lo[0] = t_lo[0];
1514   hi[0] = t_hi[0];
1515   lo[2] = s_lo[6];
1516   hi[2] = s_hi[6];
1517   lo[4] = s_lo[3];
1518   hi[4] = s_hi[3];
1519   lo[6] = t_lo[5];
1520   hi[6] = t_hi[5];
1521 
1522   lo[1] = vnegq_s32(t_lo[4]);
1523   hi[1] = vnegq_s32(t_hi[4]);
1524   lo[3] = vnegq_s32(s_lo[2]);
1525   hi[3] = vnegq_s32(s_hi[2]);
1526   lo[5] = vnegq_s32(s_lo[7]);
1527   hi[5] = vnegq_s32(s_hi[7]);
1528   lo[7] = vnegq_s32(t_lo[1]);
1529   hi[7] = vnegq_s32(t_hi[1]);
1530 
1531   transpose_s32_8x8_2(lo, hi, lo, hi);
1532 }
1533 
vp9_highbd_fht8x8_neon(const int16_t * input,tran_low_t * output,int stride,int tx_type)1534 void vp9_highbd_fht8x8_neon(const int16_t *input, tran_low_t *output,
1535                             int stride, int tx_type) {
1536   int32x4_t lo[8], hi[8];
1537 
1538   switch (tx_type) {
1539     case DCT_DCT: vpx_highbd_fdct8x8_neon(input, output, stride); break;
1540     case ADST_DCT:
1541       highbd_load_buffer_8x8(input, lo, hi, stride);
1542       highbd_fadst8x8_neon(lo, hi);
1543       // pass1 variant is not precise enough
1544       vpx_highbd_fdct8x8_pass2_neon(lo, hi);
1545       highbd_right_shift_8x8(lo, hi, 1);
1546       highbd_write_buffer_8x8(output, lo, hi, 8);
1547       break;
1548     case DCT_ADST:
1549       highbd_load_buffer_8x8(input, lo, hi, stride);
1550       // pass1 variant is not precise enough
1551       vpx_highbd_fdct8x8_pass2_neon(lo, hi);
1552       highbd_fadst8x8_neon(lo, hi);
1553       highbd_right_shift_8x8(lo, hi, 1);
1554       highbd_write_buffer_8x8(output, lo, hi, 8);
1555       break;
1556     default:
1557       assert(tx_type == ADST_ADST);
1558       highbd_load_buffer_8x8(input, lo, hi, stride);
1559       highbd_fadst8x8_neon(lo, hi);
1560       highbd_fadst8x8_neon(lo, hi);
1561       highbd_right_shift_8x8(lo, hi, 1);
1562       highbd_write_buffer_8x8(output, lo, hi, 8);
1563       break;
1564   }
1565 }
1566 
highbd_load_buffer_16x16(const int16_t * input,int32x4_t * left1,int32x4_t * right1,int32x4_t * left2,int32x4_t * right2,int stride)1567 static INLINE void highbd_load_buffer_16x16(
1568     const int16_t *input, int32x4_t *left1 /*[16]*/, int32x4_t *right1 /*[16]*/,
1569     int32x4_t *left2 /*[16]*/, int32x4_t *right2 /*[16]*/, int stride) {
1570   // load first 8 columns
1571   highbd_load_buffer_8x8(input, left1, right1, stride);
1572   highbd_load_buffer_8x8(input + 8 * stride, left1 + 8, right1 + 8, stride);
1573 
1574   input += 8;
1575   // load second 8 columns
1576   highbd_load_buffer_8x8(input, left2, right2, stride);
1577   highbd_load_buffer_8x8(input + 8 * stride, left2 + 8, right2 + 8, stride);
1578 }
1579 
highbd_write_buffer_16x16(tran_low_t * output,int32x4_t * left1,int32x4_t * right1,int32x4_t * left2,int32x4_t * right2,int stride)1580 static INLINE void highbd_write_buffer_16x16(
1581     tran_low_t *output, int32x4_t *left1 /*[16]*/, int32x4_t *right1 /*[16]*/,
1582     int32x4_t *left2 /*[16]*/, int32x4_t *right2 /*[16]*/, int stride) {
1583   // write first 8 columns
1584   highbd_write_buffer_8x8(output, left1, right1, stride);
1585   highbd_write_buffer_8x8(output + 8 * stride, left1 + 8, right1 + 8, stride);
1586 
1587   // write second 8 columns
1588   output += 8;
1589   highbd_write_buffer_8x8(output, left2, right2, stride);
1590   highbd_write_buffer_8x8(output + 8 * stride, left2 + 8, right2 + 8, stride);
1591 }
1592 
highbd_right_shift_16x16(int32x4_t * left1,int32x4_t * right1,int32x4_t * left2,int32x4_t * right2,const int bit)1593 static INLINE void highbd_right_shift_16x16(int32x4_t *left1 /*[16]*/,
1594                                             int32x4_t *right1 /*[16]*/,
1595                                             int32x4_t *left2 /*[16]*/,
1596                                             int32x4_t *right2 /*[16]*/,
1597                                             const int bit) {
1598   // perform rounding operations
1599   highbd_right_shift_8x8(left1, right1, bit);
1600   highbd_right_shift_8x8(left1 + 8, right1 + 8, bit);
1601   highbd_right_shift_8x8(left2, right2, bit);
1602   highbd_right_shift_8x8(left2 + 8, right2 + 8, bit);
1603 }
1604 
highbd_fdct16_8col(int32x4_t * left,int32x4_t * right)1605 static void highbd_fdct16_8col(int32x4_t *left, int32x4_t *right) {
1606   // perform 16x16 1-D DCT for 8 columns
1607   int32x4_t s1_lo[8], s1_hi[8], s2_lo[8], s2_hi[8], s3_lo[8], s3_hi[8];
1608   int32x4_t left8[8], right8[8];
1609 
1610   // stage 1
1611   left8[0] = vaddq_s32(left[0], left[15]);
1612   right8[0] = vaddq_s32(right[0], right[15]);
1613   left8[1] = vaddq_s32(left[1], left[14]);
1614   right8[1] = vaddq_s32(right[1], right[14]);
1615   left8[2] = vaddq_s32(left[2], left[13]);
1616   right8[2] = vaddq_s32(right[2], right[13]);
1617   left8[3] = vaddq_s32(left[3], left[12]);
1618   right8[3] = vaddq_s32(right[3], right[12]);
1619   left8[4] = vaddq_s32(left[4], left[11]);
1620   right8[4] = vaddq_s32(right[4], right[11]);
1621   left8[5] = vaddq_s32(left[5], left[10]);
1622   right8[5] = vaddq_s32(right[5], right[10]);
1623   left8[6] = vaddq_s32(left[6], left[9]);
1624   right8[6] = vaddq_s32(right[6], right[9]);
1625   left8[7] = vaddq_s32(left[7], left[8]);
1626   right8[7] = vaddq_s32(right[7], right[8]);
1627 
1628   // step 1
1629   s1_lo[0] = vsubq_s32(left[7], left[8]);
1630   s1_hi[0] = vsubq_s32(right[7], right[8]);
1631   s1_lo[1] = vsubq_s32(left[6], left[9]);
1632   s1_hi[1] = vsubq_s32(right[6], right[9]);
1633   s1_lo[2] = vsubq_s32(left[5], left[10]);
1634   s1_hi[2] = vsubq_s32(right[5], right[10]);
1635   s1_lo[3] = vsubq_s32(left[4], left[11]);
1636   s1_hi[3] = vsubq_s32(right[4], right[11]);
1637   s1_lo[4] = vsubq_s32(left[3], left[12]);
1638   s1_hi[4] = vsubq_s32(right[3], right[12]);
1639   s1_lo[5] = vsubq_s32(left[2], left[13]);
1640   s1_hi[5] = vsubq_s32(right[2], right[13]);
1641   s1_lo[6] = vsubq_s32(left[1], left[14]);
1642   s1_hi[6] = vsubq_s32(right[1], right[14]);
1643   s1_lo[7] = vsubq_s32(left[0], left[15]);
1644   s1_hi[7] = vsubq_s32(right[0], right[15]);
1645 
1646   // pass1 variant is not accurate enough
1647   vpx_highbd_fdct8x8_pass2_notranspose_neon(left8, right8);
1648 
1649   // step 2
1650   // step2[2] = (step1[5] - step1[2]) * cospi_16_64;
1651   // step2[5] = (step1[5] + step1[2]) * cospi_16_64;
1652   butterfly_one_coeff_s32_s64_narrow(s1_lo[5], s1_hi[5], s1_lo[2], s1_hi[2],
1653                                      cospi_16_64, &s2_lo[5], &s2_hi[5],
1654                                      &s2_lo[2], &s2_hi[2]);
1655   // step2[3] = (step1[4] - step1[3]) * cospi_16_64;
1656   // step2[4] = (step1[4] + step1[3]) * cospi_16_64;
1657   butterfly_one_coeff_s32_s64_narrow(s1_lo[4], s1_hi[4], s1_lo[3], s1_hi[3],
1658                                      cospi_16_64, &s2_lo[4], &s2_hi[4],
1659                                      &s2_lo[3], &s2_hi[3]);
1660 
1661   // step 3
1662   s3_lo[0] = vaddq_s32(s1_lo[0], s2_lo[3]);
1663   s3_hi[0] = vaddq_s32(s1_hi[0], s2_hi[3]);
1664   s3_lo[1] = vaddq_s32(s1_lo[1], s2_lo[2]);
1665   s3_hi[1] = vaddq_s32(s1_hi[1], s2_hi[2]);
1666   s3_lo[2] = vsubq_s32(s1_lo[1], s2_lo[2]);
1667   s3_hi[2] = vsubq_s32(s1_hi[1], s2_hi[2]);
1668   s3_lo[3] = vsubq_s32(s1_lo[0], s2_lo[3]);
1669   s3_hi[3] = vsubq_s32(s1_hi[0], s2_hi[3]);
1670   s3_lo[4] = vsubq_s32(s1_lo[7], s2_lo[4]);
1671   s3_hi[4] = vsubq_s32(s1_hi[7], s2_hi[4]);
1672   s3_lo[5] = vsubq_s32(s1_lo[6], s2_lo[5]);
1673   s3_hi[5] = vsubq_s32(s1_hi[6], s2_hi[5]);
1674   s3_lo[6] = vaddq_s32(s1_lo[6], s2_lo[5]);
1675   s3_hi[6] = vaddq_s32(s1_hi[6], s2_hi[5]);
1676   s3_lo[7] = vaddq_s32(s1_lo[7], s2_lo[4]);
1677   s3_hi[7] = vaddq_s32(s1_hi[7], s2_hi[4]);
1678 
1679   // step 4
1680   // s2[1] = cospi_24_64 * s3[6] - cospi_8_64 * s3[1]
1681   // s2[6] = cospi_8_64 * s3[6]  + cospi_24_64 * s3[1]
1682   butterfly_two_coeff_s32_s64_narrow(s3_lo[6], s3_hi[6], s3_lo[1], s3_hi[1],
1683                                      cospi_8_64, cospi_24_64, &s2_lo[6],
1684                                      &s2_hi[6], &s2_lo[1], &s2_hi[1]);
1685 
1686   // s2[5] =  cospi_8_64 * s3[2] - cospi_24_64 * s3[5]
1687   // s2[2] = cospi_24_64 * s3[2] + cospi_8_64 * s3[5]
1688   butterfly_two_coeff_s32_s64_narrow(s3_lo[2], s3_hi[2], s3_lo[5], s3_hi[5],
1689                                      cospi_24_64, cospi_8_64, &s2_lo[2],
1690                                      &s2_hi[2], &s2_lo[5], &s2_hi[5]);
1691 
1692   // step 5
1693   s1_lo[0] = vaddq_s32(s3_lo[0], s2_lo[1]);
1694   s1_hi[0] = vaddq_s32(s3_hi[0], s2_hi[1]);
1695   s1_lo[1] = vsubq_s32(s3_lo[0], s2_lo[1]);
1696   s1_hi[1] = vsubq_s32(s3_hi[0], s2_hi[1]);
1697   s1_lo[2] = vaddq_s32(s3_lo[3], s2_lo[2]);
1698   s1_hi[2] = vaddq_s32(s3_hi[3], s2_hi[2]);
1699   s1_lo[3] = vsubq_s32(s3_lo[3], s2_lo[2]);
1700   s1_hi[3] = vsubq_s32(s3_hi[3], s2_hi[2]);
1701   s1_lo[4] = vsubq_s32(s3_lo[4], s2_lo[5]);
1702   s1_hi[4] = vsubq_s32(s3_hi[4], s2_hi[5]);
1703   s1_lo[5] = vaddq_s32(s3_lo[4], s2_lo[5]);
1704   s1_hi[5] = vaddq_s32(s3_hi[4], s2_hi[5]);
1705   s1_lo[6] = vsubq_s32(s3_lo[7], s2_lo[6]);
1706   s1_hi[6] = vsubq_s32(s3_hi[7], s2_hi[6]);
1707   s1_lo[7] = vaddq_s32(s3_lo[7], s2_lo[6]);
1708   s1_hi[7] = vaddq_s32(s3_hi[7], s2_hi[6]);
1709 
1710   // step 6
1711   // out[1]  = step1[7] * cospi_2_64 + step1[0] * cospi_30_64
1712   // out[15] = step1[7] * cospi_30_64 - step1[0] * cospi_2_64
1713   butterfly_two_coeff_s32_s64_narrow(s1_lo[7], s1_hi[7], s1_lo[0], s1_hi[0],
1714                                      cospi_2_64, cospi_30_64, &left[1],
1715                                      &right[1], &left[15], &right[15]);
1716 
1717   // out[9] = step1[6] * cospi_18_64 + step1[1] * cospi_14_64
1718   // out[7] = step1[6] * cospi_14_64 - step1[1] * cospi_18_64
1719   butterfly_two_coeff_s32_s64_narrow(s1_lo[6], s1_hi[6], s1_lo[1], s1_hi[1],
1720                                      cospi_18_64, cospi_14_64, &left[9],
1721                                      &right[9], &left[7], &right[7]);
1722 
1723   // out[5]  = step1[5] * cospi_10_64 + step1[2] * cospi_22_64
1724   // out[11] = step1[5] * cospi_22_64 - step1[2] * cospi_10_64
1725   butterfly_two_coeff_s32_s64_narrow(s1_lo[5], s1_hi[5], s1_lo[2], s1_hi[2],
1726                                      cospi_10_64, cospi_22_64, &left[5],
1727                                      &right[5], &left[11], &right[11]);
1728 
1729   // out[13] = step1[4] * cospi_26_64 + step1[3] * cospi_6_64
1730   // out[3]  = step1[4] * cospi_6_64  - step1[3] * cospi_26_64
1731   butterfly_two_coeff_s32_s64_narrow(s1_lo[4], s1_hi[4], s1_lo[3], s1_hi[3],
1732                                      cospi_26_64, cospi_6_64, &left[13],
1733                                      &right[13], &left[3], &right[3]);
1734 
1735   left[0] = left8[0];
1736   right[0] = right8[0];
1737   left[2] = left8[1];
1738   right[2] = right8[1];
1739   left[4] = left8[2];
1740   right[4] = right8[2];
1741   left[6] = left8[3];
1742   right[6] = right8[3];
1743   left[8] = left8[4];
1744   right[8] = right8[4];
1745   left[10] = left8[5];
1746   right[10] = right8[5];
1747   left[12] = left8[6];
1748   right[12] = right8[6];
1749   left[14] = left8[7];
1750   right[14] = right8[7];
1751 }
1752 
highbd_fadst16_8col(int32x4_t * left,int32x4_t * right)1753 static void highbd_fadst16_8col(int32x4_t *left, int32x4_t *right) {
1754   // perform 16x16 1-D ADST for 8 columns
1755   int32x4_t x_lo[16], x_hi[16];
1756   int32x4_t s_lo[16], s_hi[16];
1757   int32x4_t t_lo[16], t_hi[16];
1758   int64x2_t s64_lo[32], s64_hi[32];
1759 
1760   x_lo[0] = left[15];
1761   x_hi[0] = right[15];
1762   x_lo[1] = left[0];
1763   x_hi[1] = right[0];
1764   x_lo[2] = left[13];
1765   x_hi[2] = right[13];
1766   x_lo[3] = left[2];
1767   x_hi[3] = right[2];
1768   x_lo[4] = left[11];
1769   x_hi[4] = right[11];
1770   x_lo[5] = left[4];
1771   x_hi[5] = right[4];
1772   x_lo[6] = left[9];
1773   x_hi[6] = right[9];
1774   x_lo[7] = left[6];
1775   x_hi[7] = right[6];
1776   x_lo[8] = left[7];
1777   x_hi[8] = right[7];
1778   x_lo[9] = left[8];
1779   x_hi[9] = right[8];
1780   x_lo[10] = left[5];
1781   x_hi[10] = right[5];
1782   x_lo[11] = left[10];
1783   x_hi[11] = right[10];
1784   x_lo[12] = left[3];
1785   x_hi[12] = right[3];
1786   x_lo[13] = left[12];
1787   x_hi[13] = right[12];
1788   x_lo[14] = left[1];
1789   x_hi[14] = right[1];
1790   x_lo[15] = left[14];
1791   x_hi[15] = right[14];
1792 
1793   // stage 1, indices are doubled
1794   // s0 = cospi_1_64 * x0 + cospi_31_64 * x1;
1795   // s1 = cospi_31_64 * x0 - cospi_1_64 * x1;
1796   butterfly_two_coeff_s32_s64_noround(
1797       x_lo[0], x_hi[0], x_lo[1], x_hi[1], cospi_1_64, cospi_31_64,
1798       &s64_lo[2 * 0], &s64_hi[2 * 0], &s64_lo[2 * 1], &s64_hi[2 * 1]);
1799   // s2 = cospi_5_64 * x2 + cospi_27_64 * x3;
1800   // s3 = cospi_27_64 * x2 - cospi_5_64 * x3;
1801   butterfly_two_coeff_s32_s64_noround(
1802       x_lo[2], x_hi[2], x_lo[3], x_hi[3], cospi_5_64, cospi_27_64,
1803       &s64_lo[2 * 2], &s64_hi[2 * 2], &s64_lo[2 * 3], &s64_hi[2 * 3]);
1804   // s4 = cospi_9_64 * x4 + cospi_23_64 * x5;
1805   // s5 = cospi_23_64 * x4 - cospi_9_64 * x5;
1806   butterfly_two_coeff_s32_s64_noround(
1807       x_lo[4], x_hi[4], x_lo[5], x_hi[5], cospi_9_64, cospi_23_64,
1808       &s64_lo[2 * 4], &s64_hi[2 * 4], &s64_lo[2 * 5], &s64_hi[2 * 5]);
1809   // s6 = cospi_13_64 * x6 + cospi_19_64 * x7;
1810   // s7 = cospi_19_64 * x6 - cospi_13_64 * x7;
1811   butterfly_two_coeff_s32_s64_noround(
1812       x_lo[6], x_hi[6], x_lo[7], x_hi[7], cospi_13_64, cospi_19_64,
1813       &s64_lo[2 * 6], &s64_hi[2 * 6], &s64_lo[2 * 7], &s64_hi[2 * 7]);
1814   // s8 = cospi_17_64 * x8 + cospi_15_64 * x9;
1815   // s9 = cospi_15_64 * x8 - cospi_17_64 * x9;
1816   butterfly_two_coeff_s32_s64_noround(
1817       x_lo[8], x_hi[8], x_lo[9], x_hi[9], cospi_17_64, cospi_15_64,
1818       &s64_lo[2 * 8], &s64_hi[2 * 8], &s64_lo[2 * 9], &s64_hi[2 * 9]);
1819   // s10 = cospi_21_64 * x10 + cospi_11_64 * x11;
1820   // s11 = cospi_11_64 * x10 - cospi_21_64 * x11;
1821   butterfly_two_coeff_s32_s64_noround(
1822       x_lo[10], x_hi[10], x_lo[11], x_hi[11], cospi_21_64, cospi_11_64,
1823       &s64_lo[2 * 10], &s64_hi[2 * 10], &s64_lo[2 * 11], &s64_hi[2 * 11]);
1824   // s12 = cospi_25_64 * x12 + cospi_7_64 * x13;
1825   // s13 = cospi_7_64 * x12 - cospi_25_64 * x13;
1826   butterfly_two_coeff_s32_s64_noround(
1827       x_lo[12], x_hi[12], x_lo[13], x_hi[13], cospi_25_64, cospi_7_64,
1828       &s64_lo[2 * 12], &s64_hi[2 * 12], &s64_lo[2 * 13], &s64_hi[2 * 13]);
1829   // s14 = cospi_29_64 * x14 + cospi_3_64 * x15;
1830   // s15 = cospi_3_64 * x14 - cospi_29_64 * x15;
1831   butterfly_two_coeff_s32_s64_noround(
1832       x_lo[14], x_hi[14], x_lo[15], x_hi[15], cospi_29_64, cospi_3_64,
1833       &s64_lo[2 * 14], &s64_hi[2 * 14], &s64_lo[2 * 15], &s64_hi[2 * 15]);
1834 
1835   // fdct_round_shift, indices are doubled
1836   t_lo[0] = add_s64_round_narrow(&s64_lo[2 * 0], &s64_lo[2 * 8]);
1837   t_hi[0] = add_s64_round_narrow(&s64_hi[2 * 0], &s64_hi[2 * 8]);
1838   t_lo[1] = add_s64_round_narrow(&s64_lo[2 * 1], &s64_lo[2 * 9]);
1839   t_hi[1] = add_s64_round_narrow(&s64_hi[2 * 1], &s64_hi[2 * 9]);
1840   t_lo[2] = add_s64_round_narrow(&s64_lo[2 * 2], &s64_lo[2 * 10]);
1841   t_hi[2] = add_s64_round_narrow(&s64_hi[2 * 2], &s64_hi[2 * 10]);
1842   t_lo[3] = add_s64_round_narrow(&s64_lo[2 * 3], &s64_lo[2 * 11]);
1843   t_hi[3] = add_s64_round_narrow(&s64_hi[2 * 3], &s64_hi[2 * 11]);
1844   t_lo[4] = add_s64_round_narrow(&s64_lo[2 * 4], &s64_lo[2 * 12]);
1845   t_hi[4] = add_s64_round_narrow(&s64_hi[2 * 4], &s64_hi[2 * 12]);
1846   t_lo[5] = add_s64_round_narrow(&s64_lo[2 * 5], &s64_lo[2 * 13]);
1847   t_hi[5] = add_s64_round_narrow(&s64_hi[2 * 5], &s64_hi[2 * 13]);
1848   t_lo[6] = add_s64_round_narrow(&s64_lo[2 * 6], &s64_lo[2 * 14]);
1849   t_hi[6] = add_s64_round_narrow(&s64_hi[2 * 6], &s64_hi[2 * 14]);
1850   t_lo[7] = add_s64_round_narrow(&s64_lo[2 * 7], &s64_lo[2 * 15]);
1851   t_hi[7] = add_s64_round_narrow(&s64_hi[2 * 7], &s64_hi[2 * 15]);
1852   t_lo[8] = sub_s64_round_narrow(&s64_lo[2 * 0], &s64_lo[2 * 8]);
1853   t_hi[8] = sub_s64_round_narrow(&s64_hi[2 * 0], &s64_hi[2 * 8]);
1854   t_lo[9] = sub_s64_round_narrow(&s64_lo[2 * 1], &s64_lo[2 * 9]);
1855   t_hi[9] = sub_s64_round_narrow(&s64_hi[2 * 1], &s64_hi[2 * 9]);
1856   t_lo[10] = sub_s64_round_narrow(&s64_lo[2 * 2], &s64_lo[2 * 10]);
1857   t_hi[10] = sub_s64_round_narrow(&s64_hi[2 * 2], &s64_hi[2 * 10]);
1858   t_lo[11] = sub_s64_round_narrow(&s64_lo[2 * 3], &s64_lo[2 * 11]);
1859   t_hi[11] = sub_s64_round_narrow(&s64_hi[2 * 3], &s64_hi[2 * 11]);
1860   t_lo[12] = sub_s64_round_narrow(&s64_lo[2 * 4], &s64_lo[2 * 12]);
1861   t_hi[12] = sub_s64_round_narrow(&s64_hi[2 * 4], &s64_hi[2 * 12]);
1862   t_lo[13] = sub_s64_round_narrow(&s64_lo[2 * 5], &s64_lo[2 * 13]);
1863   t_hi[13] = sub_s64_round_narrow(&s64_hi[2 * 5], &s64_hi[2 * 13]);
1864   t_lo[14] = sub_s64_round_narrow(&s64_lo[2 * 6], &s64_lo[2 * 14]);
1865   t_hi[14] = sub_s64_round_narrow(&s64_hi[2 * 6], &s64_hi[2 * 14]);
1866   t_lo[15] = sub_s64_round_narrow(&s64_lo[2 * 7], &s64_lo[2 * 15]);
1867   t_hi[15] = sub_s64_round_narrow(&s64_hi[2 * 7], &s64_hi[2 * 15]);
1868 
1869   // stage 2
1870   s_lo[0] = t_lo[0];
1871   s_hi[0] = t_hi[0];
1872   s_lo[1] = t_lo[1];
1873   s_hi[1] = t_hi[1];
1874   s_lo[2] = t_lo[2];
1875   s_hi[2] = t_hi[2];
1876   s_lo[3] = t_lo[3];
1877   s_hi[3] = t_hi[3];
1878   s_lo[4] = t_lo[4];
1879   s_hi[4] = t_hi[4];
1880   s_lo[5] = t_lo[5];
1881   s_hi[5] = t_hi[5];
1882   s_lo[6] = t_lo[6];
1883   s_hi[6] = t_hi[6];
1884   s_lo[7] = t_lo[7];
1885   s_hi[7] = t_hi[7];
1886   // s8 = x8 * cospi_4_64 + x9 * cospi_28_64;
1887   // s9 = x8 * cospi_28_64 - x9 * cospi_4_64;
1888   butterfly_two_coeff_s32_s64_noround(
1889       t_lo[8], t_hi[8], t_lo[9], t_hi[9], cospi_4_64, cospi_28_64,
1890       &s64_lo[2 * 8], &s64_hi[2 * 8], &s64_lo[2 * 9], &s64_hi[2 * 9]);
1891   // s10 = x10 * cospi_20_64 + x11 * cospi_12_64;
1892   // s11 = x10 * cospi_12_64 - x11 * cospi_20_64;
1893   butterfly_two_coeff_s32_s64_noround(
1894       t_lo[10], t_hi[10], t_lo[11], t_hi[11], cospi_20_64, cospi_12_64,
1895       &s64_lo[2 * 10], &s64_hi[2 * 10], &s64_lo[2 * 11], &s64_hi[2 * 11]);
1896   // s12 = -x12 * cospi_28_64 + x13 * cospi_4_64;
1897   // s13 = x12 * cospi_4_64 + x13 * cospi_28_64;
1898   butterfly_two_coeff_s32_s64_noround(
1899       t_lo[13], t_hi[13], t_lo[12], t_hi[12], cospi_28_64, cospi_4_64,
1900       &s64_lo[2 * 13], &s64_hi[2 * 13], &s64_lo[2 * 12], &s64_hi[2 * 12]);
1901   // s14 = -x14 * cospi_12_64 + x15 * cospi_20_64;
1902   // s15 = x14 * cospi_20_64 + x15 * cospi_12_64;
1903   butterfly_two_coeff_s32_s64_noround(
1904       t_lo[15], t_hi[15], t_lo[14], t_hi[14], cospi_12_64, cospi_20_64,
1905       &s64_lo[2 * 15], &s64_hi[2 * 15], &s64_lo[2 * 14], &s64_hi[2 * 14]);
1906 
1907   // s0 + s4
1908   t_lo[0] = add_s32_s64_narrow(s_lo[0], s_lo[4]);
1909   t_hi[0] = add_s32_s64_narrow(s_hi[0], s_hi[4]);
1910   // s1 + s5
1911   t_lo[1] = add_s32_s64_narrow(s_lo[1], s_lo[5]);
1912   t_hi[1] = add_s32_s64_narrow(s_hi[1], s_hi[5]);
1913   // s2 + s6
1914   t_lo[2] = add_s32_s64_narrow(s_lo[2], s_lo[6]);
1915   t_hi[2] = add_s32_s64_narrow(s_hi[2], s_hi[6]);
1916   // s3 + s7
1917   t_lo[3] = add_s32_s64_narrow(s_lo[3], s_lo[7]);
1918   t_hi[3] = add_s32_s64_narrow(s_hi[3], s_hi[7]);
1919 
1920   // s0 - s4
1921   t_lo[4] = sub_s32_s64_narrow(s_lo[0], s_lo[4]);
1922   t_hi[4] = sub_s32_s64_narrow(s_hi[0], s_hi[4]);
1923   // s1 - s5
1924   t_lo[5] = sub_s32_s64_narrow(s_lo[1], s_lo[5]);
1925   t_hi[5] = sub_s32_s64_narrow(s_hi[1], s_hi[5]);
1926   // s2 - s6
1927   t_lo[6] = sub_s32_s64_narrow(s_lo[2], s_lo[6]);
1928   t_hi[6] = sub_s32_s64_narrow(s_hi[2], s_hi[6]);
1929   // s3 - s7
1930   t_lo[7] = sub_s32_s64_narrow(s_lo[3], s_lo[7]);
1931   t_hi[7] = sub_s32_s64_narrow(s_hi[3], s_hi[7]);
1932 
1933   // fdct_round_shift()
1934   // s8 + s12
1935   t_lo[8] = add_s64_round_narrow(&s64_lo[2 * 8], &s64_lo[2 * 12]);
1936   t_hi[8] = add_s64_round_narrow(&s64_hi[2 * 8], &s64_hi[2 * 12]);
1937   // s9 + s13
1938   t_lo[9] = add_s64_round_narrow(&s64_lo[2 * 9], &s64_lo[2 * 13]);
1939   t_hi[9] = add_s64_round_narrow(&s64_hi[2 * 9], &s64_hi[2 * 13]);
1940   // s10 + s14
1941   t_lo[10] = add_s64_round_narrow(&s64_lo[2 * 10], &s64_lo[2 * 14]);
1942   t_hi[10] = add_s64_round_narrow(&s64_hi[2 * 10], &s64_hi[2 * 14]);
1943   // s11 + s15
1944   t_lo[11] = add_s64_round_narrow(&s64_lo[2 * 11], &s64_lo[2 * 15]);
1945   t_hi[11] = add_s64_round_narrow(&s64_hi[2 * 11], &s64_hi[2 * 15]);
1946 
1947   // s8 - s12
1948   t_lo[12] = sub_s64_round_narrow(&s64_lo[2 * 8], &s64_lo[2 * 12]);
1949   t_hi[12] = sub_s64_round_narrow(&s64_hi[2 * 8], &s64_hi[2 * 12]);
1950   // s9 - s13
1951   t_lo[13] = sub_s64_round_narrow(&s64_lo[2 * 9], &s64_lo[2 * 13]);
1952   t_hi[13] = sub_s64_round_narrow(&s64_hi[2 * 9], &s64_hi[2 * 13]);
1953   // s10 - s14
1954   t_lo[14] = sub_s64_round_narrow(&s64_lo[2 * 10], &s64_lo[2 * 14]);
1955   t_hi[14] = sub_s64_round_narrow(&s64_hi[2 * 10], &s64_hi[2 * 14]);
1956   // s11 - s15
1957   t_lo[15] = sub_s64_round_narrow(&s64_lo[2 * 11], &s64_lo[2 * 15]);
1958   t_hi[15] = sub_s64_round_narrow(&s64_hi[2 * 11], &s64_hi[2 * 15]);
1959 
1960   // stage 3
1961   s_lo[0] = t_lo[0];
1962   s_hi[0] = t_hi[0];
1963   s_lo[1] = t_lo[1];
1964   s_hi[1] = t_hi[1];
1965   s_lo[2] = t_lo[2];
1966   s_hi[2] = t_hi[2];
1967   s_lo[3] = t_lo[3];
1968   s_hi[3] = t_hi[3];
1969   // s4 = x4 * cospi_8_64 + x5 * cospi_24_64;
1970   // s5 = x4 * cospi_24_64 - x5 * cospi_8_64;
1971   butterfly_two_coeff_s32_s64_noround(
1972       t_lo[4], t_hi[4], t_lo[5], t_hi[5], cospi_8_64, cospi_24_64,
1973       &s64_lo[2 * 4], &s64_hi[2 * 4], &s64_lo[2 * 5], &s64_hi[2 * 5]);
1974   // s6 = -x6 * cospi_24_64 + x7 * cospi_8_64;
1975   // s7 = x6 * cospi_8_64 + x7 * cospi_24_64;
1976   butterfly_two_coeff_s32_s64_noround(
1977       t_lo[7], t_hi[7], t_lo[6], t_hi[6], cospi_24_64, cospi_8_64,
1978       &s64_lo[2 * 7], &s64_hi[2 * 7], &s64_lo[2 * 6], &s64_hi[2 * 6]);
1979   s_lo[8] = t_lo[8];
1980   s_hi[8] = t_hi[8];
1981   s_lo[9] = t_lo[9];
1982   s_hi[9] = t_hi[9];
1983   s_lo[10] = t_lo[10];
1984   s_hi[10] = t_hi[10];
1985   s_lo[11] = t_lo[11];
1986   s_hi[11] = t_hi[11];
1987   // s12 = x12 * cospi_8_64 + x13 * cospi_24_64;
1988   // s13 = x12 * cospi_24_64 - x13 * cospi_8_64;
1989   butterfly_two_coeff_s32_s64_noround(
1990       t_lo[12], t_hi[12], t_lo[13], t_hi[13], cospi_8_64, cospi_24_64,
1991       &s64_lo[2 * 12], &s64_hi[2 * 12], &s64_lo[2 * 13], &s64_hi[2 * 13]);
1992   // s14 = -x14 * cospi_24_64 + x15 * cospi_8_64;
1993   // s15 = x14 * cospi_8_64 + x15 * cospi_24_64;
1994   butterfly_two_coeff_s32_s64_noround(
1995       t_lo[15], t_hi[15], t_lo[14], t_hi[14], cospi_24_64, cospi_8_64,
1996       &s64_lo[2 * 15], &s64_hi[2 * 15], &s64_lo[2 * 14], &s64_hi[2 * 14]);
1997 
1998   // s0 + s2
1999   t_lo[0] = add_s32_s64_narrow(s_lo[0], s_lo[2]);
2000   t_hi[0] = add_s32_s64_narrow(s_hi[0], s_hi[2]);
2001   // s1 + s3
2002   t_lo[1] = add_s32_s64_narrow(s_lo[1], s_lo[3]);
2003   t_hi[1] = add_s32_s64_narrow(s_hi[1], s_hi[3]);
2004   // s0 - s2
2005   t_lo[2] = sub_s32_s64_narrow(s_lo[0], s_lo[2]);
2006   t_hi[2] = sub_s32_s64_narrow(s_hi[0], s_hi[2]);
2007   // s1 - s3
2008   t_lo[3] = sub_s32_s64_narrow(s_lo[1], s_lo[3]);
2009   t_hi[3] = sub_s32_s64_narrow(s_hi[1], s_hi[3]);
2010   // fdct_round_shift()
2011   // s4 + s6
2012   t_lo[4] = add_s64_round_narrow(&s64_lo[2 * 4], &s64_lo[2 * 6]);
2013   t_hi[4] = add_s64_round_narrow(&s64_hi[2 * 4], &s64_hi[2 * 6]);
2014   // s5 + s7
2015   t_lo[5] = add_s64_round_narrow(&s64_lo[2 * 5], &s64_lo[2 * 7]);
2016   t_hi[5] = add_s64_round_narrow(&s64_hi[2 * 5], &s64_hi[2 * 7]);
2017   // s4 - s6
2018   t_lo[6] = sub_s64_round_narrow(&s64_lo[2 * 4], &s64_lo[2 * 6]);
2019   t_hi[6] = sub_s64_round_narrow(&s64_hi[2 * 4], &s64_hi[2 * 6]);
2020   // s5 - s7
2021   t_lo[7] = sub_s64_round_narrow(&s64_lo[2 * 5], &s64_lo[2 * 7]);
2022   t_hi[7] = sub_s64_round_narrow(&s64_hi[2 * 5], &s64_hi[2 * 7]);
2023   // s8 + s10
2024   t_lo[8] = add_s32_s64_narrow(s_lo[8], s_lo[10]);
2025   t_hi[8] = add_s32_s64_narrow(s_hi[8], s_hi[10]);
2026   // s9 + s11
2027   t_lo[9] = add_s32_s64_narrow(s_lo[9], s_lo[11]);
2028   t_hi[9] = add_s32_s64_narrow(s_hi[9], s_hi[11]);
2029   // s8 - s10
2030   t_lo[10] = sub_s32_s64_narrow(s_lo[8], s_lo[10]);
2031   t_hi[10] = sub_s32_s64_narrow(s_hi[8], s_hi[10]);
2032   // s9 - s11
2033   t_lo[11] = sub_s32_s64_narrow(s_lo[9], s_lo[11]);
2034   t_hi[11] = sub_s32_s64_narrow(s_hi[9], s_hi[11]);
2035   // fdct_round_shift()
2036   // s12 + s14
2037   t_lo[12] = add_s64_round_narrow(&s64_lo[2 * 12], &s64_lo[2 * 14]);
2038   t_hi[12] = add_s64_round_narrow(&s64_hi[2 * 12], &s64_hi[2 * 14]);
2039   // s13 + s15
2040   t_lo[13] = add_s64_round_narrow(&s64_lo[2 * 13], &s64_lo[2 * 15]);
2041   t_hi[13] = add_s64_round_narrow(&s64_hi[2 * 13], &s64_hi[2 * 15]);
2042   // s12 - s14
2043   t_lo[14] = sub_s64_round_narrow(&s64_lo[2 * 12], &s64_lo[2 * 14]);
2044   t_hi[14] = sub_s64_round_narrow(&s64_hi[2 * 12], &s64_hi[2 * 14]);
2045   // s13 - s15
2046   t_lo[15] = sub_s64_round_narrow(&s64_lo[2 * 13], &s64_lo[2 * 15]);
2047   t_hi[15] = sub_s64_round_narrow(&s64_hi[2 * 13], &s64_hi[2 * 15]);
2048 
2049   // stage 4, with fdct_round_shift
2050   // s2 = (-cospi_16_64) * (x2 + x3);
2051   // s3 = cospi_16_64 * (x2 - x3);
2052   butterfly_one_coeff_s32_s64_narrow(t_lo[3], t_hi[3], t_lo[2], t_hi[2],
2053                                      -cospi_16_64, &x_lo[2], &x_hi[2], &x_lo[3],
2054                                      &x_hi[3]);
2055   // s6 = cospi_16_64 * (x6 + x7);
2056   // s7 = cospi_16_64 * (-x6 + x7);
2057   butterfly_one_coeff_s32_s64_narrow(t_lo[7], t_hi[7], t_lo[6], t_hi[6],
2058                                      cospi_16_64, &x_lo[6], &x_hi[6], &x_lo[7],
2059                                      &x_hi[7]);
2060   // s10 = cospi_16_64 * (x10 + x11);
2061   // s11 = cospi_16_64 * (-x10 + x11);
2062   butterfly_one_coeff_s32_s64_narrow(t_lo[11], t_hi[11], t_lo[10], t_hi[10],
2063                                      cospi_16_64, &x_lo[10], &x_hi[10],
2064                                      &x_lo[11], &x_hi[11]);
2065   // s14 = (-cospi_16_64) * (x14 + x15);
2066   // s15 = cospi_16_64 * (x14 - x15);
2067   butterfly_one_coeff_s32_s64_narrow(t_lo[15], t_hi[15], t_lo[14], t_hi[14],
2068                                      -cospi_16_64, &x_lo[14], &x_hi[14],
2069                                      &x_lo[15], &x_hi[15]);
2070 
2071   // Just copy x0, x1, x4, x5, x8, x9, x12, x13
2072   x_lo[0] = t_lo[0];
2073   x_hi[0] = t_hi[0];
2074   x_lo[1] = t_lo[1];
2075   x_hi[1] = t_hi[1];
2076   x_lo[4] = t_lo[4];
2077   x_hi[4] = t_hi[4];
2078   x_lo[5] = t_lo[5];
2079   x_hi[5] = t_hi[5];
2080   x_lo[8] = t_lo[8];
2081   x_hi[8] = t_hi[8];
2082   x_lo[9] = t_lo[9];
2083   x_hi[9] = t_hi[9];
2084   x_lo[12] = t_lo[12];
2085   x_hi[12] = t_hi[12];
2086   x_lo[13] = t_lo[13];
2087   x_hi[13] = t_hi[13];
2088 
2089   left[0] = x_lo[0];
2090   right[0] = x_hi[0];
2091   left[1] = vnegq_s32(x_lo[8]);
2092   right[1] = vnegq_s32(x_hi[8]);
2093   left[2] = x_lo[12];
2094   right[2] = x_hi[12];
2095   left[3] = vnegq_s32(x_lo[4]);
2096   right[3] = vnegq_s32(x_hi[4]);
2097   left[4] = x_lo[6];
2098   right[4] = x_hi[6];
2099   left[5] = x_lo[14];
2100   right[5] = x_hi[14];
2101   left[6] = x_lo[10];
2102   right[6] = x_hi[10];
2103   left[7] = x_lo[2];
2104   right[7] = x_hi[2];
2105   left[8] = x_lo[3];
2106   right[8] = x_hi[3];
2107   left[9] = x_lo[11];
2108   right[9] = x_hi[11];
2109   left[10] = x_lo[15];
2110   right[10] = x_hi[15];
2111   left[11] = x_lo[7];
2112   right[11] = x_hi[7];
2113   left[12] = x_lo[5];
2114   right[12] = x_hi[5];
2115   left[13] = vnegq_s32(x_lo[13]);
2116   right[13] = vnegq_s32(x_hi[13]);
2117   left[14] = x_lo[9];
2118   right[14] = x_hi[9];
2119   left[15] = vnegq_s32(x_lo[1]);
2120   right[15] = vnegq_s32(x_hi[1]);
2121 }
2122 
highbd_fdct16x16_neon(int32x4_t * left1,int32x4_t * right1,int32x4_t * left2,int32x4_t * right2)2123 static void highbd_fdct16x16_neon(int32x4_t *left1, int32x4_t *right1,
2124                                   int32x4_t *left2, int32x4_t *right2) {
2125   // Left half.
2126   highbd_fdct16_8col(left1, right1);
2127   // Right half.
2128   highbd_fdct16_8col(left2, right2);
2129   transpose_s32_16x16(left1, right1, left2, right2);
2130 }
2131 
highbd_fadst16x16_neon(int32x4_t * left1,int32x4_t * right1,int32x4_t * left2,int32x4_t * right2)2132 static void highbd_fadst16x16_neon(int32x4_t *left1, int32x4_t *right1,
2133                                    int32x4_t *left2, int32x4_t *right2) {
2134   // Left half.
2135   highbd_fadst16_8col(left1, right1);
2136   // Right half.
2137   highbd_fadst16_8col(left2, right2);
2138   transpose_s32_16x16(left1, right1, left2, right2);
2139 }
2140 
vp9_highbd_fht16x16_neon(const int16_t * input,tran_low_t * output,int stride,int tx_type)2141 void vp9_highbd_fht16x16_neon(const int16_t *input, tran_low_t *output,
2142                               int stride, int tx_type) {
2143   int32x4_t left1[16], right1[16], left2[16], right2[16];
2144 
2145   switch (tx_type) {
2146     case DCT_DCT: vpx_highbd_fdct16x16_neon(input, output, stride); break;
2147     case ADST_DCT:
2148       highbd_load_buffer_16x16(input, left1, right1, left2, right2, stride);
2149       highbd_fadst16x16_neon(left1, right1, left2, right2);
2150       highbd_write_buffer_16x16(output, left1, right1, left2, right2, 16);
2151       highbd_right_shift_16x16(left1, right1, left2, right2, 2);
2152       highbd_fdct16x16_neon(left1, right1, left2, right2);
2153       highbd_write_buffer_16x16(output, left1, right1, left2, right2, 16);
2154       break;
2155     case DCT_ADST:
2156       highbd_load_buffer_16x16(input, left1, right1, left2, right2, stride);
2157       highbd_fdct16x16_neon(left1, right1, left2, right2);
2158       highbd_right_shift_16x16(left1, right1, left2, right2, 2);
2159       highbd_fadst16x16_neon(left1, right1, left2, right2);
2160       highbd_write_buffer_16x16(output, left1, right1, left2, right2, 16);
2161       break;
2162     default:
2163       assert(tx_type == ADST_ADST);
2164       highbd_load_buffer_16x16(input, left1, right1, left2, right2, stride);
2165       highbd_fadst16x16_neon(left1, right1, left2, right2);
2166       highbd_right_shift_16x16(left1, right1, left2, right2, 2);
2167       highbd_fadst16x16_neon(left1, right1, left2, right2);
2168       highbd_write_buffer_16x16(output, left1, right1, left2, right2, 16);
2169       break;
2170   }
2171 }
2172 
2173 #endif  // CONFIG_VP9_HIGHBITDEPTH
2174