1 /*
2 * Copyright (c) 2017 The WebM project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <arm_neon.h>
12
13 #include "./vpx_dsp_rtcd.h"
14 #include "./vpx_config.h"
15 #include "vpx_dsp/arm/mem_neon.h"
16 #include "vpx_dsp/arm/sum_neon.h"
17
vpx_fdct4x4_1_neon(const int16_t * input,tran_low_t * output,int stride)18 void vpx_fdct4x4_1_neon(const int16_t *input, tran_low_t *output, int stride) {
19 int16x4_t a0, a1, a2, a3;
20 int16x8_t b0, b1;
21 int16x8_t c;
22
23 a0 = vld1_s16(input);
24 input += stride;
25 a1 = vld1_s16(input);
26 input += stride;
27 a2 = vld1_s16(input);
28 input += stride;
29 a3 = vld1_s16(input);
30
31 b0 = vcombine_s16(a0, a1);
32 b1 = vcombine_s16(a2, a3);
33
34 c = vaddq_s16(b0, b1);
35
36 output[0] = (tran_low_t)(horizontal_add_int16x8(c) << 1);
37 output[1] = 0;
38 }
39
40 // Visual Studio 2022 (cl.exe) targeting AArch64 with optimizations enabled
41 // will fail with an internal compiler error.
42 // See:
43 // https://developercommunity.visualstudio.com/t/Compiler-crash-C1001-when-building-a-for/10346110
44 // TODO(jzern): check the compiler version after a fix for the issue is
45 // released.
46 #if defined(_MSC_VER) && defined(_M_ARM64) && !defined(__clang__)
47 #pragma optimize("", off)
48 #endif
vpx_fdct8x8_1_neon(const int16_t * input,tran_low_t * output,int stride)49 void vpx_fdct8x8_1_neon(const int16_t *input, tran_low_t *output, int stride) {
50 int r;
51 int16x8_t sum = vld1q_s16(&input[0]);
52
53 for (r = 1; r < 8; ++r) {
54 const int16x8_t input_00 = vld1q_s16(&input[r * stride]);
55 sum = vaddq_s16(sum, input_00);
56 }
57
58 output[0] = (tran_low_t)horizontal_add_int16x8(sum);
59 output[1] = 0;
60 }
61 #if defined(_MSC_VER) && defined(_M_ARM64) && !defined(__clang__)
62 #pragma optimize("", on)
63 #endif
64
vpx_fdct16x16_1_neon(const int16_t * input,tran_low_t * output,int stride)65 void vpx_fdct16x16_1_neon(const int16_t *input, tran_low_t *output,
66 int stride) {
67 int r;
68 int16x8_t left = vld1q_s16(input);
69 int16x8_t right = vld1q_s16(input + 8);
70 int32_t sum;
71 input += stride;
72
73 for (r = 1; r < 16; ++r) {
74 const int16x8_t a = vld1q_s16(input);
75 const int16x8_t b = vld1q_s16(input + 8);
76 input += stride;
77 left = vaddq_s16(left, a);
78 right = vaddq_s16(right, b);
79 }
80
81 sum = horizontal_add_int16x8(left) + horizontal_add_int16x8(right);
82
83 output[0] = (tran_low_t)(sum >> 1);
84 output[1] = 0;
85 }
86
vpx_fdct32x32_1_neon(const int16_t * input,tran_low_t * output,int stride)87 void vpx_fdct32x32_1_neon(const int16_t *input, tran_low_t *output,
88 int stride) {
89 int r;
90 int16x8_t a0 = vld1q_s16(input);
91 int16x8_t a1 = vld1q_s16(input + 8);
92 int16x8_t a2 = vld1q_s16(input + 16);
93 int16x8_t a3 = vld1q_s16(input + 24);
94 int32_t sum;
95 input += stride;
96
97 for (r = 1; r < 32; ++r) {
98 const int16x8_t b0 = vld1q_s16(input);
99 const int16x8_t b1 = vld1q_s16(input + 8);
100 const int16x8_t b2 = vld1q_s16(input + 16);
101 const int16x8_t b3 = vld1q_s16(input + 24);
102 input += stride;
103 a0 = vaddq_s16(a0, b0);
104 a1 = vaddq_s16(a1, b1);
105 a2 = vaddq_s16(a2, b2);
106 a3 = vaddq_s16(a3, b3);
107 }
108
109 sum = horizontal_add_int16x8(a0);
110 sum += horizontal_add_int16x8(a1);
111 sum += horizontal_add_int16x8(a2);
112 sum += horizontal_add_int16x8(a3);
113 output[0] = (tran_low_t)(sum >> 3);
114 output[1] = 0;
115 }
116
117 #if CONFIG_VP9_HIGHBITDEPTH
118
vpx_highbd_fdct16x16_1_neon(const int16_t * input,tran_low_t * output,int stride)119 void vpx_highbd_fdct16x16_1_neon(const int16_t *input, tran_low_t *output,
120 int stride) {
121 int32x4_t partial_sum[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
122 vdupq_n_s32(0) };
123 int32_t sum;
124
125 int r = 0;
126 do {
127 const int16x8_t a = vld1q_s16(input);
128 const int16x8_t b = vld1q_s16(input + 8);
129 input += stride;
130 partial_sum[0] = vaddw_s16(partial_sum[0], vget_low_s16(a));
131 partial_sum[1] = vaddw_s16(partial_sum[1], vget_high_s16(a));
132 partial_sum[2] = vaddw_s16(partial_sum[2], vget_low_s16(b));
133 partial_sum[3] = vaddw_s16(partial_sum[3], vget_high_s16(b));
134 r++;
135 } while (r < 16);
136
137 partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[1]);
138 partial_sum[2] = vaddq_s32(partial_sum[2], partial_sum[3]);
139 partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[2]);
140 sum = horizontal_add_int32x4(partial_sum[0]);
141
142 output[0] = (tran_low_t)(sum >> 1);
143 output[1] = 0;
144 }
145
vpx_highbd_fdct32x32_1_neon(const int16_t * input,tran_low_t * output,int stride)146 void vpx_highbd_fdct32x32_1_neon(const int16_t *input, tran_low_t *output,
147 int stride) {
148 int32x4_t partial_sum[4] = { vdupq_n_s32(0), vdupq_n_s32(0), vdupq_n_s32(0),
149 vdupq_n_s32(0) };
150
151 int32_t sum;
152
153 int r = 0;
154 do {
155 const int16x8_t a0 = vld1q_s16(input);
156 const int16x8_t a1 = vld1q_s16(input + 8);
157 const int16x8_t a2 = vld1q_s16(input + 16);
158 const int16x8_t a3 = vld1q_s16(input + 24);
159 input += stride;
160 partial_sum[0] = vaddw_s16(partial_sum[0], vget_low_s16(a0));
161 partial_sum[0] = vaddw_s16(partial_sum[0], vget_high_s16(a0));
162 partial_sum[1] = vaddw_s16(partial_sum[1], vget_low_s16(a1));
163 partial_sum[1] = vaddw_s16(partial_sum[1], vget_high_s16(a1));
164 partial_sum[2] = vaddw_s16(partial_sum[2], vget_low_s16(a2));
165 partial_sum[2] = vaddw_s16(partial_sum[2], vget_high_s16(a2));
166 partial_sum[3] = vaddw_s16(partial_sum[3], vget_low_s16(a3));
167 partial_sum[3] = vaddw_s16(partial_sum[3], vget_high_s16(a3));
168 r++;
169 } while (r < 32);
170
171 partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[1]);
172 partial_sum[2] = vaddq_s32(partial_sum[2], partial_sum[3]);
173 partial_sum[0] = vaddq_s32(partial_sum[0], partial_sum[2]);
174 sum = horizontal_add_int32x4(partial_sum[0]);
175
176 output[0] = (tran_low_t)(sum >> 3);
177 output[1] = 0;
178 }
179
180 #endif // CONFIG_VP9_HIGHBITDEPTH
181