xref: /aosp_15_r20/external/libgav1/src/dsp/arm/inverse_transform_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/inverse_transform.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/array_2d.h"
30 #include "src/utils/common.h"
31 #include "src/utils/compiler_attributes.h"
32 #include "src/utils/constants.h"
33 
34 namespace libgav1 {
35 namespace dsp {
36 namespace low_bitdepth {
37 namespace {
38 
39 // Include the constants and utility functions inside the anonymous namespace.
40 #include "src/dsp/inverse_transform.inc"
41 
42 //------------------------------------------------------------------------------
43 
44 // Note this is only used in the final stage of Dct32/64 and Adst16 as the in
45 // place version causes additional stack usage with clang.
Transpose8x8(const int16x8_t in[8],int16x8_t out[8])46 LIBGAV1_ALWAYS_INLINE void Transpose8x8(const int16x8_t in[8],
47                                         int16x8_t out[8]) {
48   // Swap 16 bit elements. Goes from:
49   // a0: 00 01 02 03 04 05 06 07
50   // a1: 10 11 12 13 14 15 16 17
51   // a2: 20 21 22 23 24 25 26 27
52   // a3: 30 31 32 33 34 35 36 37
53   // a4: 40 41 42 43 44 45 46 47
54   // a5: 50 51 52 53 54 55 56 57
55   // a6: 60 61 62 63 64 65 66 67
56   // a7: 70 71 72 73 74 75 76 77
57   // to:
58   // b0.val[0]: 00 10 02 12 04 14 06 16
59   // b0.val[1]: 01 11 03 13 05 15 07 17
60   // b1.val[0]: 20 30 22 32 24 34 26 36
61   // b1.val[1]: 21 31 23 33 25 35 27 37
62   // b2.val[0]: 40 50 42 52 44 54 46 56
63   // b2.val[1]: 41 51 43 53 45 55 47 57
64   // b3.val[0]: 60 70 62 72 64 74 66 76
65   // b3.val[1]: 61 71 63 73 65 75 67 77
66 
67   const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]);
68   const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]);
69   const int16x8x2_t b2 = vtrnq_s16(in[4], in[5]);
70   const int16x8x2_t b3 = vtrnq_s16(in[6], in[7]);
71 
72   // Swap 32 bit elements resulting in:
73   // c0.val[0]: 00 10 20 30 04 14 24 34
74   // c0.val[1]: 02 12 22 32 06 16 26 36
75   // c1.val[0]: 01 11 21 31 05 15 25 35
76   // c1.val[1]: 03 13 23 33 07 17 27 37
77   // c2.val[0]: 40 50 60 70 44 54 64 74
78   // c2.val[1]: 42 52 62 72 46 56 66 76
79   // c3.val[0]: 41 51 61 71 45 55 65 75
80   // c3.val[1]: 43 53 63 73 47 57 67 77
81 
82   const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
83                                    vreinterpretq_s32_s16(b1.val[0]));
84   const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
85                                    vreinterpretq_s32_s16(b1.val[1]));
86   const int32x4x2_t c2 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[0]),
87                                    vreinterpretq_s32_s16(b3.val[0]));
88   const int32x4x2_t c3 = vtrnq_s32(vreinterpretq_s32_s16(b2.val[1]),
89                                    vreinterpretq_s32_s16(b3.val[1]));
90 
91   // Swap 64 bit elements resulting in:
92   // d0.val[0]: 00 10 20 30 40 50 60 70
93   // d0.val[1]: 04 14 24 34 44 54 64 74
94   // d1.val[0]: 01 11 21 31 41 51 61 71
95   // d1.val[1]: 05 15 25 35 45 55 65 75
96   // d2.val[0]: 02 12 22 32 42 52 62 72
97   // d2.val[1]: 06 16 26 36 46 56 66 76
98   // d3.val[0]: 03 13 23 33 43 53 63 73
99   // d3.val[1]: 07 17 27 37 47 57 67 77
100   const int16x8x2_t d0 = VtrnqS64(c0.val[0], c2.val[0]);
101   const int16x8x2_t d1 = VtrnqS64(c1.val[0], c3.val[0]);
102   const int16x8x2_t d2 = VtrnqS64(c0.val[1], c2.val[1]);
103   const int16x8x2_t d3 = VtrnqS64(c1.val[1], c3.val[1]);
104 
105   out[0] = d0.val[0];
106   out[1] = d1.val[0];
107   out[2] = d2.val[0];
108   out[3] = d3.val[0];
109   out[4] = d0.val[1];
110   out[5] = d1.val[1];
111   out[6] = d2.val[1];
112   out[7] = d3.val[1];
113 }
114 
Transpose4x8To8x4(const uint16x8_t in[8],uint16x8_t out[4])115 LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const uint16x8_t in[8],
116                                              uint16x8_t out[4]) {
117   // Swap 16 bit elements. Goes from:
118   // a0: 00 01 02 03
119   // a1: 10 11 12 13
120   // a2: 20 21 22 23
121   // a3: 30 31 32 33
122   // a4: 40 41 42 43
123   // a5: 50 51 52 53
124   // a6: 60 61 62 63
125   // a7: 70 71 72 73
126   // to:
127   // b0.val[0]: 00 10 02 12
128   // b0.val[1]: 01 11 03 13
129   // b1.val[0]: 20 30 22 32
130   // b1.val[1]: 21 31 23 33
131   // b2.val[0]: 40 50 42 52
132   // b2.val[1]: 41 51 43 53
133   // b3.val[0]: 60 70 62 72
134   // b3.val[1]: 61 71 63 73
135 
136   uint16x4x2_t b0 = vtrn_u16(vget_low_u16(in[0]), vget_low_u16(in[1]));
137   uint16x4x2_t b1 = vtrn_u16(vget_low_u16(in[2]), vget_low_u16(in[3]));
138   uint16x4x2_t b2 = vtrn_u16(vget_low_u16(in[4]), vget_low_u16(in[5]));
139   uint16x4x2_t b3 = vtrn_u16(vget_low_u16(in[6]), vget_low_u16(in[7]));
140 
141   // Swap 32 bit elements resulting in:
142   // c0.val[0]: 00 10 20 30
143   // c0.val[1]: 02 12 22 32
144   // c1.val[0]: 01 11 21 31
145   // c1.val[1]: 03 13 23 33
146   // c2.val[0]: 40 50 60 70
147   // c2.val[1]: 42 52 62 72
148   // c3.val[0]: 41 51 61 71
149   // c3.val[1]: 43 53 63 73
150 
151   uint32x2x2_t c0 = vtrn_u32(vreinterpret_u32_u16(b0.val[0]),
152                              vreinterpret_u32_u16(b1.val[0]));
153   uint32x2x2_t c1 = vtrn_u32(vreinterpret_u32_u16(b0.val[1]),
154                              vreinterpret_u32_u16(b1.val[1]));
155   uint32x2x2_t c2 = vtrn_u32(vreinterpret_u32_u16(b2.val[0]),
156                              vreinterpret_u32_u16(b3.val[0]));
157   uint32x2x2_t c3 = vtrn_u32(vreinterpret_u32_u16(b2.val[1]),
158                              vreinterpret_u32_u16(b3.val[1]));
159 
160   // Swap 64 bit elements resulting in:
161   // o0: 00 10 20 30 40 50 60 70
162   // o1: 01 11 21 31 41 51 61 71
163   // o2: 02 12 22 32 42 52 62 72
164   // o3: 03 13 23 33 43 53 63 73
165 
166   out[0] = vcombine_u16(vreinterpret_u16_u32(c0.val[0]),
167                         vreinterpret_u16_u32(c2.val[0]));
168   out[1] = vcombine_u16(vreinterpret_u16_u32(c1.val[0]),
169                         vreinterpret_u16_u32(c3.val[0]));
170   out[2] = vcombine_u16(vreinterpret_u16_u32(c0.val[1]),
171                         vreinterpret_u16_u32(c2.val[1]));
172   out[3] = vcombine_u16(vreinterpret_u16_u32(c1.val[1]),
173                         vreinterpret_u16_u32(c3.val[1]));
174 }
175 
Transpose4x8To8x4(const int16x8_t in[8],int16x8_t out[4])176 LIBGAV1_ALWAYS_INLINE void Transpose4x8To8x4(const int16x8_t in[8],
177                                              int16x8_t out[4]) {
178   Transpose4x8To8x4(reinterpret_cast<const uint16x8_t*>(in),
179                     reinterpret_cast<uint16x8_t*>(out));
180 }
181 
Transpose8x4To4x8(const int16x8_t in[4],int16x8_t out[8])182 LIBGAV1_ALWAYS_INLINE void Transpose8x4To4x8(const int16x8_t in[4],
183                                              int16x8_t out[8]) {
184   // Swap 16 bit elements. Goes from:
185   // a0: 00 01 02 03 04 05 06 07
186   // a1: 10 11 12 13 14 15 16 17
187   // a2: 20 21 22 23 24 25 26 27
188   // a3: 30 31 32 33 34 35 36 37
189   // to:
190   // b0.val[0]: 00 10 02 12 04 14 06 16
191   // b0.val[1]: 01 11 03 13 05 15 07 17
192   // b1.val[0]: 20 30 22 32 24 34 26 36
193   // b1.val[1]: 21 31 23 33 25 35 27 37
194   const int16x8x2_t b0 = vtrnq_s16(in[0], in[1]);
195   const int16x8x2_t b1 = vtrnq_s16(in[2], in[3]);
196 
197   // Swap 32 bit elements resulting in:
198   // c0.val[0]: 00 10 20 30 04 14 24 34
199   // c0.val[1]: 02 12 22 32 06 16 26 36
200   // c1.val[0]: 01 11 21 31 05 15 25 35
201   // c1.val[1]: 03 13 23 33 07 17 27 37
202   const int32x4x2_t c0 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[0]),
203                                    vreinterpretq_s32_s16(b1.val[0]));
204   const int32x4x2_t c1 = vtrnq_s32(vreinterpretq_s32_s16(b0.val[1]),
205                                    vreinterpretq_s32_s16(b1.val[1]));
206 
207   // The upper 8 bytes are don't cares.
208   // out[0]: 00 10 20 30 04 14 24 34
209   // out[1]: 01 11 21 31 05 15 25 35
210   // out[2]: 02 12 22 32 06 16 26 36
211   // out[3]: 03 13 23 33 07 17 27 37
212   // out[4]: 04 14 24 34 04 14 24 34
213   // out[5]: 05 15 25 35 05 15 25 35
214   // out[6]: 06 16 26 36 06 16 26 36
215   // out[7]: 07 17 27 37 07 17 27 37
216   out[0] = vreinterpretq_s16_s32(c0.val[0]);
217   out[1] = vreinterpretq_s16_s32(c1.val[0]);
218   out[2] = vreinterpretq_s16_s32(c0.val[1]);
219   out[3] = vreinterpretq_s16_s32(c1.val[1]);
220   out[4] = vreinterpretq_s16_s32(
221       vcombine_s32(vget_high_s32(c0.val[0]), vget_high_s32(c0.val[0])));
222   out[5] = vreinterpretq_s16_s32(
223       vcombine_s32(vget_high_s32(c1.val[0]), vget_high_s32(c1.val[0])));
224   out[6] = vreinterpretq_s16_s32(
225       vcombine_s32(vget_high_s32(c0.val[1]), vget_high_s32(c0.val[1])));
226   out[7] = vreinterpretq_s16_s32(
227       vcombine_s32(vget_high_s32(c1.val[1]), vget_high_s32(c1.val[1])));
228 }
229 
230 //------------------------------------------------------------------------------
231 template <int store_width, int store_count>
StoreDst(int16_t * LIBGAV1_RESTRICT dst,int32_t stride,int32_t idx,const int16x8_t * const s)232 LIBGAV1_ALWAYS_INLINE void StoreDst(int16_t* LIBGAV1_RESTRICT dst,
233                                     int32_t stride, int32_t idx,
234                                     const int16x8_t* const s) {
235   assert(store_count % 4 == 0);
236   assert(store_width == 8 || store_width == 16);
237   // NOTE: It is expected that the compiler will unroll these loops.
238   if (store_width == 16) {
239     for (int i = 0; i < store_count; i += 4) {
240       vst1q_s16(&dst[i * stride + idx], (s[i]));
241       vst1q_s16(&dst[(i + 1) * stride + idx], (s[i + 1]));
242       vst1q_s16(&dst[(i + 2) * stride + idx], (s[i + 2]));
243       vst1q_s16(&dst[(i + 3) * stride + idx], (s[i + 3]));
244     }
245   } else {
246     // store_width == 8
247     for (int i = 0; i < store_count; i += 4) {
248       vst1_s16(&dst[i * stride + idx], vget_low_s16(s[i]));
249       vst1_s16(&dst[(i + 1) * stride + idx], vget_low_s16(s[i + 1]));
250       vst1_s16(&dst[(i + 2) * stride + idx], vget_low_s16(s[i + 2]));
251       vst1_s16(&dst[(i + 3) * stride + idx], vget_low_s16(s[i + 3]));
252     }
253   }
254 }
255 
256 template <int load_width, int load_count>
LoadSrc(const int16_t * LIBGAV1_RESTRICT src,int32_t stride,int32_t idx,int16x8_t * x)257 LIBGAV1_ALWAYS_INLINE void LoadSrc(const int16_t* LIBGAV1_RESTRICT src,
258                                    int32_t stride, int32_t idx, int16x8_t* x) {
259   assert(load_count % 4 == 0);
260   assert(load_width == 8 || load_width == 16);
261   // NOTE: It is expected that the compiler will unroll these loops.
262   if (load_width == 16) {
263     for (int i = 0; i < load_count; i += 4) {
264       x[i] = vld1q_s16(&src[i * stride + idx]);
265       x[i + 1] = vld1q_s16(&src[(i + 1) * stride + idx]);
266       x[i + 2] = vld1q_s16(&src[(i + 2) * stride + idx]);
267       x[i + 3] = vld1q_s16(&src[(i + 3) * stride + idx]);
268     }
269   } else {
270     // load_width == 8
271     const int64x2_t zero = vdupq_n_s64(0);
272     for (int i = 0; i < load_count; i += 4) {
273       // The src buffer is aligned to 32 bytes.  Each load will always be 8
274       // byte aligned.
275       x[i] = vreinterpretq_s16_s64(vld1q_lane_s64(
276           reinterpret_cast<const int64_t*>(&src[i * stride + idx]), zero, 0));
277       x[i + 1] = vreinterpretq_s16_s64(vld1q_lane_s64(
278           reinterpret_cast<const int64_t*>(&src[(i + 1) * stride + idx]), zero,
279           0));
280       x[i + 2] = vreinterpretq_s16_s64(vld1q_lane_s64(
281           reinterpret_cast<const int64_t*>(&src[(i + 2) * stride + idx]), zero,
282           0));
283       x[i + 3] = vreinterpretq_s16_s64(vld1q_lane_s64(
284           reinterpret_cast<const int64_t*>(&src[(i + 3) * stride + idx]), zero,
285           0));
286     }
287   }
288 }
289 
290 // Butterfly rotate 4 values.
ButterflyRotation_4(int16x8_t * a,int16x8_t * b,const int angle,const bool flip)291 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_4(int16x8_t* a, int16x8_t* b,
292                                                const int angle,
293                                                const bool flip) {
294   const int16_t cos128 = Cos128(angle);
295   const int16_t sin128 = Sin128(angle);
296   const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128);
297   const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128);
298   const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128);
299   const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128);
300   const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
301   const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
302   const int16x8_t x = vcombine_s16(x1, x1);
303   const int16x8_t y = vcombine_s16(y1, y1);
304   if (flip) {
305     *a = y;
306     *b = x;
307   } else {
308     *a = x;
309     *b = y;
310   }
311 }
312 
313 // Butterfly rotate 8 values.
ButterflyRotation_8(int16x8_t * a,int16x8_t * b,const int angle,const bool flip)314 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_8(int16x8_t* a, int16x8_t* b,
315                                                const int angle,
316                                                const bool flip) {
317   const int16_t cos128 = Cos128(angle);
318   const int16_t sin128 = Sin128(angle);
319   const int32x4_t acc_x = vmull_n_s16(vget_low_s16(*a), cos128);
320   const int32x4_t acc_y = vmull_n_s16(vget_low_s16(*a), sin128);
321   const int32x4_t x0 = vmlsl_n_s16(acc_x, vget_low_s16(*b), sin128);
322   const int32x4_t y0 = vmlal_n_s16(acc_y, vget_low_s16(*b), cos128);
323   const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
324   const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
325 
326   const int32x4_t acc_x_hi = vmull_n_s16(vget_high_s16(*a), cos128);
327   const int32x4_t acc_y_hi = vmull_n_s16(vget_high_s16(*a), sin128);
328   const int32x4_t x0_hi = vmlsl_n_s16(acc_x_hi, vget_high_s16(*b), sin128);
329   const int32x4_t y0_hi = vmlal_n_s16(acc_y_hi, vget_high_s16(*b), cos128);
330   const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
331   const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12);
332 
333   const int16x8_t x = vcombine_s16(x1, x1_hi);
334   const int16x8_t y = vcombine_s16(y1, y1_hi);
335   if (flip) {
336     *a = y;
337     *b = x;
338   } else {
339     *a = x;
340     *b = y;
341   }
342 }
343 
ButterflyRotation_FirstIsZero(int16x8_t * a,int16x8_t * b,const int angle,const bool flip)344 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_FirstIsZero(int16x8_t* a,
345                                                          int16x8_t* b,
346                                                          const int angle,
347                                                          const bool flip) {
348   // Clang < 14 targeting armv8.1-a+ optimizes vqrdmulhq_n_s16 and vqsubq_s16
349   // (in HadamardRotation) into vqrdmlshq_s16 resulting in an "off by one"
350   // error. This behavior was fixed in 14.0.0:
351   // https://github.com/llvm/llvm-project/commit/82973edfb72a95b442fa6d2bb404e15a4031855e
352 #if defined(__ARM_FEATURE_QRDMX) && defined(__aarch64__) && \
353     defined(__clang__) && __clang_major__ < 14
354   const int16_t cos128 = Cos128(angle);
355   const int16_t sin128 = Sin128(angle);
356   const int32x4_t x0 = vmull_n_s16(vget_low_s16(*b), -sin128);
357   const int32x4_t y0 = vmull_n_s16(vget_low_s16(*b), cos128);
358   const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
359   const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
360 
361   const int32x4_t x0_hi = vmull_n_s16(vget_high_s16(*b), -sin128);
362   const int32x4_t y0_hi = vmull_n_s16(vget_high_s16(*b), cos128);
363   const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
364   const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12);
365 
366   const int16x8_t x = vcombine_s16(x1, x1_hi);
367   const int16x8_t y = vcombine_s16(y1, y1_hi);
368   if (flip) {
369     *a = y;
370     *b = x;
371   } else {
372     *a = x;
373     *b = y;
374   }
375 #else
376   const int16_t cos128 = Cos128(angle);
377   const int16_t sin128 = Sin128(angle);
378   // For this function, the max value returned by Sin128() is 4091, which fits
379   // inside 12 bits.  This leaves room for the sign bit and the 3 left shifted
380   // bits.
381   assert(sin128 <= 0xfff);
382   const int16x8_t x = vqrdmulhq_n_s16(*b, -sin128 << 3);
383   const int16x8_t y = vqrdmulhq_n_s16(*b, cos128 << 3);
384   if (flip) {
385     *a = y;
386     *b = x;
387   } else {
388     *a = x;
389     *b = y;
390   }
391 #endif
392 }
393 
ButterflyRotation_SecondIsZero(int16x8_t * a,int16x8_t * b,const int angle,const bool flip)394 LIBGAV1_ALWAYS_INLINE void ButterflyRotation_SecondIsZero(int16x8_t* a,
395                                                           int16x8_t* b,
396                                                           const int angle,
397                                                           const bool flip) {
398 #if defined(__ARM_FEATURE_QRDMX) && defined(__aarch64__) && \
399     defined(__clang__)  // ARM v8.1-A
400   // Clang optimizes vqrdmulhq_n_s16 and vqsubq_s16 (in HadamardRotation) into
401   // vqrdmlshq_s16 resulting in an "off by one" error. For now, do not use
402   // vqrdmulhq_n_s16().
403   const int16_t cos128 = Cos128(angle);
404   const int16_t sin128 = Sin128(angle);
405   const int32x4_t x0 = vmull_n_s16(vget_low_s16(*a), cos128);
406   const int32x4_t y0 = vmull_n_s16(vget_low_s16(*a), sin128);
407   const int16x4_t x1 = vqrshrn_n_s32(x0, 12);
408   const int16x4_t y1 = vqrshrn_n_s32(y0, 12);
409 
410   const int32x4_t x0_hi = vmull_n_s16(vget_high_s16(*a), cos128);
411   const int32x4_t y0_hi = vmull_n_s16(vget_high_s16(*a), sin128);
412   const int16x4_t x1_hi = vqrshrn_n_s32(x0_hi, 12);
413   const int16x4_t y1_hi = vqrshrn_n_s32(y0_hi, 12);
414 
415   const int16x8_t x = vcombine_s16(x1, x1_hi);
416   const int16x8_t y = vcombine_s16(y1, y1_hi);
417   if (flip) {
418     *a = y;
419     *b = x;
420   } else {
421     *a = x;
422     *b = y;
423   }
424 #else
425   const int16_t cos128 = Cos128(angle);
426   const int16_t sin128 = Sin128(angle);
427   const int16x8_t x = vqrdmulhq_n_s16(*a, cos128 << 3);
428   const int16x8_t y = vqrdmulhq_n_s16(*a, sin128 << 3);
429   if (flip) {
430     *a = y;
431     *b = x;
432   } else {
433     *a = x;
434     *b = y;
435   }
436 #endif
437 }
438 
HadamardRotation(int16x8_t * a,int16x8_t * b,bool flip)439 LIBGAV1_ALWAYS_INLINE void HadamardRotation(int16x8_t* a, int16x8_t* b,
440                                             bool flip) {
441   int16x8_t x, y;
442   if (flip) {
443     y = vqaddq_s16(*b, *a);
444     x = vqsubq_s16(*b, *a);
445   } else {
446     x = vqaddq_s16(*a, *b);
447     y = vqsubq_s16(*a, *b);
448   }
449   *a = x;
450   *b = y;
451 }
452 
453 using ButterflyRotationFunc = void (*)(int16x8_t* a, int16x8_t* b, int angle,
454                                        bool flip);
455 
456 //------------------------------------------------------------------------------
457 // Discrete Cosine Transforms (DCT).
458 
459 template <int width>
DctDcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)460 LIBGAV1_ALWAYS_INLINE bool DctDcOnly(void* dest, int adjusted_tx_height,
461                                      bool should_round, int row_shift) {
462   if (adjusted_tx_height > 1) return false;
463 
464   auto* dst = static_cast<int16_t*>(dest);
465   const int16x8_t v_src = vdupq_n_s16(dst[0]);
466   const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
467   const int16x8_t v_src_round =
468       vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
469   const int16x8_t s0 = vbslq_s16(v_mask, v_src_round, v_src);
470   const int16_t cos128 = Cos128(32);
471   const int16x8_t xy = vqrdmulhq_n_s16(s0, cos128 << 3);
472   // vqrshlq_s16 will shift right if shift value is negative.
473   const int16x8_t xy_shifted = vqrshlq_s16(xy, vdupq_n_s16(-row_shift));
474 
475   if (width == 4) {
476     vst1_s16(dst, vget_low_s16(xy_shifted));
477   } else {
478     for (int i = 0; i < width; i += 8) {
479       vst1q_s16(dst, xy_shifted);
480       dst += 8;
481     }
482   }
483   return true;
484 }
485 
486 template <int height>
DctDcOnlyColumn(void * dest,int adjusted_tx_height,int width)487 LIBGAV1_ALWAYS_INLINE bool DctDcOnlyColumn(void* dest, int adjusted_tx_height,
488                                            int width) {
489   if (adjusted_tx_height > 1) return false;
490 
491   auto* dst = static_cast<int16_t*>(dest);
492   const int16_t cos128 = Cos128(32);
493 
494   // Calculate dc values for first row.
495   if (width == 4) {
496     const int16x4_t v_src = vld1_s16(dst);
497     const int16x4_t xy = vqrdmulh_n_s16(v_src, cos128 << 3);
498     vst1_s16(dst, xy);
499   } else {
500     int i = 0;
501     do {
502       const int16x8_t v_src = vld1q_s16(&dst[i]);
503       const int16x8_t xy = vqrdmulhq_n_s16(v_src, cos128 << 3);
504       vst1q_s16(&dst[i], xy);
505       i += 8;
506     } while (i < width);
507   }
508 
509   // Copy first row to the rest of the block.
510   for (int y = 1; y < height; ++y) {
511     memcpy(&dst[y * width], dst, width * sizeof(dst[0]));
512   }
513   return true;
514 }
515 
516 template <ButterflyRotationFunc butterfly_rotation,
517           bool is_fast_butterfly = false>
Dct4Stages(int16x8_t * s)518 LIBGAV1_ALWAYS_INLINE void Dct4Stages(int16x8_t* s) {
519   // stage 12.
520   if (is_fast_butterfly) {
521     ButterflyRotation_SecondIsZero(&s[0], &s[1], 32, true);
522     ButterflyRotation_SecondIsZero(&s[2], &s[3], 48, false);
523   } else {
524     butterfly_rotation(&s[0], &s[1], 32, true);
525     butterfly_rotation(&s[2], &s[3], 48, false);
526   }
527 
528   // stage 17.
529   HadamardRotation(&s[0], &s[3], false);
530   HadamardRotation(&s[1], &s[2], false);
531 }
532 
533 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Dct4_NEON(void * dest,int32_t step,bool transpose)534 LIBGAV1_ALWAYS_INLINE void Dct4_NEON(void* dest, int32_t step, bool transpose) {
535   auto* const dst = static_cast<int16_t*>(dest);
536   int16x8_t s[4], x[4];
537 
538   if (stage_is_rectangular) {
539     if (transpose) {
540       assert(step == 4);
541       int16x8x4_t y = vld4q_s16(dst);
542       for (int i = 0; i < 4; ++i) x[i] = y.val[i];
543     } else {
544       LoadSrc<16, 4>(dst, step, 0, x);
545     }
546   } else {
547     if (transpose) {
548       assert(step == 4);
549       int16x4x4_t y = vld4_s16(dst);
550       for (int i = 0; i < 4; ++i) x[i] = vcombine_s16(y.val[i], y.val[i]);
551     } else {
552       LoadSrc<8, 4>(dst, step, 0, x);
553     }
554   }
555 
556   // stage 1.
557   // kBitReverseLookup 0, 2, 1, 3
558   s[0] = x[0];
559   s[1] = x[2];
560   s[2] = x[1];
561   s[3] = x[3];
562 
563   Dct4Stages<butterfly_rotation>(s);
564 
565   if (stage_is_rectangular) {
566     if (transpose) {
567       int16x8x4_t y;
568       for (int i = 0; i < 4; ++i) y.val[i] = s[i];
569       vst4q_s16(dst, y);
570     } else {
571       StoreDst<16, 4>(dst, step, 0, s);
572     }
573   } else {
574     if (transpose) {
575       int16x4x4_t y;
576       for (int i = 0; i < 4; ++i) y.val[i] = vget_low_s16(s[i]);
577       vst4_s16(dst, y);
578     } else {
579       StoreDst<8, 4>(dst, step, 0, s);
580     }
581   }
582 }
583 
584 template <ButterflyRotationFunc butterfly_rotation,
585           bool is_fast_butterfly = false>
Dct8Stages(int16x8_t * s)586 LIBGAV1_ALWAYS_INLINE void Dct8Stages(int16x8_t* s) {
587   // stage 8.
588   if (is_fast_butterfly) {
589     ButterflyRotation_SecondIsZero(&s[4], &s[7], 56, false);
590     ButterflyRotation_FirstIsZero(&s[5], &s[6], 24, false);
591   } else {
592     butterfly_rotation(&s[4], &s[7], 56, false);
593     butterfly_rotation(&s[5], &s[6], 24, false);
594   }
595 
596   // stage 13.
597   HadamardRotation(&s[4], &s[5], false);
598   HadamardRotation(&s[6], &s[7], true);
599 
600   // stage 18.
601   butterfly_rotation(&s[6], &s[5], 32, true);
602 
603   // stage 22.
604   HadamardRotation(&s[0], &s[7], false);
605   HadamardRotation(&s[1], &s[6], false);
606   HadamardRotation(&s[2], &s[5], false);
607   HadamardRotation(&s[3], &s[4], false);
608 }
609 
610 // Process dct8 rows or columns, depending on the transpose flag.
611 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Dct8_NEON(void * dest,int32_t step,bool transpose)612 LIBGAV1_ALWAYS_INLINE void Dct8_NEON(void* dest, int32_t step, bool transpose) {
613   auto* const dst = static_cast<int16_t*>(dest);
614   int16x8_t s[8], x[8];
615 
616   if (stage_is_rectangular) {
617     if (transpose) {
618       int16x8_t input[4];
619       LoadSrc<16, 4>(dst, step, 0, input);
620       Transpose8x4To4x8(input, x);
621     } else {
622       LoadSrc<8, 8>(dst, step, 0, x);
623     }
624   } else if (transpose) {
625     LoadSrc<16, 8>(dst, step, 0, x);
626     dsp::Transpose8x8(x);
627   } else {
628     LoadSrc<16, 8>(dst, step, 0, x);
629   }
630 
631   // stage 1.
632   // kBitReverseLookup 0, 4, 2, 6, 1, 5, 3, 7,
633   s[0] = x[0];
634   s[1] = x[4];
635   s[2] = x[2];
636   s[3] = x[6];
637   s[4] = x[1];
638   s[5] = x[5];
639   s[6] = x[3];
640   s[7] = x[7];
641 
642   Dct4Stages<butterfly_rotation>(s);
643   Dct8Stages<butterfly_rotation>(s);
644 
645   if (stage_is_rectangular) {
646     if (transpose) {
647       int16x8_t output[4];
648       Transpose4x8To8x4(s, output);
649       StoreDst<16, 4>(dst, step, 0, output);
650     } else {
651       StoreDst<8, 8>(dst, step, 0, s);
652     }
653   } else if (transpose) {
654     dsp::Transpose8x8(s);
655     StoreDst<16, 8>(dst, step, 0, s);
656   } else {
657     StoreDst<16, 8>(dst, step, 0, s);
658   }
659 }
660 
661 template <ButterflyRotationFunc butterfly_rotation,
662           bool is_fast_butterfly = false>
Dct16Stages(int16x8_t * s)663 LIBGAV1_ALWAYS_INLINE void Dct16Stages(int16x8_t* s) {
664   // stage 5.
665   if (is_fast_butterfly) {
666     ButterflyRotation_SecondIsZero(&s[8], &s[15], 60, false);
667     ButterflyRotation_FirstIsZero(&s[9], &s[14], 28, false);
668     ButterflyRotation_SecondIsZero(&s[10], &s[13], 44, false);
669     ButterflyRotation_FirstIsZero(&s[11], &s[12], 12, false);
670   } else {
671     butterfly_rotation(&s[8], &s[15], 60, false);
672     butterfly_rotation(&s[9], &s[14], 28, false);
673     butterfly_rotation(&s[10], &s[13], 44, false);
674     butterfly_rotation(&s[11], &s[12], 12, false);
675   }
676 
677   // stage 9.
678   HadamardRotation(&s[8], &s[9], false);
679   HadamardRotation(&s[10], &s[11], true);
680   HadamardRotation(&s[12], &s[13], false);
681   HadamardRotation(&s[14], &s[15], true);
682 
683   // stage 14.
684   butterfly_rotation(&s[14], &s[9], 48, true);
685   butterfly_rotation(&s[13], &s[10], 112, true);
686 
687   // stage 19.
688   HadamardRotation(&s[8], &s[11], false);
689   HadamardRotation(&s[9], &s[10], false);
690   HadamardRotation(&s[12], &s[15], true);
691   HadamardRotation(&s[13], &s[14], true);
692 
693   // stage 23.
694   butterfly_rotation(&s[13], &s[10], 32, true);
695   butterfly_rotation(&s[12], &s[11], 32, true);
696 
697   // stage 26.
698   HadamardRotation(&s[0], &s[15], false);
699   HadamardRotation(&s[1], &s[14], false);
700   HadamardRotation(&s[2], &s[13], false);
701   HadamardRotation(&s[3], &s[12], false);
702   HadamardRotation(&s[4], &s[11], false);
703   HadamardRotation(&s[5], &s[10], false);
704   HadamardRotation(&s[6], &s[9], false);
705   HadamardRotation(&s[7], &s[8], false);
706 }
707 
708 // Process dct16 rows or columns, depending on the transpose flag.
709 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Dct16_NEON(void * dest,int32_t step,bool is_row,int row_shift)710 LIBGAV1_ALWAYS_INLINE void Dct16_NEON(void* dest, int32_t step, bool is_row,
711                                       int row_shift) {
712   auto* const dst = static_cast<int16_t*>(dest);
713   int16x8_t s[16], x[16];
714 
715   if (stage_is_rectangular) {
716     if (is_row) {
717       int16x8_t input[4];
718       LoadSrc<16, 4>(dst, step, 0, input);
719       Transpose8x4To4x8(input, x);
720       LoadSrc<16, 4>(dst, step, 8, input);
721       Transpose8x4To4x8(input, &x[8]);
722     } else {
723       LoadSrc<8, 16>(dst, step, 0, x);
724     }
725   } else if (is_row) {
726     for (int idx = 0; idx < 16; idx += 8) {
727       LoadSrc<16, 8>(dst, step, idx, &x[idx]);
728       dsp::Transpose8x8(&x[idx]);
729     }
730   } else {
731     LoadSrc<16, 16>(dst, step, 0, x);
732   }
733 
734   // stage 1
735   // kBitReverseLookup 0, 8, 4, 12, 2, 10, 6, 14, 1, 9, 5, 13, 3, 11, 7, 15,
736   s[0] = x[0];
737   s[1] = x[8];
738   s[2] = x[4];
739   s[3] = x[12];
740   s[4] = x[2];
741   s[5] = x[10];
742   s[6] = x[6];
743   s[7] = x[14];
744   s[8] = x[1];
745   s[9] = x[9];
746   s[10] = x[5];
747   s[11] = x[13];
748   s[12] = x[3];
749   s[13] = x[11];
750   s[14] = x[7];
751   s[15] = x[15];
752 
753   Dct4Stages<butterfly_rotation>(s);
754   Dct8Stages<butterfly_rotation>(s);
755   Dct16Stages<butterfly_rotation>(s);
756 
757   if (is_row) {
758     const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
759     for (auto& i : s) {
760       i = vqrshlq_s16(i, v_row_shift);
761     }
762   }
763 
764   if (stage_is_rectangular) {
765     if (is_row) {
766       int16x8_t output[4];
767       Transpose4x8To8x4(s, output);
768       StoreDst<16, 4>(dst, step, 0, output);
769       Transpose4x8To8x4(&s[8], output);
770       StoreDst<16, 4>(dst, step, 8, output);
771     } else {
772       StoreDst<8, 16>(dst, step, 0, s);
773     }
774   } else if (is_row) {
775     for (int idx = 0; idx < 16; idx += 8) {
776       dsp::Transpose8x8(&s[idx]);
777       StoreDst<16, 8>(dst, step, idx, &s[idx]);
778     }
779   } else {
780     StoreDst<16, 16>(dst, step, 0, s);
781   }
782 }
783 
784 template <ButterflyRotationFunc butterfly_rotation,
785           bool is_fast_butterfly = false>
Dct32Stages(int16x8_t * s)786 LIBGAV1_ALWAYS_INLINE void Dct32Stages(int16x8_t* s) {
787   // stage 3
788   if (is_fast_butterfly) {
789     ButterflyRotation_SecondIsZero(&s[16], &s[31], 62, false);
790     ButterflyRotation_FirstIsZero(&s[17], &s[30], 30, false);
791     ButterflyRotation_SecondIsZero(&s[18], &s[29], 46, false);
792     ButterflyRotation_FirstIsZero(&s[19], &s[28], 14, false);
793     ButterflyRotation_SecondIsZero(&s[20], &s[27], 54, false);
794     ButterflyRotation_FirstIsZero(&s[21], &s[26], 22, false);
795     ButterflyRotation_SecondIsZero(&s[22], &s[25], 38, false);
796     ButterflyRotation_FirstIsZero(&s[23], &s[24], 6, false);
797   } else {
798     butterfly_rotation(&s[16], &s[31], 62, false);
799     butterfly_rotation(&s[17], &s[30], 30, false);
800     butterfly_rotation(&s[18], &s[29], 46, false);
801     butterfly_rotation(&s[19], &s[28], 14, false);
802     butterfly_rotation(&s[20], &s[27], 54, false);
803     butterfly_rotation(&s[21], &s[26], 22, false);
804     butterfly_rotation(&s[22], &s[25], 38, false);
805     butterfly_rotation(&s[23], &s[24], 6, false);
806   }
807   // stage 6.
808   HadamardRotation(&s[16], &s[17], false);
809   HadamardRotation(&s[18], &s[19], true);
810   HadamardRotation(&s[20], &s[21], false);
811   HadamardRotation(&s[22], &s[23], true);
812   HadamardRotation(&s[24], &s[25], false);
813   HadamardRotation(&s[26], &s[27], true);
814   HadamardRotation(&s[28], &s[29], false);
815   HadamardRotation(&s[30], &s[31], true);
816 
817   // stage 10.
818   butterfly_rotation(&s[30], &s[17], 24 + 32, true);
819   butterfly_rotation(&s[29], &s[18], 24 + 64 + 32, true);
820   butterfly_rotation(&s[26], &s[21], 24, true);
821   butterfly_rotation(&s[25], &s[22], 24 + 64, true);
822 
823   // stage 15.
824   HadamardRotation(&s[16], &s[19], false);
825   HadamardRotation(&s[17], &s[18], false);
826   HadamardRotation(&s[20], &s[23], true);
827   HadamardRotation(&s[21], &s[22], true);
828   HadamardRotation(&s[24], &s[27], false);
829   HadamardRotation(&s[25], &s[26], false);
830   HadamardRotation(&s[28], &s[31], true);
831   HadamardRotation(&s[29], &s[30], true);
832 
833   // stage 20.
834   butterfly_rotation(&s[29], &s[18], 48, true);
835   butterfly_rotation(&s[28], &s[19], 48, true);
836   butterfly_rotation(&s[27], &s[20], 48 + 64, true);
837   butterfly_rotation(&s[26], &s[21], 48 + 64, true);
838 
839   // stage 24.
840   HadamardRotation(&s[16], &s[23], false);
841   HadamardRotation(&s[17], &s[22], false);
842   HadamardRotation(&s[18], &s[21], false);
843   HadamardRotation(&s[19], &s[20], false);
844   HadamardRotation(&s[24], &s[31], true);
845   HadamardRotation(&s[25], &s[30], true);
846   HadamardRotation(&s[26], &s[29], true);
847   HadamardRotation(&s[27], &s[28], true);
848 
849   // stage 27.
850   butterfly_rotation(&s[27], &s[20], 32, true);
851   butterfly_rotation(&s[26], &s[21], 32, true);
852   butterfly_rotation(&s[25], &s[22], 32, true);
853   butterfly_rotation(&s[24], &s[23], 32, true);
854 
855   // stage 29.
856   HadamardRotation(&s[0], &s[31], false);
857   HadamardRotation(&s[1], &s[30], false);
858   HadamardRotation(&s[2], &s[29], false);
859   HadamardRotation(&s[3], &s[28], false);
860   HadamardRotation(&s[4], &s[27], false);
861   HadamardRotation(&s[5], &s[26], false);
862   HadamardRotation(&s[6], &s[25], false);
863   HadamardRotation(&s[7], &s[24], false);
864   HadamardRotation(&s[8], &s[23], false);
865   HadamardRotation(&s[9], &s[22], false);
866   HadamardRotation(&s[10], &s[21], false);
867   HadamardRotation(&s[11], &s[20], false);
868   HadamardRotation(&s[12], &s[19], false);
869   HadamardRotation(&s[13], &s[18], false);
870   HadamardRotation(&s[14], &s[17], false);
871   HadamardRotation(&s[15], &s[16], false);
872 }
873 
874 // Process dct32 rows or columns, depending on the transpose flag.
Dct32_NEON(void * dest,const int32_t step,const bool is_row,int row_shift)875 LIBGAV1_ALWAYS_INLINE void Dct32_NEON(void* dest, const int32_t step,
876                                       const bool is_row, int row_shift) {
877   auto* const dst = static_cast<int16_t*>(dest);
878   int16x8_t s[32], x[32];
879 
880   if (is_row) {
881     for (int idx = 0; idx < 32; idx += 8) {
882       LoadSrc<16, 8>(dst, step, idx, &x[idx]);
883       dsp::Transpose8x8(&x[idx]);
884     }
885   } else {
886     LoadSrc<16, 32>(dst, step, 0, x);
887   }
888 
889   // stage 1
890   // kBitReverseLookup
891   // 0, 16, 8, 24, 4, 20, 12, 28, 2, 18, 10, 26, 6, 22, 14, 30,
892   s[0] = x[0];
893   s[1] = x[16];
894   s[2] = x[8];
895   s[3] = x[24];
896   s[4] = x[4];
897   s[5] = x[20];
898   s[6] = x[12];
899   s[7] = x[28];
900   s[8] = x[2];
901   s[9] = x[18];
902   s[10] = x[10];
903   s[11] = x[26];
904   s[12] = x[6];
905   s[13] = x[22];
906   s[14] = x[14];
907   s[15] = x[30];
908 
909   // 1, 17, 9, 25, 5, 21, 13, 29, 3, 19, 11, 27, 7, 23, 15, 31,
910   s[16] = x[1];
911   s[17] = x[17];
912   s[18] = x[9];
913   s[19] = x[25];
914   s[20] = x[5];
915   s[21] = x[21];
916   s[22] = x[13];
917   s[23] = x[29];
918   s[24] = x[3];
919   s[25] = x[19];
920   s[26] = x[11];
921   s[27] = x[27];
922   s[28] = x[7];
923   s[29] = x[23];
924   s[30] = x[15];
925   s[31] = x[31];
926 
927   Dct4Stages<ButterflyRotation_8>(s);
928   Dct8Stages<ButterflyRotation_8>(s);
929   Dct16Stages<ButterflyRotation_8>(s);
930   Dct32Stages<ButterflyRotation_8>(s);
931 
932   if (is_row) {
933     const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
934     for (int idx = 0; idx < 32; idx += 8) {
935       int16x8_t output[8];
936       Transpose8x8(&s[idx], output);
937       for (auto& o : output) {
938         o = vqrshlq_s16(o, v_row_shift);
939       }
940       StoreDst<16, 8>(dst, step, idx, output);
941     }
942   } else {
943     StoreDst<16, 32>(dst, step, 0, s);
944   }
945 }
946 
947 // Allow the compiler to call this function instead of force inlining. Tests
948 // show the performance is slightly faster.
Dct64_NEON(void * dest,int32_t step,bool is_row,int row_shift)949 void Dct64_NEON(void* dest, int32_t step, bool is_row, int row_shift) {
950   auto* const dst = static_cast<int16_t*>(dest);
951   int16x8_t s[64], x[32];
952 
953   if (is_row) {
954     // The last 32 values of every row are always zero if the |tx_width| is
955     // 64.
956     for (int idx = 0; idx < 32; idx += 8) {
957       LoadSrc<16, 8>(dst, step, idx, &x[idx]);
958       dsp::Transpose8x8(&x[idx]);
959     }
960   } else {
961     // The last 32 values of every column are always zero if the |tx_height| is
962     // 64.
963     LoadSrc<16, 32>(dst, step, 0, x);
964   }
965 
966   // stage 1
967   // kBitReverseLookup
968   // 0, 32, 16, 48, 8, 40, 24, 56, 4, 36, 20, 52, 12, 44, 28, 60,
969   s[0] = x[0];
970   s[2] = x[16];
971   s[4] = x[8];
972   s[6] = x[24];
973   s[8] = x[4];
974   s[10] = x[20];
975   s[12] = x[12];
976   s[14] = x[28];
977 
978   // 2, 34, 18, 50, 10, 42, 26, 58, 6, 38, 22, 54, 14, 46, 30, 62,
979   s[16] = x[2];
980   s[18] = x[18];
981   s[20] = x[10];
982   s[22] = x[26];
983   s[24] = x[6];
984   s[26] = x[22];
985   s[28] = x[14];
986   s[30] = x[30];
987 
988   // 1, 33, 17, 49, 9, 41, 25, 57, 5, 37, 21, 53, 13, 45, 29, 61,
989   s[32] = x[1];
990   s[34] = x[17];
991   s[36] = x[9];
992   s[38] = x[25];
993   s[40] = x[5];
994   s[42] = x[21];
995   s[44] = x[13];
996   s[46] = x[29];
997 
998   // 3, 35, 19, 51, 11, 43, 27, 59, 7, 39, 23, 55, 15, 47, 31, 63
999   s[48] = x[3];
1000   s[50] = x[19];
1001   s[52] = x[11];
1002   s[54] = x[27];
1003   s[56] = x[7];
1004   s[58] = x[23];
1005   s[60] = x[15];
1006   s[62] = x[31];
1007 
1008   Dct4Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
1009   Dct8Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
1010   Dct16Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
1011   Dct32Stages<ButterflyRotation_8, /*is_fast_butterfly=*/true>(s);
1012 
1013   //-- start dct 64 stages
1014   // stage 2.
1015   ButterflyRotation_SecondIsZero(&s[32], &s[63], 63 - 0, false);
1016   ButterflyRotation_FirstIsZero(&s[33], &s[62], 63 - 32, false);
1017   ButterflyRotation_SecondIsZero(&s[34], &s[61], 63 - 16, false);
1018   ButterflyRotation_FirstIsZero(&s[35], &s[60], 63 - 48, false);
1019   ButterflyRotation_SecondIsZero(&s[36], &s[59], 63 - 8, false);
1020   ButterflyRotation_FirstIsZero(&s[37], &s[58], 63 - 40, false);
1021   ButterflyRotation_SecondIsZero(&s[38], &s[57], 63 - 24, false);
1022   ButterflyRotation_FirstIsZero(&s[39], &s[56], 63 - 56, false);
1023   ButterflyRotation_SecondIsZero(&s[40], &s[55], 63 - 4, false);
1024   ButterflyRotation_FirstIsZero(&s[41], &s[54], 63 - 36, false);
1025   ButterflyRotation_SecondIsZero(&s[42], &s[53], 63 - 20, false);
1026   ButterflyRotation_FirstIsZero(&s[43], &s[52], 63 - 52, false);
1027   ButterflyRotation_SecondIsZero(&s[44], &s[51], 63 - 12, false);
1028   ButterflyRotation_FirstIsZero(&s[45], &s[50], 63 - 44, false);
1029   ButterflyRotation_SecondIsZero(&s[46], &s[49], 63 - 28, false);
1030   ButterflyRotation_FirstIsZero(&s[47], &s[48], 63 - 60, false);
1031 
1032   // stage 4.
1033   HadamardRotation(&s[32], &s[33], false);
1034   HadamardRotation(&s[34], &s[35], true);
1035   HadamardRotation(&s[36], &s[37], false);
1036   HadamardRotation(&s[38], &s[39], true);
1037   HadamardRotation(&s[40], &s[41], false);
1038   HadamardRotation(&s[42], &s[43], true);
1039   HadamardRotation(&s[44], &s[45], false);
1040   HadamardRotation(&s[46], &s[47], true);
1041   HadamardRotation(&s[48], &s[49], false);
1042   HadamardRotation(&s[50], &s[51], true);
1043   HadamardRotation(&s[52], &s[53], false);
1044   HadamardRotation(&s[54], &s[55], true);
1045   HadamardRotation(&s[56], &s[57], false);
1046   HadamardRotation(&s[58], &s[59], true);
1047   HadamardRotation(&s[60], &s[61], false);
1048   HadamardRotation(&s[62], &s[63], true);
1049 
1050   // stage 7.
1051   ButterflyRotation_8(&s[62], &s[33], 60 - 0, true);
1052   ButterflyRotation_8(&s[61], &s[34], 60 - 0 + 64, true);
1053   ButterflyRotation_8(&s[58], &s[37], 60 - 32, true);
1054   ButterflyRotation_8(&s[57], &s[38], 60 - 32 + 64, true);
1055   ButterflyRotation_8(&s[54], &s[41], 60 - 16, true);
1056   ButterflyRotation_8(&s[53], &s[42], 60 - 16 + 64, true);
1057   ButterflyRotation_8(&s[50], &s[45], 60 - 48, true);
1058   ButterflyRotation_8(&s[49], &s[46], 60 - 48 + 64, true);
1059 
1060   // stage 11.
1061   HadamardRotation(&s[32], &s[35], false);
1062   HadamardRotation(&s[33], &s[34], false);
1063   HadamardRotation(&s[36], &s[39], true);
1064   HadamardRotation(&s[37], &s[38], true);
1065   HadamardRotation(&s[40], &s[43], false);
1066   HadamardRotation(&s[41], &s[42], false);
1067   HadamardRotation(&s[44], &s[47], true);
1068   HadamardRotation(&s[45], &s[46], true);
1069   HadamardRotation(&s[48], &s[51], false);
1070   HadamardRotation(&s[49], &s[50], false);
1071   HadamardRotation(&s[52], &s[55], true);
1072   HadamardRotation(&s[53], &s[54], true);
1073   HadamardRotation(&s[56], &s[59], false);
1074   HadamardRotation(&s[57], &s[58], false);
1075   HadamardRotation(&s[60], &s[63], true);
1076   HadamardRotation(&s[61], &s[62], true);
1077 
1078   // stage 16.
1079   ButterflyRotation_8(&s[61], &s[34], 56, true);
1080   ButterflyRotation_8(&s[60], &s[35], 56, true);
1081   ButterflyRotation_8(&s[59], &s[36], 56 + 64, true);
1082   ButterflyRotation_8(&s[58], &s[37], 56 + 64, true);
1083   ButterflyRotation_8(&s[53], &s[42], 56 - 32, true);
1084   ButterflyRotation_8(&s[52], &s[43], 56 - 32, true);
1085   ButterflyRotation_8(&s[51], &s[44], 56 - 32 + 64, true);
1086   ButterflyRotation_8(&s[50], &s[45], 56 - 32 + 64, true);
1087 
1088   // stage 21.
1089   HadamardRotation(&s[32], &s[39], false);
1090   HadamardRotation(&s[33], &s[38], false);
1091   HadamardRotation(&s[34], &s[37], false);
1092   HadamardRotation(&s[35], &s[36], false);
1093   HadamardRotation(&s[40], &s[47], true);
1094   HadamardRotation(&s[41], &s[46], true);
1095   HadamardRotation(&s[42], &s[45], true);
1096   HadamardRotation(&s[43], &s[44], true);
1097   HadamardRotation(&s[48], &s[55], false);
1098   HadamardRotation(&s[49], &s[54], false);
1099   HadamardRotation(&s[50], &s[53], false);
1100   HadamardRotation(&s[51], &s[52], false);
1101   HadamardRotation(&s[56], &s[63], true);
1102   HadamardRotation(&s[57], &s[62], true);
1103   HadamardRotation(&s[58], &s[61], true);
1104   HadamardRotation(&s[59], &s[60], true);
1105 
1106   // stage 25.
1107   ButterflyRotation_8(&s[59], &s[36], 48, true);
1108   ButterflyRotation_8(&s[58], &s[37], 48, true);
1109   ButterflyRotation_8(&s[57], &s[38], 48, true);
1110   ButterflyRotation_8(&s[56], &s[39], 48, true);
1111   ButterflyRotation_8(&s[55], &s[40], 112, true);
1112   ButterflyRotation_8(&s[54], &s[41], 112, true);
1113   ButterflyRotation_8(&s[53], &s[42], 112, true);
1114   ButterflyRotation_8(&s[52], &s[43], 112, true);
1115 
1116   // stage 28.
1117   HadamardRotation(&s[32], &s[47], false);
1118   HadamardRotation(&s[33], &s[46], false);
1119   HadamardRotation(&s[34], &s[45], false);
1120   HadamardRotation(&s[35], &s[44], false);
1121   HadamardRotation(&s[36], &s[43], false);
1122   HadamardRotation(&s[37], &s[42], false);
1123   HadamardRotation(&s[38], &s[41], false);
1124   HadamardRotation(&s[39], &s[40], false);
1125   HadamardRotation(&s[48], &s[63], true);
1126   HadamardRotation(&s[49], &s[62], true);
1127   HadamardRotation(&s[50], &s[61], true);
1128   HadamardRotation(&s[51], &s[60], true);
1129   HadamardRotation(&s[52], &s[59], true);
1130   HadamardRotation(&s[53], &s[58], true);
1131   HadamardRotation(&s[54], &s[57], true);
1132   HadamardRotation(&s[55], &s[56], true);
1133 
1134   // stage 30.
1135   ButterflyRotation_8(&s[55], &s[40], 32, true);
1136   ButterflyRotation_8(&s[54], &s[41], 32, true);
1137   ButterflyRotation_8(&s[53], &s[42], 32, true);
1138   ButterflyRotation_8(&s[52], &s[43], 32, true);
1139   ButterflyRotation_8(&s[51], &s[44], 32, true);
1140   ButterflyRotation_8(&s[50], &s[45], 32, true);
1141   ButterflyRotation_8(&s[49], &s[46], 32, true);
1142   ButterflyRotation_8(&s[48], &s[47], 32, true);
1143 
1144   // stage 31.
1145   for (int i = 0; i < 32; i += 4) {
1146     HadamardRotation(&s[i], &s[63 - i], false);
1147     HadamardRotation(&s[i + 1], &s[63 - i - 1], false);
1148     HadamardRotation(&s[i + 2], &s[63 - i - 2], false);
1149     HadamardRotation(&s[i + 3], &s[63 - i - 3], false);
1150   }
1151   //-- end dct 64 stages
1152 
1153   if (is_row) {
1154     const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
1155     for (int idx = 0; idx < 64; idx += 8) {
1156       int16x8_t output[8];
1157       Transpose8x8(&s[idx], output);
1158       for (auto& o : output) {
1159         o = vqrshlq_s16(o, v_row_shift);
1160       }
1161       StoreDst<16, 8>(dst, step, idx, output);
1162     }
1163   } else {
1164     StoreDst<16, 64>(dst, step, 0, s);
1165   }
1166 }
1167 
1168 //------------------------------------------------------------------------------
1169 // Asymmetric Discrete Sine Transforms (ADST).
1170 
Adst4_NEON(void * dest,int32_t step,bool transpose)1171 LIBGAV1_ALWAYS_INLINE void Adst4_NEON(void* dest, int32_t step,
1172                                       bool transpose) {
1173   auto* const dst = static_cast<int16_t*>(dest);
1174   int32x4_t s[7];
1175   int16x4_t x[4];
1176 
1177   if (transpose) {
1178     assert(step == 4);
1179     int16x4x4_t y = vld4_s16(dst);
1180     for (int i = 0; i < 4; ++i) x[i] = y.val[i];
1181   } else {
1182     x[0] = vld1_s16(dst);
1183     x[1] = vld1_s16(dst + 1 * step);
1184     x[2] = vld1_s16(dst + 2 * step);
1185     x[3] = vld1_s16(dst + 3 * step);
1186   }
1187 
1188   // stage 1.
1189   s[5] = vmull_n_s16(x[3], kAdst4Multiplier[1]);
1190   s[6] = vmull_n_s16(x[3], kAdst4Multiplier[3]);
1191 
1192   // stage 2.
1193   const int32x4_t a7 = vsubl_s16(x[0], x[2]);
1194   const int32x4_t b7 = vaddw_s16(a7, x[3]);
1195 
1196   // stage 3.
1197   s[0] = vmull_n_s16(x[0], kAdst4Multiplier[0]);
1198   s[1] = vmull_n_s16(x[0], kAdst4Multiplier[1]);
1199   // s[0] = s[0] + s[3]
1200   s[0] = vmlal_n_s16(s[0], x[2], kAdst4Multiplier[3]);
1201   // s[1] = s[1] - s[4]
1202   s[1] = vmlsl_n_s16(s[1], x[2], kAdst4Multiplier[0]);
1203 
1204   s[3] = vmull_n_s16(x[1], kAdst4Multiplier[2]);
1205   s[2] = vmulq_n_s32(b7, kAdst4Multiplier[2]);
1206 
1207   // stage 4.
1208   s[0] = vaddq_s32(s[0], s[5]);
1209   s[1] = vsubq_s32(s[1], s[6]);
1210 
1211   // stages 5 and 6.
1212   const int32x4_t x0 = vaddq_s32(s[0], s[3]);
1213   const int32x4_t x1 = vaddq_s32(s[1], s[3]);
1214   const int32x4_t x3_a = vaddq_s32(s[0], s[1]);
1215   const int32x4_t x3 = vsubq_s32(x3_a, s[3]);
1216   const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12);
1217   const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12);
1218   const int16x4_t dst_2 = vqrshrn_n_s32(s[2], 12);
1219   const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12);
1220 
1221   x[0] = dst_0;
1222   x[1] = dst_1;
1223   x[2] = dst_2;
1224   x[3] = dst_3;
1225 
1226   if (transpose) {
1227     int16x4x4_t y;
1228     for (int i = 0; i < 4; ++i) y.val[i] = x[i];
1229     vst4_s16(dst, y);
1230   } else {
1231     vst1_s16(dst, x[0]);
1232     vst1_s16(dst + 1 * step, x[1]);
1233     vst1_s16(dst + 2 * step, x[2]);
1234     vst1_s16(dst + 3 * step, x[3]);
1235   }
1236 }
1237 
1238 alignas(8) constexpr int16_t kAdst4DcOnlyMultiplier[4] = {1321, 2482, 3344,
1239                                                           2482};
1240 
Adst4DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1241 LIBGAV1_ALWAYS_INLINE bool Adst4DcOnly(void* dest, int adjusted_tx_height,
1242                                        bool should_round, int row_shift) {
1243   if (adjusted_tx_height > 1) return false;
1244 
1245   auto* dst = static_cast<int16_t*>(dest);
1246   int32x4_t s[2];
1247 
1248   const int16x4_t v_src0 = vdup_n_s16(dst[0]);
1249   const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
1250   const int16x4_t v_src_round =
1251       vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
1252   const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
1253   const int16x4_t kAdst4DcOnlyMultipliers = vld1_s16(kAdst4DcOnlyMultiplier);
1254   s[1] = vdupq_n_s32(0);
1255 
1256   // s0*k0 s0*k1 s0*k2 s0*k1
1257   s[0] = vmull_s16(kAdst4DcOnlyMultipliers, v_src);
1258   // 0     0     0     s0*k0
1259   s[1] = vextq_s32(s[1], s[0], 1);
1260 
1261   const int32x4_t x3 = vaddq_s32(s[0], s[1]);
1262   const int16x4_t dst_0 = vqrshrn_n_s32(x3, 12);
1263 
1264   // vqrshlq_s16 will shift right if shift value is negative.
1265   vst1_s16(dst, vqrshl_s16(dst_0, vdup_n_s16(-row_shift)));
1266 
1267   return true;
1268 }
1269 
Adst4DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1270 LIBGAV1_ALWAYS_INLINE bool Adst4DcOnlyColumn(void* dest, int adjusted_tx_height,
1271                                              int width) {
1272   if (adjusted_tx_height > 1) return false;
1273 
1274   auto* dst = static_cast<int16_t*>(dest);
1275   int32x4_t s[4];
1276 
1277   int i = 0;
1278   do {
1279     const int16x4_t v_src = vld1_s16(&dst[i]);
1280 
1281     s[0] = vmull_n_s16(v_src, kAdst4Multiplier[0]);
1282     s[1] = vmull_n_s16(v_src, kAdst4Multiplier[1]);
1283     s[2] = vmull_n_s16(v_src, kAdst4Multiplier[2]);
1284 
1285     const int32x4_t x0 = s[0];
1286     const int32x4_t x1 = s[1];
1287     const int32x4_t x2 = s[2];
1288     const int32x4_t x3 = vaddq_s32(s[0], s[1]);
1289     const int16x4_t dst_0 = vqrshrn_n_s32(x0, 12);
1290     const int16x4_t dst_1 = vqrshrn_n_s32(x1, 12);
1291     const int16x4_t dst_2 = vqrshrn_n_s32(x2, 12);
1292     const int16x4_t dst_3 = vqrshrn_n_s32(x3, 12);
1293 
1294     vst1_s16(&dst[i], dst_0);
1295     vst1_s16(&dst[i + width * 1], dst_1);
1296     vst1_s16(&dst[i + width * 2], dst_2);
1297     vst1_s16(&dst[i + width * 3], dst_3);
1298 
1299     i += 4;
1300   } while (i < width);
1301 
1302   return true;
1303 }
1304 
1305 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Adst8_NEON(void * dest,int32_t step,bool transpose)1306 LIBGAV1_ALWAYS_INLINE void Adst8_NEON(void* dest, int32_t step,
1307                                       bool transpose) {
1308   auto* const dst = static_cast<int16_t*>(dest);
1309   int16x8_t s[8], x[8];
1310 
1311   if (stage_is_rectangular) {
1312     if (transpose) {
1313       int16x8_t input[4];
1314       LoadSrc<16, 4>(dst, step, 0, input);
1315       Transpose8x4To4x8(input, x);
1316     } else {
1317       LoadSrc<8, 8>(dst, step, 0, x);
1318     }
1319   } else {
1320     if (transpose) {
1321       LoadSrc<16, 8>(dst, step, 0, x);
1322       dsp::Transpose8x8(x);
1323     } else {
1324       LoadSrc<16, 8>(dst, step, 0, x);
1325     }
1326   }
1327 
1328   // stage 1.
1329   s[0] = x[7];
1330   s[1] = x[0];
1331   s[2] = x[5];
1332   s[3] = x[2];
1333   s[4] = x[3];
1334   s[5] = x[4];
1335   s[6] = x[1];
1336   s[7] = x[6];
1337 
1338   // stage 2.
1339   butterfly_rotation(&s[0], &s[1], 60 - 0, true);
1340   butterfly_rotation(&s[2], &s[3], 60 - 16, true);
1341   butterfly_rotation(&s[4], &s[5], 60 - 32, true);
1342   butterfly_rotation(&s[6], &s[7], 60 - 48, true);
1343 
1344   // stage 3.
1345   HadamardRotation(&s[0], &s[4], false);
1346   HadamardRotation(&s[1], &s[5], false);
1347   HadamardRotation(&s[2], &s[6], false);
1348   HadamardRotation(&s[3], &s[7], false);
1349 
1350   // stage 4.
1351   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
1352   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
1353 
1354   // stage 5.
1355   HadamardRotation(&s[0], &s[2], false);
1356   HadamardRotation(&s[4], &s[6], false);
1357   HadamardRotation(&s[1], &s[3], false);
1358   HadamardRotation(&s[5], &s[7], false);
1359 
1360   // stage 6.
1361   butterfly_rotation(&s[2], &s[3], 32, true);
1362   butterfly_rotation(&s[6], &s[7], 32, true);
1363 
1364   // stage 7.
1365   x[0] = s[0];
1366   x[1] = vqnegq_s16(s[4]);
1367   x[2] = s[6];
1368   x[3] = vqnegq_s16(s[2]);
1369   x[4] = s[3];
1370   x[5] = vqnegq_s16(s[7]);
1371   x[6] = s[5];
1372   x[7] = vqnegq_s16(s[1]);
1373 
1374   if (stage_is_rectangular) {
1375     if (transpose) {
1376       int16x8_t output[4];
1377       Transpose4x8To8x4(x, output);
1378       StoreDst<16, 4>(dst, step, 0, output);
1379     } else {
1380       StoreDst<8, 8>(dst, step, 0, x);
1381     }
1382   } else {
1383     if (transpose) {
1384       dsp::Transpose8x8(x);
1385       StoreDst<16, 8>(dst, step, 0, x);
1386     } else {
1387       StoreDst<16, 8>(dst, step, 0, x);
1388     }
1389   }
1390 }
1391 
Adst8DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1392 LIBGAV1_ALWAYS_INLINE bool Adst8DcOnly(void* dest, int adjusted_tx_height,
1393                                        bool should_round, int row_shift) {
1394   if (adjusted_tx_height > 1) return false;
1395 
1396   auto* dst = static_cast<int16_t*>(dest);
1397   int16x8_t s[8];
1398 
1399   const int16x8_t v_src = vdupq_n_s16(dst[0]);
1400   const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
1401   const int16x8_t v_src_round =
1402       vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
1403   // stage 1.
1404   s[1] = vbslq_s16(v_mask, v_src_round, v_src);
1405 
1406   // stage 2.
1407   ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
1408 
1409   // stage 3.
1410   s[4] = s[0];
1411   s[5] = s[1];
1412 
1413   // stage 4.
1414   ButterflyRotation_4(&s[4], &s[5], 48, true);
1415 
1416   // stage 5.
1417   s[2] = s[0];
1418   s[3] = s[1];
1419   s[6] = s[4];
1420   s[7] = s[5];
1421 
1422   // stage 6.
1423   ButterflyRotation_4(&s[2], &s[3], 32, true);
1424   ButterflyRotation_4(&s[6], &s[7], 32, true);
1425 
1426   // stage 7.
1427   int16x8_t x[8];
1428   x[0] = s[0];
1429   x[1] = vqnegq_s16(s[4]);
1430   x[2] = s[6];
1431   x[3] = vqnegq_s16(s[2]);
1432   x[4] = s[3];
1433   x[5] = vqnegq_s16(s[7]);
1434   x[6] = s[5];
1435   x[7] = vqnegq_s16(s[1]);
1436 
1437   for (int i = 0; i < 8; ++i) {
1438     // vqrshlq_s16 will shift right if shift value is negative.
1439     x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift));
1440     vst1q_lane_s16(&dst[i], x[i], 0);
1441   }
1442 
1443   return true;
1444 }
1445 
Adst8DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1446 LIBGAV1_ALWAYS_INLINE bool Adst8DcOnlyColumn(void* dest, int adjusted_tx_height,
1447                                              int width) {
1448   if (adjusted_tx_height > 1) return false;
1449 
1450   auto* dst = static_cast<int16_t*>(dest);
1451   int16x8_t s[8];
1452 
1453   int i = 0;
1454   do {
1455     const int16x8_t v_src = vld1q_s16(dst);
1456     // stage 1.
1457     s[1] = v_src;
1458 
1459     // stage 2.
1460     ButterflyRotation_FirstIsZero(&s[0], &s[1], 60, true);
1461 
1462     // stage 3.
1463     s[4] = s[0];
1464     s[5] = s[1];
1465 
1466     // stage 4.
1467     ButterflyRotation_4(&s[4], &s[5], 48, true);
1468 
1469     // stage 5.
1470     s[2] = s[0];
1471     s[3] = s[1];
1472     s[6] = s[4];
1473     s[7] = s[5];
1474 
1475     // stage 6.
1476     ButterflyRotation_4(&s[2], &s[3], 32, true);
1477     ButterflyRotation_4(&s[6], &s[7], 32, true);
1478 
1479     // stage 7.
1480     int16x8_t x[8];
1481     x[0] = s[0];
1482     x[1] = vqnegq_s16(s[4]);
1483     x[2] = s[6];
1484     x[3] = vqnegq_s16(s[2]);
1485     x[4] = s[3];
1486     x[5] = vqnegq_s16(s[7]);
1487     x[6] = s[5];
1488     x[7] = vqnegq_s16(s[1]);
1489 
1490     for (int j = 0; j < 8; ++j) {
1491       vst1_s16(&dst[j * width], vget_low_s16(x[j]));
1492     }
1493     i += 4;
1494     dst += 4;
1495   } while (i < width);
1496 
1497   return true;
1498 }
1499 
1500 template <ButterflyRotationFunc butterfly_rotation, bool stage_is_rectangular>
Adst16_NEON(void * dest,int32_t step,bool is_row,int row_shift)1501 LIBGAV1_ALWAYS_INLINE void Adst16_NEON(void* dest, int32_t step, bool is_row,
1502                                        int row_shift) {
1503   auto* const dst = static_cast<int16_t*>(dest);
1504   int16x8_t s[16], x[16];
1505 
1506   if (stage_is_rectangular) {
1507     if (is_row) {
1508       int16x8_t input[4];
1509       LoadSrc<16, 4>(dst, step, 0, input);
1510       Transpose8x4To4x8(input, x);
1511       LoadSrc<16, 4>(dst, step, 8, input);
1512       Transpose8x4To4x8(input, &x[8]);
1513     } else {
1514       LoadSrc<8, 16>(dst, step, 0, x);
1515     }
1516   } else {
1517     if (is_row) {
1518       for (int idx = 0; idx < 16; idx += 8) {
1519         LoadSrc<16, 8>(dst, step, idx, &x[idx]);
1520         dsp::Transpose8x8(&x[idx]);
1521       }
1522     } else {
1523       LoadSrc<16, 16>(dst, step, 0, x);
1524     }
1525   }
1526 
1527   // stage 1.
1528   s[0] = x[15];
1529   s[1] = x[0];
1530   s[2] = x[13];
1531   s[3] = x[2];
1532   s[4] = x[11];
1533   s[5] = x[4];
1534   s[6] = x[9];
1535   s[7] = x[6];
1536   s[8] = x[7];
1537   s[9] = x[8];
1538   s[10] = x[5];
1539   s[11] = x[10];
1540   s[12] = x[3];
1541   s[13] = x[12];
1542   s[14] = x[1];
1543   s[15] = x[14];
1544 
1545   // stage 2.
1546   butterfly_rotation(&s[0], &s[1], 62 - 0, true);
1547   butterfly_rotation(&s[2], &s[3], 62 - 8, true);
1548   butterfly_rotation(&s[4], &s[5], 62 - 16, true);
1549   butterfly_rotation(&s[6], &s[7], 62 - 24, true);
1550   butterfly_rotation(&s[8], &s[9], 62 - 32, true);
1551   butterfly_rotation(&s[10], &s[11], 62 - 40, true);
1552   butterfly_rotation(&s[12], &s[13], 62 - 48, true);
1553   butterfly_rotation(&s[14], &s[15], 62 - 56, true);
1554 
1555   // stage 3.
1556   HadamardRotation(&s[0], &s[8], false);
1557   HadamardRotation(&s[1], &s[9], false);
1558   HadamardRotation(&s[2], &s[10], false);
1559   HadamardRotation(&s[3], &s[11], false);
1560   HadamardRotation(&s[4], &s[12], false);
1561   HadamardRotation(&s[5], &s[13], false);
1562   HadamardRotation(&s[6], &s[14], false);
1563   HadamardRotation(&s[7], &s[15], false);
1564 
1565   // stage 4.
1566   butterfly_rotation(&s[8], &s[9], 56 - 0, true);
1567   butterfly_rotation(&s[13], &s[12], 8 + 0, true);
1568   butterfly_rotation(&s[10], &s[11], 56 - 32, true);
1569   butterfly_rotation(&s[15], &s[14], 8 + 32, true);
1570 
1571   // stage 5.
1572   HadamardRotation(&s[0], &s[4], false);
1573   HadamardRotation(&s[8], &s[12], false);
1574   HadamardRotation(&s[1], &s[5], false);
1575   HadamardRotation(&s[9], &s[13], false);
1576   HadamardRotation(&s[2], &s[6], false);
1577   HadamardRotation(&s[10], &s[14], false);
1578   HadamardRotation(&s[3], &s[7], false);
1579   HadamardRotation(&s[11], &s[15], false);
1580 
1581   // stage 6.
1582   butterfly_rotation(&s[4], &s[5], 48 - 0, true);
1583   butterfly_rotation(&s[12], &s[13], 48 - 0, true);
1584   butterfly_rotation(&s[7], &s[6], 48 - 32, true);
1585   butterfly_rotation(&s[15], &s[14], 48 - 32, true);
1586 
1587   // stage 7.
1588   HadamardRotation(&s[0], &s[2], false);
1589   HadamardRotation(&s[4], &s[6], false);
1590   HadamardRotation(&s[8], &s[10], false);
1591   HadamardRotation(&s[12], &s[14], false);
1592   HadamardRotation(&s[1], &s[3], false);
1593   HadamardRotation(&s[5], &s[7], false);
1594   HadamardRotation(&s[9], &s[11], false);
1595   HadamardRotation(&s[13], &s[15], false);
1596 
1597   // stage 8.
1598   butterfly_rotation(&s[2], &s[3], 32, true);
1599   butterfly_rotation(&s[6], &s[7], 32, true);
1600   butterfly_rotation(&s[10], &s[11], 32, true);
1601   butterfly_rotation(&s[14], &s[15], 32, true);
1602 
1603   // stage 9.
1604   x[0] = s[0];
1605   x[1] = vqnegq_s16(s[8]);
1606   x[2] = s[12];
1607   x[3] = vqnegq_s16(s[4]);
1608   x[4] = s[6];
1609   x[5] = vqnegq_s16(s[14]);
1610   x[6] = s[10];
1611   x[7] = vqnegq_s16(s[2]);
1612   x[8] = s[3];
1613   x[9] = vqnegq_s16(s[11]);
1614   x[10] = s[15];
1615   x[11] = vqnegq_s16(s[7]);
1616   x[12] = s[5];
1617   x[13] = vqnegq_s16(s[13]);
1618   x[14] = s[9];
1619   x[15] = vqnegq_s16(s[1]);
1620 
1621   if (stage_is_rectangular) {
1622     if (is_row) {
1623       const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
1624       int16x8_t output[4];
1625       Transpose4x8To8x4(x, output);
1626       for (auto& o : output) {
1627         o = vqrshlq_s16(o, v_row_shift);
1628       }
1629       StoreDst<16, 4>(dst, step, 0, output);
1630       Transpose4x8To8x4(&x[8], output);
1631       for (auto& o : output) {
1632         o = vqrshlq_s16(o, v_row_shift);
1633       }
1634       StoreDst<16, 4>(dst, step, 8, output);
1635     } else {
1636       StoreDst<8, 16>(dst, step, 0, x);
1637     }
1638   } else {
1639     if (is_row) {
1640       const int16x8_t v_row_shift = vdupq_n_s16(-row_shift);
1641       for (int idx = 0; idx < 16; idx += 8) {
1642         int16x8_t output[8];
1643         Transpose8x8(&x[idx], output);
1644         for (auto& o : output) {
1645           o = vqrshlq_s16(o, v_row_shift);
1646         }
1647         StoreDst<16, 8>(dst, step, idx, output);
1648       }
1649     } else {
1650       StoreDst<16, 16>(dst, step, 0, x);
1651     }
1652   }
1653 }
1654 
Adst16DcOnlyInternal(int16x8_t * s,int16x8_t * x)1655 LIBGAV1_ALWAYS_INLINE void Adst16DcOnlyInternal(int16x8_t* s, int16x8_t* x) {
1656   // stage 2.
1657   ButterflyRotation_FirstIsZero(&s[0], &s[1], 62, true);
1658 
1659   // stage 3.
1660   s[8] = s[0];
1661   s[9] = s[1];
1662 
1663   // stage 4.
1664   ButterflyRotation_4(&s[8], &s[9], 56, true);
1665 
1666   // stage 5.
1667   s[4] = s[0];
1668   s[12] = s[8];
1669   s[5] = s[1];
1670   s[13] = s[9];
1671 
1672   // stage 6.
1673   ButterflyRotation_4(&s[4], &s[5], 48, true);
1674   ButterflyRotation_4(&s[12], &s[13], 48, true);
1675 
1676   // stage 7.
1677   s[2] = s[0];
1678   s[6] = s[4];
1679   s[10] = s[8];
1680   s[14] = s[12];
1681   s[3] = s[1];
1682   s[7] = s[5];
1683   s[11] = s[9];
1684   s[15] = s[13];
1685 
1686   // stage 8.
1687   ButterflyRotation_4(&s[2], &s[3], 32, true);
1688   ButterflyRotation_4(&s[6], &s[7], 32, true);
1689   ButterflyRotation_4(&s[10], &s[11], 32, true);
1690   ButterflyRotation_4(&s[14], &s[15], 32, true);
1691 
1692   // stage 9.
1693   x[0] = s[0];
1694   x[1] = vqnegq_s16(s[8]);
1695   x[2] = s[12];
1696   x[3] = vqnegq_s16(s[4]);
1697   x[4] = s[6];
1698   x[5] = vqnegq_s16(s[14]);
1699   x[6] = s[10];
1700   x[7] = vqnegq_s16(s[2]);
1701   x[8] = s[3];
1702   x[9] = vqnegq_s16(s[11]);
1703   x[10] = s[15];
1704   x[11] = vqnegq_s16(s[7]);
1705   x[12] = s[5];
1706   x[13] = vqnegq_s16(s[13]);
1707   x[14] = s[9];
1708   x[15] = vqnegq_s16(s[1]);
1709 }
1710 
Adst16DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1711 LIBGAV1_ALWAYS_INLINE bool Adst16DcOnly(void* dest, int adjusted_tx_height,
1712                                         bool should_round, int row_shift) {
1713   if (adjusted_tx_height > 1) return false;
1714 
1715   auto* dst = static_cast<int16_t*>(dest);
1716   int16x8_t s[16];
1717   int16x8_t x[16];
1718 
1719   const int16x8_t v_src = vdupq_n_s16(dst[0]);
1720   const uint16x8_t v_mask = vdupq_n_u16(should_round ? 0xffff : 0);
1721   const int16x8_t v_src_round =
1722       vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
1723   // stage 1.
1724   s[1] = vbslq_s16(v_mask, v_src_round, v_src);
1725 
1726   Adst16DcOnlyInternal(s, x);
1727 
1728   for (int i = 0; i < 16; ++i) {
1729     // vqrshlq_s16 will shift right if shift value is negative.
1730     x[i] = vqrshlq_s16(x[i], vdupq_n_s16(-row_shift));
1731     vst1q_lane_s16(&dst[i], x[i], 0);
1732   }
1733 
1734   return true;
1735 }
1736 
Adst16DcOnlyColumn(void * dest,int adjusted_tx_height,int width)1737 LIBGAV1_ALWAYS_INLINE bool Adst16DcOnlyColumn(void* dest,
1738                                               int adjusted_tx_height,
1739                                               int width) {
1740   if (adjusted_tx_height > 1) return false;
1741 
1742   auto* dst = static_cast<int16_t*>(dest);
1743   int i = 0;
1744   do {
1745     int16x8_t s[16];
1746     int16x8_t x[16];
1747     const int16x8_t v_src = vld1q_s16(dst);
1748     // stage 1.
1749     s[1] = v_src;
1750 
1751     Adst16DcOnlyInternal(s, x);
1752 
1753     for (int j = 0; j < 16; ++j) {
1754       vst1_s16(&dst[j * width], vget_low_s16(x[j]));
1755     }
1756     i += 4;
1757     dst += 4;
1758   } while (i < width);
1759 
1760   return true;
1761 }
1762 
1763 //------------------------------------------------------------------------------
1764 // Identity Transforms.
1765 
1766 template <bool is_row_shift>
Identity4_NEON(void * dest,int32_t step)1767 LIBGAV1_ALWAYS_INLINE void Identity4_NEON(void* dest, int32_t step) {
1768   auto* const dst = static_cast<int16_t*>(dest);
1769 
1770   if (is_row_shift) {
1771     const int shift = 1;
1772     const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1773     const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier);
1774     const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
1775     for (int i = 0; i < 4; i += 2) {
1776       const int16x8_t v_src = vld1q_s16(&dst[i * step]);
1777       const int32x4_t v_src_mult_lo =
1778           vmlal_s16(v_dual_round, vget_low_s16(v_src), v_multiplier);
1779       const int32x4_t v_src_mult_hi =
1780           vmlal_s16(v_dual_round, vget_high_s16(v_src), v_multiplier);
1781       const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
1782       const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
1783       vst1q_s16(&dst[i * step],
1784                 vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi)));
1785     }
1786   } else {
1787     for (int i = 0; i < 4; i += 2) {
1788       const int16x8_t v_src = vld1q_s16(&dst[i * step]);
1789       const int16x8_t a =
1790           vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3);
1791       const int16x8_t b = vqaddq_s16(v_src, a);
1792       vst1q_s16(&dst[i * step], b);
1793     }
1794   }
1795 }
1796 
Identity4DcOnly(void * dest,int adjusted_tx_height,bool should_round,int tx_height)1797 LIBGAV1_ALWAYS_INLINE bool Identity4DcOnly(void* dest, int adjusted_tx_height,
1798                                            bool should_round, int tx_height) {
1799   if (adjusted_tx_height > 1) return false;
1800 
1801   auto* dst = static_cast<int16_t*>(dest);
1802   const int16x4_t v_src0 = vdup_n_s16(dst[0]);
1803   const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
1804   const int16x4_t v_src_round =
1805       vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
1806   const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
1807   const int shift = tx_height < 16 ? 0 : 1;
1808   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
1809   const int16x4_t v_multiplier = vdup_n_s16(kIdentity4Multiplier);
1810   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
1811   const int32x4_t v_src_mult_lo = vmlal_s16(v_dual_round, v_src, v_multiplier);
1812   const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift);
1813   vst1_lane_s16(dst, vqmovn_s32(dst_0), 0);
1814   return true;
1815 }
1816 
1817 template <int identity_size>
IdentityColumnStoreToFrame(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int16_t * LIBGAV1_RESTRICT source)1818 LIBGAV1_ALWAYS_INLINE void IdentityColumnStoreToFrame(
1819     Array2DView<uint8_t> frame, const int start_x, const int start_y,
1820     const int tx_width, const int tx_height,
1821     const int16_t* LIBGAV1_RESTRICT source) {
1822   const int stride = frame.columns();
1823   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1824 
1825   if (identity_size < 32) {
1826     if (tx_width == 4) {
1827       uint8x8_t frame_data = vdup_n_u8(0);
1828       int i = 0;
1829       do {
1830         const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
1831 
1832         int16x4_t v_dst_i;
1833         if (identity_size == 4) {
1834           const int16x4_t v_src_fraction =
1835               vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3);
1836           v_dst_i = vqadd_s16(v_src, v_src_fraction);
1837         } else if (identity_size == 8) {
1838           v_dst_i = vqadd_s16(v_src, v_src);
1839         } else {  // identity_size == 16
1840           const int16x4_t v_src_mult =
1841               vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 4);
1842           const int16x4_t v_srcx2 = vqadd_s16(v_src, v_src);
1843           v_dst_i = vqadd_s16(v_srcx2, v_src_mult);
1844         }
1845 
1846         frame_data = Load4<0>(dst, frame_data);
1847         const int16x4_t a = vrshr_n_s16(v_dst_i, 4);
1848         const uint16x8_t b =
1849             vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
1850         const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
1851         StoreLo4(dst, d);
1852         dst += stride;
1853       } while (++i < tx_height);
1854     } else {
1855       int i = 0;
1856       do {
1857         const int row = i * tx_width;
1858         int j = 0;
1859         do {
1860           const int16x8_t v_src = vld1q_s16(&source[row + j]);
1861 
1862           int16x8_t v_dst_i;
1863           if (identity_size == 4) {
1864             const int16x8_t v_src_fraction =
1865                 vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 3);
1866             v_dst_i = vqaddq_s16(v_src, v_src_fraction);
1867           } else if (identity_size == 8) {
1868             v_dst_i = vqaddq_s16(v_src, v_src);
1869           } else {  // identity_size == 16
1870             const int16x8_t v_src_mult =
1871                 vqrdmulhq_n_s16(v_src, kIdentity4MultiplierFraction << 4);
1872             const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src);
1873             v_dst_i = vqaddq_s16(v_src_mult, v_srcx2);
1874           }
1875 
1876           const uint8x8_t frame_data = vld1_u8(dst + j);
1877           const int16x8_t a = vrshrq_n_s16(v_dst_i, 4);
1878           const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
1879           const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
1880           vst1_u8(dst + j, d);
1881           j += 8;
1882         } while (j < tx_width);
1883         dst += stride;
1884       } while (++i < tx_height);
1885     }
1886   } else {
1887     int i = 0;
1888     do {
1889       const int row = i * tx_width;
1890       int j = 0;
1891       do {
1892         const int16x8_t v_dst_i = vld1q_s16(&source[row + j]);
1893         const uint8x8_t frame_data = vld1_u8(dst + j);
1894         const int16x8_t a = vrshrq_n_s16(v_dst_i, 2);
1895         const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
1896         const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
1897         vst1_u8(dst + j, d);
1898         j += 8;
1899       } while (j < tx_width);
1900       dst += stride;
1901     } while (++i < tx_height);
1902   }
1903 }
1904 
Identity4RowColumnStoreToFrame(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int tx_height,const int16_t * LIBGAV1_RESTRICT source)1905 LIBGAV1_ALWAYS_INLINE void Identity4RowColumnStoreToFrame(
1906     Array2DView<uint8_t> frame, const int start_x, const int start_y,
1907     const int tx_width, const int tx_height,
1908     const int16_t* LIBGAV1_RESTRICT source) {
1909   const int stride = frame.columns();
1910   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
1911 
1912   if (tx_width == 4) {
1913     uint8x8_t frame_data = vdup_n_u8(0);
1914     int i = 0;
1915     do {
1916       const int16x4_t v_src = vld1_s16(&source[i * tx_width]);
1917       const int16x4_t v_src_mult =
1918           vqrdmulh_n_s16(v_src, kIdentity4MultiplierFraction << 3);
1919       const int16x4_t v_dst_row = vqadd_s16(v_src, v_src_mult);
1920       const int16x4_t v_src_mult2 =
1921           vqrdmulh_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3);
1922       const int16x4_t v_dst_col = vqadd_s16(v_dst_row, v_src_mult2);
1923       frame_data = Load4<0>(dst, frame_data);
1924       const int16x4_t a = vrshr_n_s16(v_dst_col, 4);
1925       const uint16x8_t b =
1926           vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
1927       const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
1928       StoreLo4(dst, d);
1929       dst += stride;
1930     } while (++i < tx_height);
1931   } else {
1932     int i = 0;
1933     do {
1934       const int row = i * tx_width;
1935       int j = 0;
1936       do {
1937         const int16x8_t v_src = vld1q_s16(&source[row + j]);
1938         const int16x8_t v_src_round =
1939             vqrdmulhq_n_s16(v_src, kTransformRowMultiplier << 3);
1940         const int16x8_t v_dst_row = vqaddq_s16(v_src_round, v_src_round);
1941         const int16x8_t v_src_mult2 =
1942             vqrdmulhq_n_s16(v_dst_row, kIdentity4MultiplierFraction << 3);
1943         const int16x8_t v_dst_col = vqaddq_s16(v_dst_row, v_src_mult2);
1944         const uint8x8_t frame_data = vld1_u8(dst + j);
1945         const int16x8_t a = vrshrq_n_s16(v_dst_col, 4);
1946         const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
1947         const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
1948         vst1_u8(dst + j, d);
1949         j += 8;
1950       } while (j < tx_width);
1951       dst += stride;
1952     } while (++i < tx_height);
1953   }
1954 }
1955 
Identity8Row32_NEON(void * dest,int32_t step)1956 LIBGAV1_ALWAYS_INLINE void Identity8Row32_NEON(void* dest, int32_t step) {
1957   auto* const dst = static_cast<int16_t*>(dest);
1958 
1959   // When combining the identity8 multiplier with the row shift, the
1960   // calculations for tx_height equal to 32 can be simplified from
1961   // ((A * 2) + 2) >> 2) to ((A + 1) >> 1).
1962   for (int i = 0; i < 4; ++i) {
1963     const int16x8_t v_src = vld1q_s16(&dst[i * step]);
1964     const int16x8_t a = vrshrq_n_s16(v_src, 1);
1965     vst1q_s16(&dst[i * step], a);
1966   }
1967 }
1968 
Identity8Row4_NEON(void * dest,int32_t step)1969 LIBGAV1_ALWAYS_INLINE void Identity8Row4_NEON(void* dest, int32_t step) {
1970   auto* const dst = static_cast<int16_t*>(dest);
1971 
1972   for (int i = 0; i < 4; ++i) {
1973     const int16x8_t v_src = vld1q_s16(&dst[i * step]);
1974     // For bitdepth == 8, the identity row clamps to a signed 16bit value, so
1975     // saturating add here is ok.
1976     const int16x8_t v_srcx2 = vqaddq_s16(v_src, v_src);
1977     vst1q_s16(&dst[i * step], v_srcx2);
1978   }
1979 }
1980 
Identity8DcOnly(void * dest,int adjusted_tx_height,bool should_round,int row_shift)1981 LIBGAV1_ALWAYS_INLINE bool Identity8DcOnly(void* dest, int adjusted_tx_height,
1982                                            bool should_round, int row_shift) {
1983   if (adjusted_tx_height > 1) return false;
1984 
1985   auto* dst = static_cast<int16_t*>(dest);
1986   const int16x4_t v_src0 = vdup_n_s16(dst[0]);
1987   const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
1988   const int16x4_t v_src_round =
1989       vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
1990   const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
1991   const int32x4_t v_srcx2 = vaddl_s16(v_src, v_src);
1992   const int32x4_t dst_0 = vqrshlq_s32(v_srcx2, vdupq_n_s32(-row_shift));
1993   vst1_lane_s16(dst, vqmovn_s32(dst_0), 0);
1994   return true;
1995 }
1996 
Identity16Row_NEON(void * dest,int32_t step,int shift)1997 LIBGAV1_ALWAYS_INLINE void Identity16Row_NEON(void* dest, int32_t step,
1998                                               int shift) {
1999   auto* const dst = static_cast<int16_t*>(dest);
2000   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
2001   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
2002 
2003   for (int i = 0; i < 4; ++i) {
2004     for (int j = 0; j < 2; ++j) {
2005       const int16x8_t v_src = vld1q_s16(&dst[i * step + j * 8]);
2006       const int32x4_t v_src_mult_lo =
2007           vmlal_n_s16(v_dual_round, vget_low_s16(v_src), kIdentity16Multiplier);
2008       const int32x4_t v_src_mult_hi = vmlal_n_s16(
2009           v_dual_round, vget_high_s16(v_src), kIdentity16Multiplier);
2010       const int32x4_t shift_lo = vqshlq_s32(v_src_mult_lo, v_shift);
2011       const int32x4_t shift_hi = vqshlq_s32(v_src_mult_hi, v_shift);
2012       vst1q_s16(&dst[i * step + j * 8],
2013                 vcombine_s16(vqmovn_s32(shift_lo), vqmovn_s32(shift_hi)));
2014     }
2015   }
2016 }
2017 
Identity16DcOnly(void * dest,int adjusted_tx_height,bool should_round,int shift)2018 LIBGAV1_ALWAYS_INLINE bool Identity16DcOnly(void* dest, int adjusted_tx_height,
2019                                             bool should_round, int shift) {
2020   if (adjusted_tx_height > 1) return false;
2021 
2022   auto* dst = static_cast<int16_t*>(dest);
2023   const int16x4_t v_src0 = vdup_n_s16(dst[0]);
2024   const uint16x4_t v_mask = vdup_n_u16(should_round ? 0xffff : 0);
2025   const int16x4_t v_src_round =
2026       vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
2027   const int16x4_t v_src = vbsl_s16(v_mask, v_src_round, v_src0);
2028   const int32x4_t v_dual_round = vdupq_n_s32((1 + (shift << 1)) << 11);
2029   const int16x4_t v_multiplier = vdup_n_s16(kIdentity16Multiplier);
2030   const int32x4_t v_shift = vdupq_n_s32(-(12 + shift));
2031   const int32x4_t v_src_mult_lo =
2032       vmlal_s16(v_dual_round, (v_src), v_multiplier);
2033   const int32x4_t dst_0 = vqshlq_s32(v_src_mult_lo, v_shift);
2034   vst1_lane_s16(dst, vqmovn_s32(dst_0), 0);
2035   return true;
2036 }
2037 
Identity32Row16_NEON(void * dest,const int32_t step)2038 LIBGAV1_ALWAYS_INLINE void Identity32Row16_NEON(void* dest,
2039                                                 const int32_t step) {
2040   auto* const dst = static_cast<int16_t*>(dest);
2041 
2042   // When combining the identity32 multiplier with the row shift, the
2043   // calculation for tx_height equal to 16 can be simplified from
2044   // ((A * 4) + 1) >> 1) to (A * 2).
2045   for (int i = 0; i < 4; ++i) {
2046     for (int j = 0; j < 32; j += 8) {
2047       const int16x8_t v_src = vld1q_s16(&dst[i * step + j]);
2048       // For bitdepth == 8, the identity row clamps to a signed 16bit value, so
2049       // saturating add here is ok.
2050       const int16x8_t v_dst_i = vqaddq_s16(v_src, v_src);
2051       vst1q_s16(&dst[i * step + j], v_dst_i);
2052     }
2053   }
2054 }
2055 
Identity32DcOnly(void * dest,int adjusted_tx_height)2056 LIBGAV1_ALWAYS_INLINE bool Identity32DcOnly(void* dest,
2057                                             int adjusted_tx_height) {
2058   if (adjusted_tx_height > 1) return false;
2059 
2060   auto* dst = static_cast<int16_t*>(dest);
2061   const int16x4_t v_src0 = vdup_n_s16(dst[0]);
2062   const int16x4_t v_src = vqrdmulh_n_s16(v_src0, kTransformRowMultiplier << 3);
2063   // When combining the identity32 multiplier with the row shift, the
2064   // calculation for tx_height equal to 16 can be simplified from
2065   // ((A * 4) + 1) >> 1) to (A * 2).
2066   const int16x4_t v_dst_0 = vqadd_s16(v_src, v_src);
2067   vst1_lane_s16(dst, v_dst_0, 0);
2068   return true;
2069 }
2070 
2071 //------------------------------------------------------------------------------
2072 // Walsh Hadamard Transform.
2073 
2074 // Transposes a 4x4 matrix and then permutes the rows of the transposed matrix
2075 // for the WHT. The input matrix is in two "wide" int16x8_t variables. The
2076 // output matrix is in four int16x4_t variables.
2077 //
2078 // Input:
2079 // in[0]: 00 01 02 03  10 11 12 13
2080 // in[1]: 20 21 22 23  30 31 32 33
2081 // Output:
2082 // out[0]: 00 10 20 30
2083 // out[1]: 03 13 23 33
2084 // out[2]: 01 11 21 31
2085 // out[3]: 02 12 22 32
TransposeAndPermute4x4WideInput(const int16x8_t in[2],int16x4_t out[4])2086 LIBGAV1_ALWAYS_INLINE void TransposeAndPermute4x4WideInput(
2087     const int16x8_t in[2], int16x4_t out[4]) {
2088   // Swap 32 bit elements. Goes from:
2089   // in[0]: 00 01 02 03  10 11 12 13
2090   // in[1]: 20 21 22 23  30 31 32 33
2091   // to:
2092   // b0.val[0]: 00 01 20 21  10 11 30 31
2093   // b0.val[1]: 02 03 22 23  12 13 32 33
2094 
2095   const int32x4x2_t b0 =
2096       vtrnq_s32(vreinterpretq_s32_s16(in[0]), vreinterpretq_s32_s16(in[1]));
2097 
2098   // Swap 16 bit elements. Goes from:
2099   // vget_low_s32(b0.val[0]):  00 01 20 21
2100   // vget_high_s32(b0.val[0]): 10 11 30 31
2101   // vget_low_s32(b0.val[1]):  02 03 22 23
2102   // vget_high_s32(b0.val[1]): 12 13 32 33
2103   // to:
2104   // c0.val[0]: 00 10 20 30
2105   // c0.val[1]: 01 11 21 32
2106   // c1.val[0]: 02 12 22 32
2107   // c1.val[1]: 03 13 23 33
2108 
2109   const int16x4x2_t c0 =
2110       vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[0])),
2111                vreinterpret_s16_s32(vget_high_s32(b0.val[0])));
2112   const int16x4x2_t c1 =
2113       vtrn_s16(vreinterpret_s16_s32(vget_low_s32(b0.val[1])),
2114                vreinterpret_s16_s32(vget_high_s32(b0.val[1])));
2115 
2116   out[0] = c0.val[0];
2117   out[1] = c1.val[1];
2118   out[2] = c0.val[1];
2119   out[3] = c1.val[0];
2120 }
2121 
2122 // Process 4 wht4 rows and columns.
Wht4_NEON(uint8_t * LIBGAV1_RESTRICT dst,const int dst_stride,const void * LIBGAV1_RESTRICT source,const int adjusted_tx_height)2123 LIBGAV1_ALWAYS_INLINE void Wht4_NEON(uint8_t* LIBGAV1_RESTRICT dst,
2124                                      const int dst_stride,
2125                                      const void* LIBGAV1_RESTRICT source,
2126                                      const int adjusted_tx_height) {
2127   const auto* const src = static_cast<const int16_t*>(source);
2128   int16x4_t s[4];
2129 
2130   if (adjusted_tx_height == 1) {
2131     // Special case: only src[0] is nonzero.
2132     //   src[0]  0   0   0
2133     //       0   0   0   0
2134     //       0   0   0   0
2135     //       0   0   0   0
2136     //
2137     // After the row and column transforms are applied, we have:
2138     //       f   h   h   h
2139     //       g   i   i   i
2140     //       g   i   i   i
2141     //       g   i   i   i
2142     // where f, g, h, i are computed as follows.
2143     int16_t f = (src[0] >> 2) - (src[0] >> 3);
2144     const int16_t g = f >> 1;
2145     f = f - (f >> 1);
2146     const int16_t h = (src[0] >> 3) - (src[0] >> 4);
2147     const int16_t i = (src[0] >> 4);
2148     s[0] = vdup_n_s16(h);
2149     s[0] = vset_lane_s16(f, s[0], 0);
2150     s[1] = vdup_n_s16(i);
2151     s[1] = vset_lane_s16(g, s[1], 0);
2152     s[2] = s[3] = s[1];
2153   } else {
2154     // Load the 4x4 source in transposed form.
2155     int16x4x4_t columns = vld4_s16(src);
2156     // Shift right and permute the columns for the WHT.
2157     s[0] = vshr_n_s16(columns.val[0], 2);
2158     s[2] = vshr_n_s16(columns.val[1], 2);
2159     s[3] = vshr_n_s16(columns.val[2], 2);
2160     s[1] = vshr_n_s16(columns.val[3], 2);
2161 
2162     // Row transforms.
2163     s[0] = vadd_s16(s[0], s[2]);
2164     s[3] = vsub_s16(s[3], s[1]);
2165     int16x4_t e = vhsub_s16(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
2166     s[1] = vsub_s16(e, s[1]);
2167     s[2] = vsub_s16(e, s[2]);
2168     s[0] = vsub_s16(s[0], s[1]);
2169     s[3] = vadd_s16(s[3], s[2]);
2170 
2171     int16x8_t x[2];
2172     x[0] = vcombine_s16(s[0], s[1]);
2173     x[1] = vcombine_s16(s[2], s[3]);
2174     TransposeAndPermute4x4WideInput(x, s);
2175 
2176     // Column transforms.
2177     s[0] = vadd_s16(s[0], s[2]);
2178     s[3] = vsub_s16(s[3], s[1]);
2179     e = vhsub_s16(s[0], s[3]);  // e = (s[0] - s[3]) >> 1
2180     s[1] = vsub_s16(e, s[1]);
2181     s[2] = vsub_s16(e, s[2]);
2182     s[0] = vsub_s16(s[0], s[1]);
2183     s[3] = vadd_s16(s[3], s[2]);
2184   }
2185 
2186   // Store to frame.
2187   uint8x8_t frame_data = vdup_n_u8(0);
2188   for (int row = 0; row < 4; row += 2) {
2189     frame_data = Load4<0>(dst, frame_data);
2190     frame_data = Load4<1>(dst + dst_stride, frame_data);
2191     const int16x8_t residual = vcombine_s16(s[row], s[row + 1]);
2192     const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(residual), frame_data);
2193     frame_data = vqmovun_s16(vreinterpretq_s16_u16(b));
2194     StoreLo4(dst, frame_data);
2195     dst += dst_stride;
2196     StoreHi4(dst, frame_data);
2197     dst += dst_stride;
2198   }
2199 }
2200 
2201 //------------------------------------------------------------------------------
2202 // row/column transform loops
2203 
2204 template <int tx_height>
FlipColumns(int16_t * source,int tx_width)2205 LIBGAV1_ALWAYS_INLINE void FlipColumns(int16_t* source, int tx_width) {
2206   if (tx_width >= 16) {
2207     int i = 0;
2208     do {
2209       const int16x8_t a = vld1q_s16(&source[i]);
2210       const int16x8_t b = vld1q_s16(&source[i + 8]);
2211       const int16x8_t c = vrev64q_s16(a);
2212       const int16x8_t d = vrev64q_s16(b);
2213       vst1q_s16(&source[i], vcombine_s16(vget_high_s16(d), vget_low_s16(d)));
2214       vst1q_s16(&source[i + 8],
2215                 vcombine_s16(vget_high_s16(c), vget_low_s16(c)));
2216       i += 16;
2217     } while (i < tx_width * tx_height);
2218   } else if (tx_width == 8) {
2219     for (int i = 0; i < 8 * tx_height; i += 8) {
2220       const int16x8_t a = vld1q_s16(&source[i]);
2221       const int16x8_t b = vrev64q_s16(a);
2222       vst1q_s16(&source[i], vcombine_s16(vget_high_s16(b), vget_low_s16(b)));
2223     }
2224   } else {
2225     // Process two rows per iteration.
2226     for (int i = 0; i < 4 * tx_height; i += 8) {
2227       const int16x8_t a = vld1q_s16(&source[i]);
2228       vst1q_s16(&source[i], vrev64q_s16(a));
2229     }
2230   }
2231 }
2232 
2233 template <int tx_width>
ApplyRounding(int16_t * source,int num_rows)2234 LIBGAV1_ALWAYS_INLINE void ApplyRounding(int16_t* source, int num_rows) {
2235   if (tx_width == 4) {
2236     // Process two rows per iteration.
2237     int i = 0;
2238     do {
2239       const int16x8_t a = vld1q_s16(&source[i]);
2240       const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3);
2241       vst1q_s16(&source[i], b);
2242       i += 8;
2243     } while (i < tx_width * num_rows);
2244   } else {
2245     int i = 0;
2246     do {
2247       // The last 32 values of every row are always zero if the |tx_width| is
2248       // 64.
2249       const int non_zero_width = (tx_width < 64) ? tx_width : 32;
2250       int j = 0;
2251       do {
2252         const int16x8_t a = vld1q_s16(&source[i * tx_width + j]);
2253         const int16x8_t b = vqrdmulhq_n_s16(a, kTransformRowMultiplier << 3);
2254         vst1q_s16(&source[i * tx_width + j], b);
2255         j += 8;
2256       } while (j < non_zero_width);
2257     } while (++i < num_rows);
2258   }
2259 }
2260 
2261 template <int tx_width>
RowShift(int16_t * source,int num_rows,int row_shift)2262 LIBGAV1_ALWAYS_INLINE void RowShift(int16_t* source, int num_rows,
2263                                     int row_shift) {
2264   // vqrshlq_s16 will shift right if shift value is negative.
2265   row_shift = -row_shift;
2266 
2267   if (tx_width == 4) {
2268     // Process two rows per iteration.
2269     int i = 0;
2270     do {
2271       const int16x8_t residual = vld1q_s16(&source[i]);
2272       vst1q_s16(&source[i], vqrshlq_s16(residual, vdupq_n_s16(row_shift)));
2273       i += 8;
2274     } while (i < tx_width * num_rows);
2275   } else {
2276     int i = 0;
2277     do {
2278       for (int j = 0; j < tx_width; j += 8) {
2279         const int16x8_t residual = vld1q_s16(&source[i * tx_width + j]);
2280         const int16x8_t residual_shifted =
2281             vqrshlq_s16(residual, vdupq_n_s16(row_shift));
2282         vst1q_s16(&source[i * tx_width + j], residual_shifted);
2283       }
2284     } while (++i < num_rows);
2285   }
2286 }
2287 
2288 template <int tx_height, bool enable_flip_rows = false>
StoreToFrameWithRound(Array2DView<uint8_t> frame,const int start_x,const int start_y,const int tx_width,const int16_t * LIBGAV1_RESTRICT source,TransformType tx_type)2289 LIBGAV1_ALWAYS_INLINE void StoreToFrameWithRound(
2290     Array2DView<uint8_t> frame, const int start_x, const int start_y,
2291     const int tx_width, const int16_t* LIBGAV1_RESTRICT source,
2292     TransformType tx_type) {
2293   const bool flip_rows =
2294       enable_flip_rows ? kTransformFlipRowsMask.Contains(tx_type) : false;
2295   const int stride = frame.columns();
2296   uint8_t* LIBGAV1_RESTRICT dst = frame[start_y] + start_x;
2297 
2298   // Enable for 4x4, 4x8, 4x16
2299   if (tx_height < 32 && tx_width == 4) {
2300     uint8x8_t frame_data = vdup_n_u8(0);
2301     for (int i = 0; i < tx_height; ++i) {
2302       const int row = flip_rows ? (tx_height - i - 1) * 4 : i * 4;
2303       const int16x4_t residual = vld1_s16(&source[row]);
2304       frame_data = Load4<0>(dst, frame_data);
2305       const int16x4_t a = vrshr_n_s16(residual, 4);
2306       const uint16x8_t b =
2307           vaddw_u8(vreinterpretq_u16_s16(vcombine_s16(a, a)), frame_data);
2308       const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
2309       StoreLo4(dst, d);
2310       dst += stride;
2311     }
2312     // Enable for 8x4, 8x8, 8x16, 8x32
2313   } else if (tx_height < 64 && tx_width == 8) {
2314     for (int i = 0; i < tx_height; ++i) {
2315       const int row = flip_rows ? (tx_height - i - 1) * 8 : i * 8;
2316       const int16x8_t residual = vld1q_s16(&source[row]);
2317       const uint8x8_t frame_data = vld1_u8(dst);
2318       const int16x8_t a = vrshrq_n_s16(residual, 4);
2319       const uint16x8_t b = vaddw_u8(vreinterpretq_u16_s16(a), frame_data);
2320       const uint8x8_t d = vqmovun_s16(vreinterpretq_s16_u16(b));
2321       vst1_u8(dst, d);
2322       dst += stride;
2323     }
2324     // Remaining widths >= 16.
2325   } else {
2326     for (int i = 0; i < tx_height; ++i) {
2327       const int y = start_y + i;
2328       const int row = flip_rows ? (tx_height - i - 1) * tx_width : i * tx_width;
2329       int j = 0;
2330       do {
2331         const int x = start_x + j;
2332         const int16x8_t residual = vld1q_s16(&source[row + j]);
2333         const int16x8_t residual_hi = vld1q_s16(&source[row + j + 8]);
2334         const uint8x16_t frame_data = vld1q_u8(frame[y] + x);
2335         const int16x8_t a = vrshrq_n_s16(residual, 4);
2336         const int16x8_t a_hi = vrshrq_n_s16(residual_hi, 4);
2337         const uint16x8_t b =
2338             vaddw_u8(vreinterpretq_u16_s16(a), vget_low_u8(frame_data));
2339         const uint16x8_t b_hi =
2340             vaddw_u8(vreinterpretq_u16_s16(a_hi), vget_high_u8(frame_data));
2341         vst1q_u8(frame[y] + x,
2342                  vcombine_u8(vqmovun_s16(vreinterpretq_s16_u16(b)),
2343                              vqmovun_s16(vreinterpretq_s16_u16(b_hi))));
2344         j += 16;
2345       } while (j < tx_width);
2346     }
2347   }
2348 }
2349 
Dct4TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2350 void Dct4TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size,
2351                                int adjusted_tx_height, void* src_buffer,
2352                                int /*start_x*/, int /*start_y*/,
2353                                void* /*dst_frame*/) {
2354   auto* src = static_cast<int16_t*>(src_buffer);
2355   const int tx_height = kTransformHeight[tx_size];
2356   const bool should_round = (tx_height == 8);
2357   const int row_shift = static_cast<int>(tx_height == 16);
2358 
2359   if (DctDcOnly<4>(src, adjusted_tx_height, should_round, row_shift)) {
2360     return;
2361   }
2362 
2363   if (should_round) {
2364     ApplyRounding<4>(src, adjusted_tx_height);
2365   }
2366 
2367   if (adjusted_tx_height == 4) {
2368     // Process 4 1d dct4 rows in parallel.
2369     Dct4_NEON<ButterflyRotation_4, false>(src, /*step=*/4, /*transpose=*/true);
2370   } else {
2371     // Process 8 1d dct4 rows in parallel per iteration.
2372     int i = adjusted_tx_height;
2373     auto* data = src;
2374     do {
2375       Dct4_NEON<ButterflyRotation_8, true>(data, /*step=*/4,
2376                                            /*transpose=*/true);
2377       data += 32;
2378       i -= 8;
2379     } while (i != 0);
2380   }
2381   if (tx_height == 16) {
2382     RowShift<4>(src, adjusted_tx_height, 1);
2383   }
2384 }
2385 
Dct4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2386 void Dct4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2387                                   int adjusted_tx_height,
2388                                   void* LIBGAV1_RESTRICT src_buffer,
2389                                   int start_x, int start_y,
2390                                   void* LIBGAV1_RESTRICT dst_frame) {
2391   auto* src = static_cast<int16_t*>(src_buffer);
2392   const int tx_width = kTransformWidth[tx_size];
2393 
2394   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2395     FlipColumns<4>(src, tx_width);
2396   }
2397 
2398   if (!DctDcOnlyColumn<4>(src, adjusted_tx_height, tx_width)) {
2399     if (tx_width == 4) {
2400       // Process 4 1d dct4 columns in parallel.
2401       Dct4_NEON<ButterflyRotation_4, false>(src, tx_width, /*transpose=*/false);
2402     } else {
2403       // Process 8 1d dct4 columns in parallel per iteration.
2404       int i = tx_width;
2405       auto* data = src;
2406       do {
2407         Dct4_NEON<ButterflyRotation_8, true>(data, tx_width,
2408                                              /*transpose=*/false);
2409         data += 8;
2410         i -= 8;
2411       } while (i != 0);
2412     }
2413   }
2414 
2415   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2416   StoreToFrameWithRound<4>(frame, start_x, start_y, tx_width, src, tx_type);
2417 }
2418 
Dct8TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2419 void Dct8TransformLoopRow_NEON(TransformType /*tx_type*/, TransformSize tx_size,
2420                                int adjusted_tx_height, void* src_buffer,
2421                                int /*start_x*/, int /*start_y*/,
2422                                void* /*dst_frame*/) {
2423   auto* src = static_cast<int16_t*>(src_buffer);
2424   const bool should_round = kShouldRound[tx_size];
2425   const uint8_t row_shift = kTransformRowShift[tx_size];
2426 
2427   if (DctDcOnly<8>(src, adjusted_tx_height, should_round, row_shift)) {
2428     return;
2429   }
2430 
2431   if (should_round) {
2432     ApplyRounding<8>(src, adjusted_tx_height);
2433   }
2434 
2435   if (adjusted_tx_height == 4) {
2436     // Process 4 1d dct8 rows in parallel.
2437     Dct8_NEON<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true);
2438   } else {
2439     // Process 8 1d dct8 rows in parallel per iteration.
2440     assert(adjusted_tx_height % 8 == 0);
2441     int i = adjusted_tx_height;
2442     auto* data = src;
2443     do {
2444       Dct8_NEON<ButterflyRotation_8, false>(data, /*step=*/8,
2445                                             /*transpose=*/true);
2446       data += 64;
2447       i -= 8;
2448     } while (i != 0);
2449   }
2450   if (row_shift > 0) {
2451     RowShift<8>(src, adjusted_tx_height, row_shift);
2452   }
2453 }
2454 
Dct8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2455 void Dct8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2456                                   int adjusted_tx_height,
2457                                   void* LIBGAV1_RESTRICT src_buffer,
2458                                   int start_x, int start_y,
2459                                   void* LIBGAV1_RESTRICT dst_frame) {
2460   auto* src = static_cast<int16_t*>(src_buffer);
2461   const int tx_width = kTransformWidth[tx_size];
2462 
2463   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2464     FlipColumns<8>(src, tx_width);
2465   }
2466 
2467   if (!DctDcOnlyColumn<8>(src, adjusted_tx_height, tx_width)) {
2468     if (tx_width == 4) {
2469       // Process 4 1d dct8 columns in parallel.
2470       Dct8_NEON<ButterflyRotation_4, true>(src, 4, /*transpose=*/false);
2471     } else {
2472       // Process 8 1d dct8 columns in parallel per iteration.
2473       int i = tx_width;
2474       auto* data = src;
2475       do {
2476         Dct8_NEON<ButterflyRotation_8, false>(data, tx_width,
2477                                               /*transpose=*/false);
2478         data += 8;
2479         i -= 8;
2480       } while (i != 0);
2481     }
2482   }
2483   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2484   StoreToFrameWithRound<8>(frame, start_x, start_y, tx_width, src, tx_type);
2485 }
2486 
Dct16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2487 void Dct16TransformLoopRow_NEON(TransformType /*tx_type*/,
2488                                 TransformSize tx_size, int adjusted_tx_height,
2489                                 void* src_buffer, int /*start_x*/,
2490                                 int /*start_y*/, void* /*dst_frame*/) {
2491   auto* src = static_cast<int16_t*>(src_buffer);
2492   const bool should_round = kShouldRound[tx_size];
2493   const uint8_t row_shift = kTransformRowShift[tx_size];
2494 
2495   if (DctDcOnly<16>(src, adjusted_tx_height, should_round, row_shift)) {
2496     return;
2497   }
2498 
2499   if (should_round) {
2500     ApplyRounding<16>(src, adjusted_tx_height);
2501   }
2502 
2503   if (adjusted_tx_height == 4) {
2504     // Process 4 1d dct16 rows in parallel.
2505     Dct16_NEON<ButterflyRotation_4, true>(src, 16, /*is_row=*/true, row_shift);
2506   } else {
2507     assert(adjusted_tx_height % 8 == 0);
2508     int i = adjusted_tx_height;
2509     do {
2510       // Process 8 1d dct16 rows in parallel per iteration.
2511       Dct16_NEON<ButterflyRotation_8, false>(src, 16, /*is_row=*/true,
2512                                              row_shift);
2513       src += 128;
2514       i -= 8;
2515     } while (i != 0);
2516   }
2517 }
2518 
Dct16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2519 void Dct16TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2520                                    int adjusted_tx_height,
2521                                    void* LIBGAV1_RESTRICT src_buffer,
2522                                    int start_x, int start_y,
2523                                    void* LIBGAV1_RESTRICT dst_frame) {
2524   auto* src = static_cast<int16_t*>(src_buffer);
2525   const int tx_width = kTransformWidth[tx_size];
2526 
2527   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2528     FlipColumns<16>(src, tx_width);
2529   }
2530 
2531   if (!DctDcOnlyColumn<16>(src, adjusted_tx_height, tx_width)) {
2532     if (tx_width == 4) {
2533       // Process 4 1d dct16 columns in parallel.
2534       Dct16_NEON<ButterflyRotation_4, true>(src, 4, /*is_row=*/false,
2535                                             /*row_shift=*/0);
2536     } else {
2537       int i = tx_width;
2538       auto* data = src;
2539       do {
2540         // Process 8 1d dct16 columns in parallel per iteration.
2541         Dct16_NEON<ButterflyRotation_8, false>(data, tx_width, /*is_row=*/false,
2542                                                /*row_shift=*/0);
2543         data += 8;
2544         i -= 8;
2545       } while (i != 0);
2546     }
2547   }
2548   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2549   StoreToFrameWithRound<16>(frame, start_x, start_y, tx_width, src, tx_type);
2550 }
2551 
Dct32TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2552 void Dct32TransformLoopRow_NEON(TransformType /*tx_type*/,
2553                                 TransformSize tx_size, int adjusted_tx_height,
2554                                 void* src_buffer, int /*start_x*/,
2555                                 int /*start_y*/, void* /*dst_frame*/) {
2556   auto* src = static_cast<int16_t*>(src_buffer);
2557   const bool should_round = kShouldRound[tx_size];
2558   const uint8_t row_shift = kTransformRowShift[tx_size];
2559 
2560   if (DctDcOnly<32>(src, adjusted_tx_height, should_round, row_shift)) {
2561     return;
2562   }
2563 
2564   if (should_round) {
2565     ApplyRounding<32>(src, adjusted_tx_height);
2566   }
2567   // Process 8 1d dct32 rows in parallel per iteration.
2568   int i = 0;
2569   do {
2570     Dct32_NEON(&src[i * 32], 32, /*is_row=*/true, row_shift);
2571     i += 8;
2572   } while (i < adjusted_tx_height);
2573 }
2574 
Dct32TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2575 void Dct32TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2576                                    int adjusted_tx_height,
2577                                    void* LIBGAV1_RESTRICT src_buffer,
2578                                    int start_x, int start_y,
2579                                    void* LIBGAV1_RESTRICT dst_frame) {
2580   auto* src = static_cast<int16_t*>(src_buffer);
2581   const int tx_width = kTransformWidth[tx_size];
2582 
2583   if (!DctDcOnlyColumn<32>(src, adjusted_tx_height, tx_width)) {
2584     // Process 8 1d dct32 columns in parallel per iteration.
2585     int i = tx_width;
2586     auto* data = src;
2587     do {
2588       Dct32_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2589       data += 8;
2590       i -= 8;
2591     } while (i != 0);
2592   }
2593   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2594   StoreToFrameWithRound<32>(frame, start_x, start_y, tx_width, src, tx_type);
2595 }
2596 
Dct64TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2597 void Dct64TransformLoopRow_NEON(TransformType /*tx_type*/,
2598                                 TransformSize tx_size, int adjusted_tx_height,
2599                                 void* src_buffer, int /*start_x*/,
2600                                 int /*start_y*/, void* /*dst_frame*/) {
2601   auto* src = static_cast<int16_t*>(src_buffer);
2602   const bool should_round = kShouldRound[tx_size];
2603   const uint8_t row_shift = kTransformRowShift[tx_size];
2604 
2605   if (DctDcOnly<64>(src, adjusted_tx_height, should_round, row_shift)) {
2606     return;
2607   }
2608 
2609   if (should_round) {
2610     ApplyRounding<64>(src, adjusted_tx_height);
2611   }
2612   // Process 8 1d dct64 rows in parallel per iteration.
2613   int i = 0;
2614   do {
2615     Dct64_NEON(&src[i * 64], 64, /*is_row=*/true, row_shift);
2616     i += 8;
2617   } while (i < adjusted_tx_height);
2618 }
2619 
Dct64TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2620 void Dct64TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2621                                    int adjusted_tx_height,
2622                                    void* LIBGAV1_RESTRICT src_buffer,
2623                                    int start_x, int start_y,
2624                                    void* LIBGAV1_RESTRICT dst_frame) {
2625   auto* src = static_cast<int16_t*>(src_buffer);
2626   const int tx_width = kTransformWidth[tx_size];
2627 
2628   if (!DctDcOnlyColumn<64>(src, adjusted_tx_height, tx_width)) {
2629     // Process 8 1d dct64 columns in parallel per iteration.
2630     int i = tx_width;
2631     auto* data = src;
2632     do {
2633       Dct64_NEON(data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2634       data += 8;
2635       i -= 8;
2636     } while (i != 0);
2637   }
2638   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2639   StoreToFrameWithRound<64>(frame, start_x, start_y, tx_width, src, tx_type);
2640 }
2641 
Adst4TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2642 void Adst4TransformLoopRow_NEON(TransformType /*tx_type*/,
2643                                 TransformSize tx_size, int adjusted_tx_height,
2644                                 void* src_buffer, int /*start_x*/,
2645                                 int /*start_y*/, void* /*dst_frame*/) {
2646   auto* src = static_cast<int16_t*>(src_buffer);
2647   const int tx_height = kTransformHeight[tx_size];
2648   const int row_shift = static_cast<int>(tx_height == 16);
2649   const bool should_round = (tx_height == 8);
2650 
2651   if (Adst4DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2652     return;
2653   }
2654 
2655   if (should_round) {
2656     ApplyRounding<4>(src, adjusted_tx_height);
2657   }
2658 
2659   // Process 4 1d adst4 rows in parallel per iteration.
2660   int i = adjusted_tx_height;
2661   auto* data = src;
2662   do {
2663     Adst4_NEON(data, /*step=*/4, /*transpose=*/true);
2664     data += 16;
2665     i -= 4;
2666   } while (i != 0);
2667 
2668   if (tx_height == 16) {
2669     RowShift<4>(src, adjusted_tx_height, 1);
2670   }
2671 }
2672 
Adst4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2673 void Adst4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2674                                    int adjusted_tx_height,
2675                                    void* LIBGAV1_RESTRICT src_buffer,
2676                                    int start_x, int start_y,
2677                                    void* LIBGAV1_RESTRICT dst_frame) {
2678   auto* src = static_cast<int16_t*>(src_buffer);
2679   const int tx_width = kTransformWidth[tx_size];
2680 
2681   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2682     FlipColumns<4>(src, tx_width);
2683   }
2684 
2685   if (!Adst4DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2686     // Process 4 1d adst4 columns in parallel per iteration.
2687     int i = tx_width;
2688     auto* data = src;
2689     do {
2690       Adst4_NEON(data, tx_width, /*transpose=*/false);
2691       data += 4;
2692       i -= 4;
2693     } while (i != 0);
2694   }
2695 
2696   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2697   StoreToFrameWithRound<4, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2698                                                       tx_width, src, tx_type);
2699 }
2700 
Adst8TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2701 void Adst8TransformLoopRow_NEON(TransformType /*tx_type*/,
2702                                 TransformSize tx_size, int adjusted_tx_height,
2703                                 void* src_buffer, int /*start_x*/,
2704                                 int /*start_y*/, void* /*dst_frame*/) {
2705   auto* src = static_cast<int16_t*>(src_buffer);
2706   const bool should_round = kShouldRound[tx_size];
2707   const uint8_t row_shift = kTransformRowShift[tx_size];
2708 
2709   if (Adst8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2710     return;
2711   }
2712 
2713   if (should_round) {
2714     ApplyRounding<8>(src, adjusted_tx_height);
2715   }
2716 
2717   if (adjusted_tx_height == 4) {
2718     // Process 4 1d adst8 rows in parallel.
2719     Adst8_NEON<ButterflyRotation_4, true>(src, /*step=*/8, /*transpose=*/true);
2720   } else {
2721     // Process 8 1d adst8 rows in parallel per iteration.
2722     assert(adjusted_tx_height % 8 == 0);
2723     int i = adjusted_tx_height;
2724     auto* data = src;
2725     do {
2726       Adst8_NEON<ButterflyRotation_8, false>(data, /*step=*/8,
2727                                              /*transpose=*/true);
2728       data += 64;
2729       i -= 8;
2730     } while (i != 0);
2731   }
2732   if (row_shift > 0) {
2733     RowShift<8>(src, adjusted_tx_height, row_shift);
2734   }
2735 }
2736 
Adst8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2737 void Adst8TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
2738                                    int adjusted_tx_height,
2739                                    void* LIBGAV1_RESTRICT src_buffer,
2740                                    int start_x, int start_y,
2741                                    void* LIBGAV1_RESTRICT dst_frame) {
2742   auto* src = static_cast<int16_t*>(src_buffer);
2743   const int tx_width = kTransformWidth[tx_size];
2744 
2745   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2746     FlipColumns<8>(src, tx_width);
2747   }
2748 
2749   if (!Adst8DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2750     if (tx_width == 4) {
2751       // Process 4 1d adst8 columns in parallel.
2752       Adst8_NEON<ButterflyRotation_4, true>(src, 4, /*transpose=*/false);
2753     } else {
2754       // Process 8 1d adst8 columns in parallel per iteration.
2755       int i = tx_width;
2756       auto* data = src;
2757       do {
2758         Adst8_NEON<ButterflyRotation_8, false>(data, tx_width,
2759                                                /*transpose=*/false);
2760         data += 8;
2761         i -= 8;
2762       } while (i != 0);
2763     }
2764   }
2765   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2766   StoreToFrameWithRound<8, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2767                                                       tx_width, src, tx_type);
2768 }
2769 
Adst16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2770 void Adst16TransformLoopRow_NEON(TransformType /*tx_type*/,
2771                                  TransformSize tx_size, int adjusted_tx_height,
2772                                  void* src_buffer, int /*start_x*/,
2773                                  int /*start_y*/, void* /*dst_frame*/) {
2774   auto* src = static_cast<int16_t*>(src_buffer);
2775   const bool should_round = kShouldRound[tx_size];
2776   const uint8_t row_shift = kTransformRowShift[tx_size];
2777 
2778   if (Adst16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2779     return;
2780   }
2781 
2782   if (should_round) {
2783     ApplyRounding<16>(src, adjusted_tx_height);
2784   }
2785 
2786   if (adjusted_tx_height == 4) {
2787     // Process 4 1d adst16 rows in parallel.
2788     Adst16_NEON<ButterflyRotation_4, true>(src, 16, /*is_row=*/true, row_shift);
2789   } else {
2790     assert(adjusted_tx_height % 8 == 0);
2791     int i = adjusted_tx_height;
2792     do {
2793       // Process 8 1d adst16 rows in parallel per iteration.
2794       Adst16_NEON<ButterflyRotation_8, false>(src, 16, /*is_row=*/true,
2795                                               row_shift);
2796       src += 128;
2797       i -= 8;
2798     } while (i != 0);
2799   }
2800 }
2801 
Adst16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2802 void Adst16TransformLoopColumn_NEON(TransformType tx_type,
2803                                     TransformSize tx_size,
2804                                     int adjusted_tx_height,
2805                                     void* LIBGAV1_RESTRICT src_buffer,
2806                                     int start_x, int start_y,
2807                                     void* LIBGAV1_RESTRICT dst_frame) {
2808   auto* src = static_cast<int16_t*>(src_buffer);
2809   const int tx_width = kTransformWidth[tx_size];
2810 
2811   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2812     FlipColumns<16>(src, tx_width);
2813   }
2814 
2815   if (!Adst16DcOnlyColumn(src, adjusted_tx_height, tx_width)) {
2816     if (tx_width == 4) {
2817       // Process 4 1d adst16 columns in parallel.
2818       Adst16_NEON<ButterflyRotation_4, true>(src, 4, /*is_row=*/false,
2819                                              /*row_shift=*/0);
2820     } else {
2821       int i = tx_width;
2822       auto* data = src;
2823       do {
2824         // Process 8 1d adst16 columns in parallel per iteration.
2825         Adst16_NEON<ButterflyRotation_8, false>(
2826             data, tx_width, /*is_row=*/false, /*row_shift=*/0);
2827         data += 8;
2828         i -= 8;
2829       } while (i != 0);
2830     }
2831   }
2832   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2833   StoreToFrameWithRound<16, /*enable_flip_rows=*/true>(frame, start_x, start_y,
2834                                                        tx_width, src, tx_type);
2835 }
2836 
Identity4TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2837 void Identity4TransformLoopRow_NEON(TransformType tx_type,
2838                                     TransformSize tx_size,
2839                                     int adjusted_tx_height, void* src_buffer,
2840                                     int /*start_x*/, int /*start_y*/,
2841                                     void* /*dst_frame*/) {
2842   // Special case: Process row calculations during column transform call.
2843   // Improves performance.
2844   if (tx_type == kTransformTypeIdentityIdentity &&
2845       tx_size == kTransformSize4x4) {
2846     return;
2847   }
2848 
2849   auto* src = static_cast<int16_t*>(src_buffer);
2850   const int tx_height = kTransformHeight[tx_size];
2851   const bool should_round = (tx_height == 8);
2852 
2853   if (Identity4DcOnly(src, adjusted_tx_height, should_round, tx_height)) {
2854     return;
2855   }
2856 
2857   if (should_round) {
2858     ApplyRounding<4>(src, adjusted_tx_height);
2859   }
2860   if (tx_height < 16) {
2861     int i = adjusted_tx_height;
2862     do {
2863       Identity4_NEON<false>(src, /*step=*/4);
2864       src += 16;
2865       i -= 4;
2866     } while (i != 0);
2867   } else {
2868     int i = adjusted_tx_height;
2869     do {
2870       Identity4_NEON<true>(src, /*step=*/4);
2871       src += 16;
2872       i -= 4;
2873     } while (i != 0);
2874   }
2875 }
2876 
Identity4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2877 void Identity4TransformLoopColumn_NEON(TransformType tx_type,
2878                                        TransformSize tx_size,
2879                                        int adjusted_tx_height,
2880                                        void* LIBGAV1_RESTRICT src_buffer,
2881                                        int start_x, int start_y,
2882                                        void* LIBGAV1_RESTRICT dst_frame) {
2883   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2884   auto* src = static_cast<int16_t*>(src_buffer);
2885   const int tx_width = kTransformWidth[tx_size];
2886 
2887   // Special case: Process row calculations during column transform call.
2888   if (tx_type == kTransformTypeIdentityIdentity &&
2889       (tx_size == kTransformSize4x4 || tx_size == kTransformSize8x4)) {
2890     Identity4RowColumnStoreToFrame(frame, start_x, start_y, tx_width,
2891                                    adjusted_tx_height, src);
2892     return;
2893   }
2894 
2895   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2896     FlipColumns<4>(src, tx_width);
2897   }
2898 
2899   IdentityColumnStoreToFrame<4>(frame, start_x, start_y, tx_width,
2900                                 adjusted_tx_height, src);
2901 }
2902 
Identity8TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2903 void Identity8TransformLoopRow_NEON(TransformType tx_type,
2904                                     TransformSize tx_size,
2905                                     int adjusted_tx_height, void* src_buffer,
2906                                     int /*start_x*/, int /*start_y*/,
2907                                     void* /*dst_frame*/) {
2908   // Special case: Process row calculations during column transform call.
2909   // Improves performance.
2910   if (tx_type == kTransformTypeIdentityIdentity &&
2911       tx_size == kTransformSize8x4) {
2912     return;
2913   }
2914 
2915   auto* src = static_cast<int16_t*>(src_buffer);
2916   const int tx_height = kTransformHeight[tx_size];
2917   const bool should_round = kShouldRound[tx_size];
2918   const uint8_t row_shift = kTransformRowShift[tx_size];
2919 
2920   if (Identity8DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2921     return;
2922   }
2923 
2924   if (should_round) {
2925     ApplyRounding<8>(src, adjusted_tx_height);
2926   }
2927 
2928   // When combining the identity8 multiplier with the row shift, the
2929   // calculations for tx_height == 8 and tx_height == 16 can be simplified
2930   // from ((A * 2) + 1) >> 1) to A.
2931   if ((tx_height & 0x18) != 0) {
2932     return;
2933   }
2934   if (tx_height == 32) {
2935     int i = adjusted_tx_height;
2936     do {
2937       Identity8Row32_NEON(src, /*step=*/8);
2938       src += 32;
2939       i -= 4;
2940     } while (i != 0);
2941     return;
2942   }
2943 
2944   assert(tx_size == kTransformSize8x4);
2945   int i = adjusted_tx_height;
2946   do {
2947     Identity8Row4_NEON(src, /*step=*/8);
2948     src += 32;
2949     i -= 4;
2950   } while (i != 0);
2951 }
2952 
Identity8TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2953 void Identity8TransformLoopColumn_NEON(TransformType tx_type,
2954                                        TransformSize tx_size,
2955                                        int adjusted_tx_height,
2956                                        void* LIBGAV1_RESTRICT src_buffer,
2957                                        int start_x, int start_y,
2958                                        void* LIBGAV1_RESTRICT dst_frame) {
2959   auto* src = static_cast<int16_t*>(src_buffer);
2960   const int tx_width = kTransformWidth[tx_size];
2961 
2962   if (kTransformFlipColumnsMask.Contains(tx_type)) {
2963     FlipColumns<8>(src, tx_width);
2964   }
2965   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
2966   IdentityColumnStoreToFrame<8>(frame, start_x, start_y, tx_width,
2967                                 adjusted_tx_height, src);
2968 }
2969 
Identity16TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)2970 void Identity16TransformLoopRow_NEON(TransformType /*tx_type*/,
2971                                      TransformSize tx_size,
2972                                      int adjusted_tx_height, void* src_buffer,
2973                                      int /*start_x*/, int /*start_y*/,
2974                                      void* /*dst_frame*/) {
2975   auto* src = static_cast<int16_t*>(src_buffer);
2976   const bool should_round = kShouldRound[tx_size];
2977   const uint8_t row_shift = kTransformRowShift[tx_size];
2978 
2979   if (Identity16DcOnly(src, adjusted_tx_height, should_round, row_shift)) {
2980     return;
2981   }
2982 
2983   if (should_round) {
2984     ApplyRounding<16>(src, adjusted_tx_height);
2985   }
2986   int i = adjusted_tx_height;
2987   do {
2988     Identity16Row_NEON(src, /*step=*/16, kTransformRowShift[tx_size]);
2989     src += 64;
2990     i -= 4;
2991   } while (i != 0);
2992 }
2993 
Identity16TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)2994 void Identity16TransformLoopColumn_NEON(TransformType tx_type,
2995                                         TransformSize tx_size,
2996                                         int adjusted_tx_height,
2997                                         void* LIBGAV1_RESTRICT src_buffer,
2998                                         int start_x, int start_y,
2999                                         void* LIBGAV1_RESTRICT dst_frame) {
3000   auto* src = static_cast<int16_t*>(src_buffer);
3001   const int tx_width = kTransformWidth[tx_size];
3002 
3003   if (kTransformFlipColumnsMask.Contains(tx_type)) {
3004     FlipColumns<16>(src, tx_width);
3005   }
3006   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
3007   IdentityColumnStoreToFrame<16>(frame, start_x, start_y, tx_width,
3008                                  adjusted_tx_height, src);
3009 }
3010 
Identity32TransformLoopRow_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * src_buffer,int,int,void *)3011 void Identity32TransformLoopRow_NEON(TransformType /*tx_type*/,
3012                                      TransformSize tx_size,
3013                                      int adjusted_tx_height, void* src_buffer,
3014                                      int /*start_x*/, int /*start_y*/,
3015                                      void* /*dst_frame*/) {
3016   const int tx_height = kTransformHeight[tx_size];
3017 
3018   // When combining the identity32 multiplier with the row shift, the
3019   // calculations for tx_height == 8 and tx_height == 32 can be simplified
3020   // from ((A * 4) + 2) >> 2) to A.
3021   if ((tx_height & 0x28) != 0) {
3022     return;
3023   }
3024 
3025   // Process kTransformSize32x16.  The src is always rounded before the
3026   // identity transform and shifted by 1 afterwards.
3027   auto* src = static_cast<int16_t*>(src_buffer);
3028   if (Identity32DcOnly(src, adjusted_tx_height)) {
3029     return;
3030   }
3031 
3032   assert(tx_size == kTransformSize32x16);
3033   ApplyRounding<32>(src, adjusted_tx_height);
3034   int i = adjusted_tx_height;
3035   do {
3036     Identity32Row16_NEON(src, /*step=*/32);
3037     src += 128;
3038     i -= 4;
3039   } while (i != 0);
3040 }
3041 
Identity32TransformLoopColumn_NEON(TransformType,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)3042 void Identity32TransformLoopColumn_NEON(TransformType /*tx_type*/,
3043                                         TransformSize tx_size,
3044                                         int adjusted_tx_height,
3045                                         void* LIBGAV1_RESTRICT src_buffer,
3046                                         int start_x, int start_y,
3047                                         void* LIBGAV1_RESTRICT dst_frame) {
3048   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
3049   auto* src = static_cast<int16_t*>(src_buffer);
3050   const int tx_width = kTransformWidth[tx_size];
3051 
3052   IdentityColumnStoreToFrame<32>(frame, start_x, start_y, tx_width,
3053                                  adjusted_tx_height, src);
3054 }
3055 
Wht4TransformLoopRow_NEON(TransformType tx_type,TransformSize tx_size,int,void *,int,int,void *)3056 void Wht4TransformLoopRow_NEON(TransformType tx_type, TransformSize tx_size,
3057                                int /*adjusted_tx_height*/, void* /*src_buffer*/,
3058                                int /*start_x*/, int /*start_y*/,
3059                                void* /*dst_frame*/) {
3060   assert(tx_type == kTransformTypeDctDct);
3061   assert(tx_size == kTransformSize4x4);
3062   static_cast<void>(tx_type);
3063   static_cast<void>(tx_size);
3064   // Do both row and column transforms in the column-transform pass.
3065 }
3066 
Wht4TransformLoopColumn_NEON(TransformType tx_type,TransformSize tx_size,int adjusted_tx_height,void * LIBGAV1_RESTRICT src_buffer,int start_x,int start_y,void * LIBGAV1_RESTRICT dst_frame)3067 void Wht4TransformLoopColumn_NEON(TransformType tx_type, TransformSize tx_size,
3068                                   int adjusted_tx_height,
3069                                   void* LIBGAV1_RESTRICT src_buffer,
3070                                   int start_x, int start_y,
3071                                   void* LIBGAV1_RESTRICT dst_frame) {
3072   assert(tx_type == kTransformTypeDctDct);
3073   assert(tx_size == kTransformSize4x4);
3074   static_cast<void>(tx_type);
3075   static_cast<void>(tx_size);
3076 
3077   // Process 4 1d wht4 rows and columns in parallel.
3078   const auto* src = static_cast<int16_t*>(src_buffer);
3079   auto& frame = *static_cast<Array2DView<uint8_t>*>(dst_frame);
3080   uint8_t* dst = frame[start_y] + start_x;
3081   const int dst_stride = frame.columns();
3082   Wht4_NEON(dst, dst_stride, src, adjusted_tx_height);
3083 }
3084 
3085 //------------------------------------------------------------------------------
3086 
Init8bpp()3087 void Init8bpp() {
3088   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
3089   assert(dsp != nullptr);
3090   // Maximum transform size for Dct is 64.
3091   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kRow] =
3092       Dct4TransformLoopRow_NEON;
3093   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize4][kColumn] =
3094       Dct4TransformLoopColumn_NEON;
3095   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kRow] =
3096       Dct8TransformLoopRow_NEON;
3097   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize8][kColumn] =
3098       Dct8TransformLoopColumn_NEON;
3099   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kRow] =
3100       Dct16TransformLoopRow_NEON;
3101   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize16][kColumn] =
3102       Dct16TransformLoopColumn_NEON;
3103   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kRow] =
3104       Dct32TransformLoopRow_NEON;
3105   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize32][kColumn] =
3106       Dct32TransformLoopColumn_NEON;
3107   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kRow] =
3108       Dct64TransformLoopRow_NEON;
3109   dsp->inverse_transforms[kTransform1dDct][kTransform1dSize64][kColumn] =
3110       Dct64TransformLoopColumn_NEON;
3111 
3112   // Maximum transform size for Adst is 16.
3113   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kRow] =
3114       Adst4TransformLoopRow_NEON;
3115   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize4][kColumn] =
3116       Adst4TransformLoopColumn_NEON;
3117   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kRow] =
3118       Adst8TransformLoopRow_NEON;
3119   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize8][kColumn] =
3120       Adst8TransformLoopColumn_NEON;
3121   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kRow] =
3122       Adst16TransformLoopRow_NEON;
3123   dsp->inverse_transforms[kTransform1dAdst][kTransform1dSize16][kColumn] =
3124       Adst16TransformLoopColumn_NEON;
3125 
3126   // Maximum transform size for Identity transform is 32.
3127   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kRow] =
3128       Identity4TransformLoopRow_NEON;
3129   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize4][kColumn] =
3130       Identity4TransformLoopColumn_NEON;
3131   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kRow] =
3132       Identity8TransformLoopRow_NEON;
3133   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize8][kColumn] =
3134       Identity8TransformLoopColumn_NEON;
3135   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kRow] =
3136       Identity16TransformLoopRow_NEON;
3137   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize16][kColumn] =
3138       Identity16TransformLoopColumn_NEON;
3139   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kRow] =
3140       Identity32TransformLoopRow_NEON;
3141   dsp->inverse_transforms[kTransform1dIdentity][kTransform1dSize32][kColumn] =
3142       Identity32TransformLoopColumn_NEON;
3143 
3144   // Maximum transform size for Wht is 4.
3145   dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kRow] =
3146       Wht4TransformLoopRow_NEON;
3147   dsp->inverse_transforms[kTransform1dWht][kTransform1dSize4][kColumn] =
3148       Wht4TransformLoopColumn_NEON;
3149 }
3150 
3151 }  // namespace
3152 }  // namespace low_bitdepth
3153 
InverseTransformInit_NEON()3154 void InverseTransformInit_NEON() { low_bitdepth::Init8bpp(); }
3155 
3156 }  // namespace dsp
3157 }  // namespace libgav1
3158 #else   // !LIBGAV1_ENABLE_NEON
3159 namespace libgav1 {
3160 namespace dsp {
3161 
InverseTransformInit_NEON()3162 void InverseTransformInit_NEON() {}
3163 
3164 }  // namespace dsp
3165 }  // namespace libgav1
3166 #endif  // LIBGAV1_ENABLE_NEON
3167