xref: /aosp_15_r20/external/libgav1/src/dsp/arm/super_res_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/super_res.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include "src/dsp/arm/common_neon.h"
23 #include "src/dsp/constants.h"
24 #include "src/dsp/dsp.h"
25 #include "src/utils/common.h"
26 #include "src/utils/compiler_attributes.h"
27 #include "src/utils/constants.h"
28 
29 namespace libgav1 {
30 namespace dsp {
31 
32 namespace low_bitdepth {
33 namespace {
34 
SuperResCoefficients_NEON(const int upscaled_width,const int initial_subpixel_x,const int step,void * const coefficients)35 void SuperResCoefficients_NEON(const int upscaled_width,
36                                const int initial_subpixel_x, const int step,
37                                void* const coefficients) {
38   auto* dst = static_cast<uint8_t*>(coefficients);
39   int subpixel_x = initial_subpixel_x;
40   int x = RightShiftWithCeiling(upscaled_width, 3);
41   do {
42     uint8x8_t filter[8];
43     uint8x16_t d[kSuperResFilterTaps / 2];
44     for (int i = 0; i < 8; ++i, subpixel_x += step) {
45       filter[i] =
46           vld1_u8(kUpscaleFilterUnsigned[(subpixel_x & kSuperResScaleMask) >>
47                                          kSuperResExtraBits]);
48     }
49     Transpose8x8(filter, d);
50     vst1q_u8(dst, d[0]);
51     dst += 16;
52     vst1q_u8(dst, d[1]);
53     dst += 16;
54     vst1q_u8(dst, d[2]);
55     dst += 16;
56     vst1q_u8(dst, d[3]);
57     dst += 16;
58   } while (--x != 0);
59 }
60 
61 // Maximum sum of positive taps: 171 = 7 + 86 + 71 + 7
62 // Maximum sum: 255*171 == 0xAA55
63 // The sum is clipped to [0, 255], so adding all positive and then
64 // subtracting all negative with saturation is sufficient.
65 //           0 1 2 3 4 5 6 7
66 // tap sign: - + - + + - + -
SuperRes(const uint8x8_t src[kSuperResFilterTaps],const uint8_t ** coefficients)67 inline uint8x8_t SuperRes(const uint8x8_t src[kSuperResFilterTaps],
68                           const uint8_t** coefficients) {
69   uint8x16_t f[kSuperResFilterTaps / 2];
70   for (int i = 0; i < kSuperResFilterTaps / 2; ++i, *coefficients += 16) {
71     f[i] = vld1q_u8(*coefficients);
72   }
73   uint16x8_t res = vmull_u8(src[1], vget_high_u8(f[0]));
74   res = vmlal_u8(res, src[3], vget_high_u8(f[1]));
75   res = vmlal_u8(res, src[4], vget_low_u8(f[2]));
76   res = vmlal_u8(res, src[6], vget_low_u8(f[3]));
77   uint16x8_t temp = vmull_u8(src[0], vget_low_u8(f[0]));
78   temp = vmlal_u8(temp, src[2], vget_low_u8(f[1]));
79   temp = vmlal_u8(temp, src[5], vget_high_u8(f[2]));
80   temp = vmlal_u8(temp, src[7], vget_high_u8(f[3]));
81   res = vqsubq_u16(res, temp);
82   return vqrshrn_n_u16(res, kFilterBits);
83 }
84 
SuperRes_NEON(const void * LIBGAV1_RESTRICT const coefficients,void * LIBGAV1_RESTRICT const source,const ptrdiff_t source_stride,const int height,const int downscaled_width,const int upscaled_width,const int initial_subpixel_x,const int step,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)85 void SuperRes_NEON(const void* LIBGAV1_RESTRICT const coefficients,
86                    void* LIBGAV1_RESTRICT const source,
87                    const ptrdiff_t source_stride, const int height,
88                    const int downscaled_width, const int upscaled_width,
89                    const int initial_subpixel_x, const int step,
90                    void* LIBGAV1_RESTRICT const dest,
91                    const ptrdiff_t dest_stride) {
92   auto* src = static_cast<uint8_t*>(source) - DivideBy2(kSuperResFilterTaps);
93   auto* dst = static_cast<uint8_t*>(dest);
94   int y = height;
95   do {
96     const auto* filter = static_cast<const uint8_t*>(coefficients);
97     uint8_t* dst_ptr = dst;
98 #if LIBGAV1_MSAN
99     // Initialize the padding area to prevent msan warnings.
100     const int super_res_right_border = kSuperResHorizontalPadding;
101 #else
102     const int super_res_right_border = kSuperResHorizontalBorder;
103 #endif
104     ExtendLine<uint8_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width,
105                         kSuperResHorizontalBorder, super_res_right_border);
106     int subpixel_x = initial_subpixel_x;
107     uint8x8_t sr[8];
108     uint8x16_t s[8];
109     int x = RightShiftWithCeiling(upscaled_width, 4);
110     // The below code calculates up to 15 extra upscaled
111     // pixels which will over-read up to 15 downscaled pixels in the end of each
112     // row. kSuperResHorizontalPadding accounts for this.
113     do {
114       for (int i = 0; i < 8; ++i, subpixel_x += step) {
115         sr[i] = vld1_u8(&src[subpixel_x >> kSuperResScaleBits]);
116       }
117       for (int i = 0; i < 8; ++i, subpixel_x += step) {
118         const uint8x8_t s_hi = vld1_u8(&src[subpixel_x >> kSuperResScaleBits]);
119         s[i] = vcombine_u8(sr[i], s_hi);
120       }
121       Transpose8x16(s);
122       // Do not use loop for the following 8 instructions, since the compiler
123       // will generate redundant code.
124       sr[0] = vget_low_u8(s[0]);
125       sr[1] = vget_low_u8(s[1]);
126       sr[2] = vget_low_u8(s[2]);
127       sr[3] = vget_low_u8(s[3]);
128       sr[4] = vget_low_u8(s[4]);
129       sr[5] = vget_low_u8(s[5]);
130       sr[6] = vget_low_u8(s[6]);
131       sr[7] = vget_low_u8(s[7]);
132       const uint8x8_t d0 = SuperRes(sr, &filter);
133       // Do not use loop for the following 8 instructions, since the compiler
134       // will generate redundant code.
135       sr[0] = vget_high_u8(s[0]);
136       sr[1] = vget_high_u8(s[1]);
137       sr[2] = vget_high_u8(s[2]);
138       sr[3] = vget_high_u8(s[3]);
139       sr[4] = vget_high_u8(s[4]);
140       sr[5] = vget_high_u8(s[5]);
141       sr[6] = vget_high_u8(s[6]);
142       sr[7] = vget_high_u8(s[7]);
143       const uint8x8_t d1 = SuperRes(sr, &filter);
144       vst1q_u8(dst_ptr, vcombine_u8(d0, d1));
145       dst_ptr += 16;
146     } while (--x != 0);
147     src += source_stride;
148     dst += dest_stride;
149   } while (--y != 0);
150 }
151 
Init8bpp()152 void Init8bpp() {
153   Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
154   dsp->super_res_coefficients = SuperResCoefficients_NEON;
155   dsp->super_res = SuperRes_NEON;
156 }
157 
158 }  // namespace
159 }  // namespace low_bitdepth
160 
161 //------------------------------------------------------------------------------
162 #if LIBGAV1_MAX_BITDEPTH >= 10
163 namespace high_bitdepth {
164 namespace {
165 
SuperResCoefficients_NEON(const int upscaled_width,const int initial_subpixel_x,const int step,void * const coefficients)166 void SuperResCoefficients_NEON(const int upscaled_width,
167                                const int initial_subpixel_x, const int step,
168                                void* const coefficients) {
169   auto* dst = static_cast<uint16_t*>(coefficients);
170   int subpixel_x = initial_subpixel_x;
171   int x = RightShiftWithCeiling(upscaled_width, 3);
172   do {
173     uint16x8_t filter[8];
174     for (int i = 0; i < 8; ++i, subpixel_x += step) {
175       const uint8x8_t filter_8 =
176           vld1_u8(kUpscaleFilterUnsigned[(subpixel_x & kSuperResScaleMask) >>
177                                          kSuperResExtraBits]);
178       // uint8_t -> uint16_t
179       filter[i] = vmovl_u8(filter_8);
180     }
181 
182     Transpose8x8(filter);
183 
184     vst1q_u16(dst, filter[0]);
185     dst += 8;
186     vst1q_u16(dst, filter[1]);
187     dst += 8;
188     vst1q_u16(dst, filter[2]);
189     dst += 8;
190     vst1q_u16(dst, filter[3]);
191     dst += 8;
192     vst1q_u16(dst, filter[4]);
193     dst += 8;
194     vst1q_u16(dst, filter[5]);
195     dst += 8;
196     vst1q_u16(dst, filter[6]);
197     dst += 8;
198     vst1q_u16(dst, filter[7]);
199     dst += 8;
200   } while (--x != 0);
201 }
202 
203 // The sum is clipped to [0, ((1 << bitdepth) -1)]. Adding all positive and then
204 // subtracting all negative with saturation will clip to zero.
205 //           0 1 2 3 4 5 6 7
206 // tap sign: - + - + + - + -
SuperRes(const uint16x8_t src[kSuperResFilterTaps],const uint16_t ** coefficients,int bitdepth)207 inline uint16x8_t SuperRes(const uint16x8_t src[kSuperResFilterTaps],
208                            const uint16_t** coefficients, int bitdepth) {
209   uint16x8_t f[kSuperResFilterTaps];
210   for (int i = 0; i < kSuperResFilterTaps; ++i, *coefficients += 8) {
211     f[i] = vld1q_u16(*coefficients);
212   }
213 
214   uint32x4_t res_lo = vmull_u16(vget_low_u16(src[1]), vget_low_u16(f[1]));
215   res_lo = vmlal_u16(res_lo, vget_low_u16(src[3]), vget_low_u16(f[3]));
216   res_lo = vmlal_u16(res_lo, vget_low_u16(src[4]), vget_low_u16(f[4]));
217   res_lo = vmlal_u16(res_lo, vget_low_u16(src[6]), vget_low_u16(f[6]));
218 
219   uint32x4_t temp_lo = vmull_u16(vget_low_u16(src[0]), vget_low_u16(f[0]));
220   temp_lo = vmlal_u16(temp_lo, vget_low_u16(src[2]), vget_low_u16(f[2]));
221   temp_lo = vmlal_u16(temp_lo, vget_low_u16(src[5]), vget_low_u16(f[5]));
222   temp_lo = vmlal_u16(temp_lo, vget_low_u16(src[7]), vget_low_u16(f[7]));
223 
224   res_lo = vqsubq_u32(res_lo, temp_lo);
225 
226   uint32x4_t res_hi = vmull_u16(vget_high_u16(src[1]), vget_high_u16(f[1]));
227   res_hi = vmlal_u16(res_hi, vget_high_u16(src[3]), vget_high_u16(f[3]));
228   res_hi = vmlal_u16(res_hi, vget_high_u16(src[4]), vget_high_u16(f[4]));
229   res_hi = vmlal_u16(res_hi, vget_high_u16(src[6]), vget_high_u16(f[6]));
230 
231   uint32x4_t temp_hi = vmull_u16(vget_high_u16(src[0]), vget_high_u16(f[0]));
232   temp_hi = vmlal_u16(temp_hi, vget_high_u16(src[2]), vget_high_u16(f[2]));
233   temp_hi = vmlal_u16(temp_hi, vget_high_u16(src[5]), vget_high_u16(f[5]));
234   temp_hi = vmlal_u16(temp_hi, vget_high_u16(src[7]), vget_high_u16(f[7]));
235 
236   res_hi = vqsubq_u32(res_hi, temp_hi);
237 
238   const uint16x8_t res = vcombine_u16(vqrshrn_n_u32(res_lo, kFilterBits),
239                                       vqrshrn_n_u32(res_hi, kFilterBits));
240 
241   // Clip the result at (1 << bd) - 1.
242   return vminq_u16(res, vdupq_n_u16((1 << bitdepth) - 1));
243 }
244 
245 template <int bitdepth>
SuperRes_NEON(const void * LIBGAV1_RESTRICT const coefficients,void * LIBGAV1_RESTRICT const source,const ptrdiff_t source_stride,const int height,const int downscaled_width,const int upscaled_width,const int initial_subpixel_x,const int step,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)246 void SuperRes_NEON(const void* LIBGAV1_RESTRICT const coefficients,
247                    void* LIBGAV1_RESTRICT const source,
248                    const ptrdiff_t source_stride, const int height,
249                    const int downscaled_width, const int upscaled_width,
250                    const int initial_subpixel_x, const int step,
251                    void* LIBGAV1_RESTRICT const dest,
252                    const ptrdiff_t dest_stride) {
253   auto* src = static_cast<uint16_t*>(source) - DivideBy2(kSuperResFilterTaps);
254   auto* dst = static_cast<uint16_t*>(dest);
255   int y = height;
256   do {
257     const auto* filter = static_cast<const uint16_t*>(coefficients);
258     uint16_t* dst_ptr = dst;
259 #if LIBGAV1_MSAN
260     // Initialize the padding area to prevent msan warnings.
261     const int super_res_right_border = kSuperResHorizontalPadding;
262 #else
263     const int super_res_right_border = kSuperResHorizontalBorder;
264 #endif
265     ExtendLine<uint16_t>(src + DivideBy2(kSuperResFilterTaps), downscaled_width,
266                          kSuperResHorizontalBorder, super_res_right_border);
267     int subpixel_x = initial_subpixel_x;
268     uint16x8_t sr[8];
269     int x = RightShiftWithCeiling(upscaled_width, 3);
270     // The below code calculates up to 7 extra upscaled
271     // pixels which will over-read up to 7 downscaled pixels in the end of each
272     // row. kSuperResHorizontalBorder accounts for this.
273     do {
274       for (int i = 0; i < 8; ++i, subpixel_x += step) {
275         sr[i] = vld1q_u16(&src[subpixel_x >> kSuperResScaleBits]);
276       }
277 
278       Transpose8x8(sr);
279 
280       const uint16x8_t d0 = SuperRes(sr, &filter, bitdepth);
281       vst1q_u16(dst_ptr, d0);
282       dst_ptr += 8;
283     } while (--x != 0);
284     src += source_stride;
285     dst += dest_stride;
286   } while (--y != 0);
287 }
288 
Init10bpp()289 void Init10bpp() {
290   Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
291   assert(dsp != nullptr);
292   dsp->super_res_coefficients = SuperResCoefficients_NEON;
293   dsp->super_res = SuperRes_NEON<10>;
294 }
295 
296 }  // namespace
297 }  // namespace high_bitdepth
298 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
299 
SuperResInit_NEON()300 void SuperResInit_NEON() {
301   low_bitdepth::Init8bpp();
302 #if LIBGAV1_MAX_BITDEPTH >= 10
303   high_bitdepth::Init10bpp();
304 #endif
305 }
306 }  // namespace dsp
307 }  // namespace libgav1
308 
309 #else   // !LIBGAV1_ENABLE_NEON
310 
311 namespace libgav1 {
312 namespace dsp {
313 
SuperResInit_NEON()314 void SuperResInit_NEON() {}
315 
316 }  // namespace dsp
317 }  // namespace libgav1
318 #endif  // LIBGAV1_ENABLE_NEON
319