xref: /aosp_15_r20/external/libgav1/src/dsp/arm/distance_weighted_blend_neon.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_ENABLE_NEON
19 
20 #include <arm_neon.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 
34 constexpr int kInterPostRoundBit = 4;
35 
36 namespace low_bitdepth {
37 namespace {
38 
ComputeWeightedAverage8(const int16x8_t pred0,const int16x8_t pred1,const int16x8_t weight)39 inline uint8x8_t ComputeWeightedAverage8(const int16x8_t pred0,
40                                          const int16x8_t pred1,
41                                          const int16x8_t weight) {
42   // Given: p0,p1 in range [-5132,9212] and w0 = 16 - w1, w1 = 16 - w0
43   // Output: (p0 * w0 + p1 * w1 + 128(=rounding bit)) >>
44   //    8(=kInterPostRoundBit + 4)
45   // The formula is manipulated to avoid lengthening to 32 bits.
46   // p0 * w0 + p1 * w1 = p0 * w0 + (16 - w0) * p1
47   // = (p0 - p1) * w0 + 16 * p1
48   // Maximum value of p0 - p1 is 9212 + 5132 = 0x3808.
49   const int16x8_t diff = vsubq_s16(pred0, pred1);
50   // (((p0 - p1) * (w0 << 11) << 1) >> 16) + ((16 * p1) >> 4)
51   const int16x8_t weighted_diff = vqdmulhq_s16(diff, weight);
52   // ((p0 - p1) * w0 >> 4) + p1
53   const int16x8_t upscaled_average = vaddq_s16(weighted_diff, pred1);
54   // (((p0 - p1) * w0 >> 4) + p1 + (128 >> 4)) >> 4
55   return vqrshrun_n_s16(upscaled_average, kInterPostRoundBit);
56 }
57 
58 template <int width>
DistanceWeightedBlendSmall_NEON(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,const int height,const int16x8_t weight,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)59 inline void DistanceWeightedBlendSmall_NEON(
60     const int16_t* LIBGAV1_RESTRICT prediction_0,
61     const int16_t* LIBGAV1_RESTRICT prediction_1, const int height,
62     const int16x8_t weight, void* LIBGAV1_RESTRICT const dest,
63     const ptrdiff_t dest_stride) {
64   auto* dst = static_cast<uint8_t*>(dest);
65   constexpr int step = 16 / width;
66 
67   int y = height;
68   do {
69     const int16x8_t src_00 = vld1q_s16(prediction_0);
70     const int16x8_t src_10 = vld1q_s16(prediction_1);
71     prediction_0 += 8;
72     prediction_1 += 8;
73     const uint8x8_t result0 = ComputeWeightedAverage8(src_00, src_10, weight);
74 
75     const int16x8_t src_01 = vld1q_s16(prediction_0);
76     const int16x8_t src_11 = vld1q_s16(prediction_1);
77     prediction_0 += 8;
78     prediction_1 += 8;
79     const uint8x8_t result1 = ComputeWeightedAverage8(src_01, src_11, weight);
80 
81     if (width == 4) {
82       StoreLo4(dst, result0);
83       dst += dest_stride;
84       StoreHi4(dst, result0);
85       dst += dest_stride;
86       StoreLo4(dst, result1);
87       dst += dest_stride;
88       StoreHi4(dst, result1);
89       dst += dest_stride;
90     } else {
91       assert(width == 8);
92       vst1_u8(dst, result0);
93       dst += dest_stride;
94       vst1_u8(dst, result1);
95       dst += dest_stride;
96     }
97     y -= step;
98   } while (y != 0);
99 }
100 
DistanceWeightedBlendLarge_NEON(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,const int16x8_t weight,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)101 inline void DistanceWeightedBlendLarge_NEON(
102     const int16_t* LIBGAV1_RESTRICT prediction_0,
103     const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x8_t weight,
104     const int width, const int height, void* LIBGAV1_RESTRICT const dest,
105     const ptrdiff_t dest_stride) {
106   auto* dst = static_cast<uint8_t*>(dest);
107 
108   int y = height;
109   do {
110     int x = 0;
111     do {
112       const int16x8_t src0_lo = vld1q_s16(prediction_0 + x);
113       const int16x8_t src1_lo = vld1q_s16(prediction_1 + x);
114       const uint8x8_t res_lo =
115           ComputeWeightedAverage8(src0_lo, src1_lo, weight);
116 
117       const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8);
118       const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8);
119       const uint8x8_t res_hi =
120           ComputeWeightedAverage8(src0_hi, src1_hi, weight);
121 
122       const uint8x16_t result = vcombine_u8(res_lo, res_hi);
123       vst1q_u8(dst + x, result);
124       x += 16;
125     } while (x < width);
126     dst += dest_stride;
127     prediction_0 += width;
128     prediction_1 += width;
129   } while (--y != 0);
130 }
131 
DistanceWeightedBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)132 inline void DistanceWeightedBlend_NEON(
133     const void* LIBGAV1_RESTRICT prediction_0,
134     const void* LIBGAV1_RESTRICT prediction_1, const uint8_t weight_0,
135     const uint8_t /*weight_1*/, const int width, const int height,
136     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
137   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
138   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
139   // Upscale the weight for vqdmulh.
140   const int16x8_t weight = vdupq_n_s16(weight_0 << 11);
141   if (width == 4) {
142     DistanceWeightedBlendSmall_NEON<4>(pred_0, pred_1, height, weight, dest,
143                                        dest_stride);
144     return;
145   }
146 
147   if (width == 8) {
148     DistanceWeightedBlendSmall_NEON<8>(pred_0, pred_1, height, weight, dest,
149                                        dest_stride);
150     return;
151   }
152 
153   DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weight, width, height, dest,
154                                   dest_stride);
155 }
156 
Init8bpp()157 void Init8bpp() {
158   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
159   assert(dsp != nullptr);
160   dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
161 }
162 
163 }  // namespace
164 }  // namespace low_bitdepth
165 
166 //------------------------------------------------------------------------------
167 #if LIBGAV1_MAX_BITDEPTH >= 10
168 namespace high_bitdepth {
169 namespace {
170 
ComputeWeightedAverage8(const uint16x4x2_t pred0,const uint16x4x2_t pred1,const uint16x4_t weights[2])171 inline uint16x4x2_t ComputeWeightedAverage8(const uint16x4x2_t pred0,
172                                             const uint16x4x2_t pred1,
173                                             const uint16x4_t weights[2]) {
174   const uint32x4_t wpred0_lo = vmull_u16(weights[0], pred0.val[0]);
175   const uint32x4_t wpred0_hi = vmull_u16(weights[0], pred0.val[1]);
176   const uint32x4_t blended_lo = vmlal_u16(wpred0_lo, weights[1], pred1.val[0]);
177   const uint32x4_t blended_hi = vmlal_u16(wpred0_hi, weights[1], pred1.val[1]);
178   const int32x4_t offset = vdupq_n_s32(kCompoundOffset * 16);
179   const int32x4_t res_lo = vsubq_s32(vreinterpretq_s32_u32(blended_lo), offset);
180   const int32x4_t res_hi = vsubq_s32(vreinterpretq_s32_u32(blended_hi), offset);
181   const uint16x4_t bd_max = vdup_n_u16((1 << kBitdepth10) - 1);
182   // Clip the result at (1 << bd) - 1.
183   uint16x4x2_t result;
184   result.val[0] =
185       vmin_u16(vqrshrun_n_s32(res_lo, kInterPostRoundBit + 4), bd_max);
186   result.val[1] =
187       vmin_u16(vqrshrun_n_s32(res_hi, kInterPostRoundBit + 4), bd_max);
188   return result;
189 }
190 
ComputeWeightedAverage8(const uint16x4x4_t pred0,const uint16x4x4_t pred1,const uint16x4_t weights[2])191 inline uint16x4x4_t ComputeWeightedAverage8(const uint16x4x4_t pred0,
192                                             const uint16x4x4_t pred1,
193                                             const uint16x4_t weights[2]) {
194   const int32x4_t offset = vdupq_n_s32(kCompoundOffset * 16);
195   const uint32x4_t wpred0 = vmull_u16(weights[0], pred0.val[0]);
196   const uint32x4_t wpred1 = vmull_u16(weights[0], pred0.val[1]);
197   const uint32x4_t blended0 = vmlal_u16(wpred0, weights[1], pred1.val[0]);
198   const uint32x4_t blended1 = vmlal_u16(wpred1, weights[1], pred1.val[1]);
199   const int32x4_t res0 = vsubq_s32(vreinterpretq_s32_u32(blended0), offset);
200   const int32x4_t res1 = vsubq_s32(vreinterpretq_s32_u32(blended1), offset);
201   const uint32x4_t wpred2 = vmull_u16(weights[0], pred0.val[2]);
202   const uint32x4_t wpred3 = vmull_u16(weights[0], pred0.val[3]);
203   const uint32x4_t blended2 = vmlal_u16(wpred2, weights[1], pred1.val[2]);
204   const uint32x4_t blended3 = vmlal_u16(wpred3, weights[1], pred1.val[3]);
205   const int32x4_t res2 = vsubq_s32(vreinterpretq_s32_u32(blended2), offset);
206   const int32x4_t res3 = vsubq_s32(vreinterpretq_s32_u32(blended3), offset);
207   const uint16x4_t bd_max = vdup_n_u16((1 << kBitdepth10) - 1);
208   // Clip the result at (1 << bd) - 1.
209   uint16x4x4_t result;
210   result.val[0] =
211       vmin_u16(vqrshrun_n_s32(res0, kInterPostRoundBit + 4), bd_max);
212   result.val[1] =
213       vmin_u16(vqrshrun_n_s32(res1, kInterPostRoundBit + 4), bd_max);
214   result.val[2] =
215       vmin_u16(vqrshrun_n_s32(res2, kInterPostRoundBit + 4), bd_max);
216   result.val[3] =
217       vmin_u16(vqrshrun_n_s32(res3, kInterPostRoundBit + 4), bd_max);
218 
219   return result;
220 }
221 
222 // We could use vld1_u16_x2, but for compatibility reasons, use this function
223 // instead. The compiler optimizes to the correct instruction.
LoadU16x4_x2(uint16_t const * ptr)224 inline uint16x4x2_t LoadU16x4_x2(uint16_t const* ptr) {
225   uint16x4x2_t x;
226   // gcc/clang (64 bit) optimizes the following to ldp.
227   x.val[0] = vld1_u16(ptr);
228   x.val[1] = vld1_u16(ptr + 4);
229   return x;
230 }
231 
232 // We could use vld1_u16_x4, but for compatibility reasons, use this function
233 // instead. The compiler optimizes to a pair of vld1_u16_x2, which showed better
234 // performance in the speed tests.
LoadU16x4_x4(uint16_t const * ptr)235 inline uint16x4x4_t LoadU16x4_x4(uint16_t const* ptr) {
236   uint16x4x4_t x;
237   x.val[0] = vld1_u16(ptr);
238   x.val[1] = vld1_u16(ptr + 4);
239   x.val[2] = vld1_u16(ptr + 8);
240   x.val[3] = vld1_u16(ptr + 12);
241   return x;
242 }
243 
DistanceWeightedBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)244 void DistanceWeightedBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
245                                 const void* LIBGAV1_RESTRICT prediction_1,
246                                 const uint8_t weight_0, const uint8_t weight_1,
247                                 const int width, const int height,
248                                 void* LIBGAV1_RESTRICT const dest,
249                                 const ptrdiff_t dest_stride) {
250   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
251   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
252   auto* dst = static_cast<uint16_t*>(dest);
253   const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
254   const uint16x4_t weights[2] = {vdup_n_u16(weight_0), vdup_n_u16(weight_1)};
255 
256   if (width == 4) {
257     int y = height;
258     do {
259       const uint16x4x2_t src0 = LoadU16x4_x2(pred_0);
260       const uint16x4x2_t src1 = LoadU16x4_x2(pred_1);
261       const uint16x4x2_t res = ComputeWeightedAverage8(src0, src1, weights);
262       vst1_u16(dst, res.val[0]);
263       vst1_u16(dst + dst_stride, res.val[1]);
264       dst += dst_stride << 1;
265       pred_0 += 8;
266       pred_1 += 8;
267       y -= 2;
268     } while (y != 0);
269   } else if (width == 8) {
270     int y = height;
271     do {
272       const uint16x4x4_t src0 = LoadU16x4_x4(pred_0);
273       const uint16x4x4_t src1 = LoadU16x4_x4(pred_1);
274       const uint16x4x4_t res = ComputeWeightedAverage8(src0, src1, weights);
275       vst1_u16(dst, res.val[0]);
276       vst1_u16(dst + 4, res.val[1]);
277       vst1_u16(dst + dst_stride, res.val[2]);
278       vst1_u16(dst + dst_stride + 4, res.val[3]);
279       dst += dst_stride << 1;
280       pred_0 += 16;
281       pred_1 += 16;
282       y -= 2;
283     } while (y != 0);
284   } else {
285     int y = height;
286     do {
287       int x = 0;
288       do {
289         const uint16x4x4_t src0 = LoadU16x4_x4(pred_0 + x);
290         const uint16x4x4_t src1 = LoadU16x4_x4(pred_1 + x);
291         const uint16x4x4_t res = ComputeWeightedAverage8(src0, src1, weights);
292         vst1_u16(dst + x, res.val[0]);
293         vst1_u16(dst + x + 4, res.val[1]);
294         vst1_u16(dst + x + 8, res.val[2]);
295         vst1_u16(dst + x + 12, res.val[3]);
296         x += 16;
297       } while (x < width);
298       dst += dst_stride;
299       pred_0 += width;
300       pred_1 += width;
301     } while (--y != 0);
302   }
303 }
304 
Init10bpp()305 void Init10bpp() {
306   Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
307   assert(dsp != nullptr);
308   dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
309 }
310 
311 }  // namespace
312 }  // namespace high_bitdepth
313 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
314 
DistanceWeightedBlendInit_NEON()315 void DistanceWeightedBlendInit_NEON() {
316   low_bitdepth::Init8bpp();
317 #if LIBGAV1_MAX_BITDEPTH >= 10
318   high_bitdepth::Init10bpp();
319 #endif
320 }
321 
322 }  // namespace dsp
323 }  // namespace libgav1
324 
325 #else   // !LIBGAV1_ENABLE_NEON
326 
327 namespace libgav1 {
328 namespace dsp {
329 
DistanceWeightedBlendInit_NEON()330 void DistanceWeightedBlendInit_NEON() {}
331 
332 }  // namespace dsp
333 }  // namespace libgav1
334 #endif  // LIBGAV1_ENABLE_NEON
335