xref: /aosp_15_r20/external/libvpx/vpx_dsp/arm/fdct_partial_neon.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2017 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <arm_neon.h>
12 
13 #include "./vpx_dsp_rtcd.h"
14 #include "./vpx_config.h"
15 #include "vpx_dsp/arm/mem_neon.h"
16 #include "vpx_dsp/arm/sum_neon.h"
17 
vpx_fdct4x4_1_neon(const int16_t * input,tran_low_t * output,int stride)18 void vpx_fdct4x4_1_neon(const int16_t *input, tran_low_t *output, int stride) {
19   int16x4_t a0, a1, a2, a3;
20   int16x8_t b0, b1;
21   int16x8_t c;
22 
23   a0 = vld1_s16(input);
24   input += stride;
25   a1 = vld1_s16(input);
26   input += stride;
27   a2 = vld1_s16(input);
28   input += stride;
29   a3 = vld1_s16(input);
30 
31   b0 = vcombine_s16(a0, a1);
32   b1 = vcombine_s16(a2, a3);
33 
34   c = vaddq_s16(b0, b1);
35 
36   output[0] = (tran_low_t)(horizontal_add_int16x8(c) << 1);
37   output[1] = 0;
38 }
39 
40 // Visual Studio 2022 (cl.exe) targeting AArch64 with optimizations enabled
41 // will fail with an internal compiler error.
42 // See:
43 // https://developercommunity.visualstudio.com/t/Compiler-crash-C1001-when-building-a-for/10346110
44 // TODO(jzern): check the compiler version after a fix for the issue is
45 // released.
46 #if defined(_MSC_VER) && defined(_M_ARM64) && !defined(__clang__)
47 #pragma optimize("", off)
48 #endif
vpx_fdct8x8_1_neon(const int16_t * input,tran_low_t * output,int stride)49 void vpx_fdct8x8_1_neon(const int16_t *input, tran_low_t *output, int stride) {
50   int r;
51   int16x8_t sum = vld1q_s16(&input[0]);
52 
53   for (r = 1; r < 8; ++r) {
54     const int16x8_t input_00 = vld1q_s16(&input[r * stride]);
55     sum = vaddq_s16(sum, input_00);
56   }
57 
58   output[0] = (tran_low_t)horizontal_add_int16x8(sum);
59   output[1] = 0;
60 }
61 #if defined(_MSC_VER) && defined(_M_ARM64) && !defined(__clang__)
62 #pragma optimize("", on)
63 #endif
64 
vpx_fdct16x16_1_neon(const int16_t * input,tran_low_t * output,int stride)65 void vpx_fdct16x16_1_neon(const int16_t *input, tran_low_t *output,
66                           int stride) {
67   int r;
68   int16x8_t left = vld1q_s16(input);
69   int16x8_t right = vld1q_s16(input + 8);
70   int32_t sum;
71   input += stride;
72 
73   for (r = 1; r < 16; ++r) {
74     const int16x8_t a = vld1q_s16(input);
75     const int16x8_t b = vld1q_s16(input + 8);
76     input += stride;
77     left = vaddq_s16(left, a);
78     right = vaddq_s16(right, b);
79   }
80 
81   sum = horizontal_add_int16x8(left) + horizontal_add_int16x8(right);
82 
83   output[0] = (tran_low_t)(sum >> 1);
84   output[1] = 0;
85 }
86 
vpx_fdct32x32_1_neon(const int16_t * input,tran_low_t * output,int stride)87 void vpx_fdct32x32_1_neon(const int16_t *input, tran_low_t *output,
88                           int stride) {
89   int r;
90   int16x8_t a0 = vld1q_s16(input);
91   int16x8_t a1 = vld1q_s16(input + 8);
92   int16x8_t a2 = vld1q_s16(input + 16);
93   int16x8_t a3 = vld1q_s16(input + 24);
94   int32_t sum;
95   input += stride;
96 
97   for (r = 1; r < 32; ++r) {
98     const int16x8_t b0 = vld1q_s16(input);
99     const int16x8_t b1 = vld1q_s16(input + 8);
100     const int16x8_t b2 = vld1q_s16(input + 16);
101     const int16x8_t b3 = vld1q_s16(input + 24);
102     input += stride;
103     a0 = vaddq_s16(a0, b0);
104     a1 = vaddq_s16(a1, b1);
105     a2 = vaddq_s16(a2, b2);
106     a3 = vaddq_s16(a3, b3);
107   }
108 
109   sum = horizontal_add_int16x8(a0);
110   sum += horizontal_add_int16x8(a1);
111   sum += horizontal_add_int16x8(a2);
112   sum += horizontal_add_int16x8(a3);
113   output[0] = (tran_low_t)(sum >> 3);
114   output[1] = 0;
115 }
116 
117 #if CONFIG_VP9_HIGHBITDEPTH
118 
vpx_highbd_fdct16x16_1_neon(const int16_t * input,tran_low_t * output,int stride)119 void vpx_highbd_fdct16x16_1_neon(const int16_t *input, tran_low_t *output,
120                                  int stride) {
121   int32x4_t partial_sum[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
122                                vdupq_n_s32(0) };
123   int32_t sum;
124 
125   int r = 0;
126   do {
127     const int16x8_t a = vld1q_s16(input);
128     const int16x8_t b = vld1q_s16(input + 8);
129     input += stride;
130     partial_sum[0] = vaddw_s16(partial_sum[0], vget_low_s16(a));
131     partial_sum[1] = vaddw_s16(partial_sum[1], vget_high_s16(a));
132     partial_sum[2] = vaddw_s16(partial_sum[2], vget_low_s16(b));
133     partial_sum[3] = vaddw_s16(partial_sum[3], vget_high_s16(b));
134     r++;
135   } while (r < 16);
136 
137   partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[1]);
138   partial_sum[2] = vaddq_s32(partial_sum[2], partial_sum[3]);
139   partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[2]);
140   sum = horizontal_add_int32x4(partial_sum[0]);
141 
142   output[0] = (tran_low_t)(sum >> 1);
143   output[1] = 0;
144 }
145 
vpx_highbd_fdct32x32_1_neon(const int16_t * input,tran_low_t * output,int stride)146 void vpx_highbd_fdct32x32_1_neon(const int16_t *input, tran_low_t *output,
147                                  int stride) {
148   int32x4_t partial_sum[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
149                                vdupq_n_s32(0) };
150 
151   int32_t sum;
152 
153   int r = 0;
154   do {
155     const int16x8_t a0 = vld1q_s16(input);
156     const int16x8_t a1 = vld1q_s16(input + 8);
157     const int16x8_t a2 = vld1q_s16(input + 16);
158     const int16x8_t a3 = vld1q_s16(input + 24);
159     input += stride;
160     partial_sum[0] = vaddw_s16(partial_sum[0], vget_low_s16(a0));
161     partial_sum[0] = vaddw_s16(partial_sum[0], vget_high_s16(a0));
162     partial_sum[1] = vaddw_s16(partial_sum[1], vget_low_s16(a1));
163     partial_sum[1] = vaddw_s16(partial_sum[1], vget_high_s16(a1));
164     partial_sum[2] = vaddw_s16(partial_sum[2], vget_low_s16(a2));
165     partial_sum[2] = vaddw_s16(partial_sum[2], vget_high_s16(a2));
166     partial_sum[3] = vaddw_s16(partial_sum[3], vget_low_s16(a3));
167     partial_sum[3] = vaddw_s16(partial_sum[3], vget_high_s16(a3));
168     r++;
169   } while (r < 32);
170 
171   partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[1]);
172   partial_sum[2] = vaddq_s32(partial_sum[2], partial_sum[3]);
173   partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[2]);
174   sum = horizontal_add_int32x4(partial_sum[0]);
175 
176   output[0] = (tran_low_t)(sum >> 3);
177   output[1] = 0;
178 }
179 
180 #endif  // CONFIG_VP9_HIGHBITDEPTH
181