1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_ENABLE_NEON
19
20 #include <arm_neon.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/arm/common_neon.h"
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33
34 constexpr int kInterPostRoundBit = 4;
35
36 namespace low_bitdepth {
37 namespace {
38
ComputeWeightedAverage8(const int16x8_t pred0,const int16x8_t pred1,const int16x8_t weight)39 inline uint8x8_t ComputeWeightedAverage8(const int16x8_t pred0,
40 const int16x8_t pred1,
41 const int16x8_t weight) {
42 // Given: p0,p1 in range [-5132,9212] and w0 = 16 - w1, w1 = 16 - w0
43 // Output: (p0 * w0 + p1 * w1 + 128(=rounding bit)) >>
44 // 8(=kInterPostRoundBit + 4)
45 // The formula is manipulated to avoid lengthening to 32 bits.
46 // p0 * w0 + p1 * w1 = p0 * w0 + (16 - w0) * p1
47 // = (p0 - p1) * w0 + 16 * p1
48 // Maximum value of p0 - p1 is 9212 + 5132 = 0x3808.
49 const int16x8_t diff = vsubq_s16(pred0, pred1);
50 // (((p0 - p1) * (w0 << 11) << 1) >> 16) + ((16 * p1) >> 4)
51 const int16x8_t weighted_diff = vqdmulhq_s16(diff, weight);
52 // ((p0 - p1) * w0 >> 4) + p1
53 const int16x8_t upscaled_average = vaddq_s16(weighted_diff, pred1);
54 // (((p0 - p1) * w0 >> 4) + p1 + (128 >> 4)) >> 4
55 return vqrshrun_n_s16(upscaled_average, kInterPostRoundBit);
56 }
57
58 template <int width>
DistanceWeightedBlendSmall_NEON(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,const int height,const int16x8_t weight,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)59 inline void DistanceWeightedBlendSmall_NEON(
60 const int16_t* LIBGAV1_RESTRICT prediction_0,
61 const int16_t* LIBGAV1_RESTRICT prediction_1, const int height,
62 const int16x8_t weight, void* LIBGAV1_RESTRICT const dest,
63 const ptrdiff_t dest_stride) {
64 auto* dst = static_cast<uint8_t*>(dest);
65 constexpr int step = 16 / width;
66
67 int y = height;
68 do {
69 const int16x8_t src_00 = vld1q_s16(prediction_0);
70 const int16x8_t src_10 = vld1q_s16(prediction_1);
71 prediction_0 += 8;
72 prediction_1 += 8;
73 const uint8x8_t result0 = ComputeWeightedAverage8(src_00, src_10, weight);
74
75 const int16x8_t src_01 = vld1q_s16(prediction_0);
76 const int16x8_t src_11 = vld1q_s16(prediction_1);
77 prediction_0 += 8;
78 prediction_1 += 8;
79 const uint8x8_t result1 = ComputeWeightedAverage8(src_01, src_11, weight);
80
81 if (width == 4) {
82 StoreLo4(dst, result0);
83 dst += dest_stride;
84 StoreHi4(dst, result0);
85 dst += dest_stride;
86 StoreLo4(dst, result1);
87 dst += dest_stride;
88 StoreHi4(dst, result1);
89 dst += dest_stride;
90 } else {
91 assert(width == 8);
92 vst1_u8(dst, result0);
93 dst += dest_stride;
94 vst1_u8(dst, result1);
95 dst += dest_stride;
96 }
97 y -= step;
98 } while (y != 0);
99 }
100
DistanceWeightedBlendLarge_NEON(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,const int16x8_t weight,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)101 inline void DistanceWeightedBlendLarge_NEON(
102 const int16_t* LIBGAV1_RESTRICT prediction_0,
103 const int16_t* LIBGAV1_RESTRICT prediction_1, const int16x8_t weight,
104 const int width, const int height, void* LIBGAV1_RESTRICT const dest,
105 const ptrdiff_t dest_stride) {
106 auto* dst = static_cast<uint8_t*>(dest);
107
108 int y = height;
109 do {
110 int x = 0;
111 do {
112 const int16x8_t src0_lo = vld1q_s16(prediction_0 + x);
113 const int16x8_t src1_lo = vld1q_s16(prediction_1 + x);
114 const uint8x8_t res_lo =
115 ComputeWeightedAverage8(src0_lo, src1_lo, weight);
116
117 const int16x8_t src0_hi = vld1q_s16(prediction_0 + x + 8);
118 const int16x8_t src1_hi = vld1q_s16(prediction_1 + x + 8);
119 const uint8x8_t res_hi =
120 ComputeWeightedAverage8(src0_hi, src1_hi, weight);
121
122 const uint8x16_t result = vcombine_u8(res_lo, res_hi);
123 vst1q_u8(dst + x, result);
124 x += 16;
125 } while (x < width);
126 dst += dest_stride;
127 prediction_0 += width;
128 prediction_1 += width;
129 } while (--y != 0);
130 }
131
DistanceWeightedBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)132 inline void DistanceWeightedBlend_NEON(
133 const void* LIBGAV1_RESTRICT prediction_0,
134 const void* LIBGAV1_RESTRICT prediction_1, const uint8_t weight_0,
135 const uint8_t /*weight_1*/, const int width, const int height,
136 void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
137 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
138 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
139 // Upscale the weight for vqdmulh.
140 const int16x8_t weight = vdupq_n_s16(weight_0 << 11);
141 if (width == 4) {
142 DistanceWeightedBlendSmall_NEON<4>(pred_0, pred_1, height, weight, dest,
143 dest_stride);
144 return;
145 }
146
147 if (width == 8) {
148 DistanceWeightedBlendSmall_NEON<8>(pred_0, pred_1, height, weight, dest,
149 dest_stride);
150 return;
151 }
152
153 DistanceWeightedBlendLarge_NEON(pred_0, pred_1, weight, width, height, dest,
154 dest_stride);
155 }
156
Init8bpp()157 void Init8bpp() {
158 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
159 assert(dsp != nullptr);
160 dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
161 }
162
163 } // namespace
164 } // namespace low_bitdepth
165
166 //------------------------------------------------------------------------------
167 #if LIBGAV1_MAX_BITDEPTH >= 10
168 namespace high_bitdepth {
169 namespace {
170
ComputeWeightedAverage8(const uint16x4x2_t pred0,const uint16x4x2_t pred1,const uint16x4_t weights[2])171 inline uint16x4x2_t ComputeWeightedAverage8(const uint16x4x2_t pred0,
172 const uint16x4x2_t pred1,
173 const uint16x4_t weights[2]) {
174 const uint32x4_t wpred0_lo = vmull_u16(weights[0], pred0.val[0]);
175 const uint32x4_t wpred0_hi = vmull_u16(weights[0], pred0.val[1]);
176 const uint32x4_t blended_lo = vmlal_u16(wpred0_lo, weights[1], pred1.val[0]);
177 const uint32x4_t blended_hi = vmlal_u16(wpred0_hi, weights[1], pred1.val[1]);
178 const int32x4_t offset = vdupq_n_s32(kCompoundOffset * 16);
179 const int32x4_t res_lo = vsubq_s32(vreinterpretq_s32_u32(blended_lo), offset);
180 const int32x4_t res_hi = vsubq_s32(vreinterpretq_s32_u32(blended_hi), offset);
181 const uint16x4_t bd_max = vdup_n_u16((1 << kBitdepth10) - 1);
182 // Clip the result at (1 << bd) - 1.
183 uint16x4x2_t result;
184 result.val[0] =
185 vmin_u16(vqrshrun_n_s32(res_lo, kInterPostRoundBit + 4), bd_max);
186 result.val[1] =
187 vmin_u16(vqrshrun_n_s32(res_hi, kInterPostRoundBit + 4), bd_max);
188 return result;
189 }
190
ComputeWeightedAverage8(const uint16x4x4_t pred0,const uint16x4x4_t pred1,const uint16x4_t weights[2])191 inline uint16x4x4_t ComputeWeightedAverage8(const uint16x4x4_t pred0,
192 const uint16x4x4_t pred1,
193 const uint16x4_t weights[2]) {
194 const int32x4_t offset = vdupq_n_s32(kCompoundOffset * 16);
195 const uint32x4_t wpred0 = vmull_u16(weights[0], pred0.val[0]);
196 const uint32x4_t wpred1 = vmull_u16(weights[0], pred0.val[1]);
197 const uint32x4_t blended0 = vmlal_u16(wpred0, weights[1], pred1.val[0]);
198 const uint32x4_t blended1 = vmlal_u16(wpred1, weights[1], pred1.val[1]);
199 const int32x4_t res0 = vsubq_s32(vreinterpretq_s32_u32(blended0), offset);
200 const int32x4_t res1 = vsubq_s32(vreinterpretq_s32_u32(blended1), offset);
201 const uint32x4_t wpred2 = vmull_u16(weights[0], pred0.val[2]);
202 const uint32x4_t wpred3 = vmull_u16(weights[0], pred0.val[3]);
203 const uint32x4_t blended2 = vmlal_u16(wpred2, weights[1], pred1.val[2]);
204 const uint32x4_t blended3 = vmlal_u16(wpred3, weights[1], pred1.val[3]);
205 const int32x4_t res2 = vsubq_s32(vreinterpretq_s32_u32(blended2), offset);
206 const int32x4_t res3 = vsubq_s32(vreinterpretq_s32_u32(blended3), offset);
207 const uint16x4_t bd_max = vdup_n_u16((1 << kBitdepth10) - 1);
208 // Clip the result at (1 << bd) - 1.
209 uint16x4x4_t result;
210 result.val[0] =
211 vmin_u16(vqrshrun_n_s32(res0, kInterPostRoundBit + 4), bd_max);
212 result.val[1] =
213 vmin_u16(vqrshrun_n_s32(res1, kInterPostRoundBit + 4), bd_max);
214 result.val[2] =
215 vmin_u16(vqrshrun_n_s32(res2, kInterPostRoundBit + 4), bd_max);
216 result.val[3] =
217 vmin_u16(vqrshrun_n_s32(res3, kInterPostRoundBit + 4), bd_max);
218
219 return result;
220 }
221
222 // We could use vld1_u16_x2, but for compatibility reasons, use this function
223 // instead. The compiler optimizes to the correct instruction.
LoadU16x4_x2(uint16_t const * ptr)224 inline uint16x4x2_t LoadU16x4_x2(uint16_t const* ptr) {
225 uint16x4x2_t x;
226 // gcc/clang (64 bit) optimizes the following to ldp.
227 x.val[0] = vld1_u16(ptr);
228 x.val[1] = vld1_u16(ptr + 4);
229 return x;
230 }
231
232 // We could use vld1_u16_x4, but for compatibility reasons, use this function
233 // instead. The compiler optimizes to a pair of vld1_u16_x2, which showed better
234 // performance in the speed tests.
LoadU16x4_x4(uint16_t const * ptr)235 inline uint16x4x4_t LoadU16x4_x4(uint16_t const* ptr) {
236 uint16x4x4_t x;
237 x.val[0] = vld1_u16(ptr);
238 x.val[1] = vld1_u16(ptr + 4);
239 x.val[2] = vld1_u16(ptr + 8);
240 x.val[3] = vld1_u16(ptr + 12);
241 return x;
242 }
243
DistanceWeightedBlend_NEON(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)244 void DistanceWeightedBlend_NEON(const void* LIBGAV1_RESTRICT prediction_0,
245 const void* LIBGAV1_RESTRICT prediction_1,
246 const uint8_t weight_0, const uint8_t weight_1,
247 const int width, const int height,
248 void* LIBGAV1_RESTRICT const dest,
249 const ptrdiff_t dest_stride) {
250 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
251 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
252 auto* dst = static_cast<uint16_t*>(dest);
253 const ptrdiff_t dst_stride = dest_stride / sizeof(dst[0]);
254 const uint16x4_t weights[2] = {vdup_n_u16(weight_0), vdup_n_u16(weight_1)};
255
256 if (width == 4) {
257 int y = height;
258 do {
259 const uint16x4x2_t src0 = LoadU16x4_x2(pred_0);
260 const uint16x4x2_t src1 = LoadU16x4_x2(pred_1);
261 const uint16x4x2_t res = ComputeWeightedAverage8(src0, src1, weights);
262 vst1_u16(dst, res.val[0]);
263 vst1_u16(dst + dst_stride, res.val[1]);
264 dst += dst_stride << 1;
265 pred_0 += 8;
266 pred_1 += 8;
267 y -= 2;
268 } while (y != 0);
269 } else if (width == 8) {
270 int y = height;
271 do {
272 const uint16x4x4_t src0 = LoadU16x4_x4(pred_0);
273 const uint16x4x4_t src1 = LoadU16x4_x4(pred_1);
274 const uint16x4x4_t res = ComputeWeightedAverage8(src0, src1, weights);
275 vst1_u16(dst, res.val[0]);
276 vst1_u16(dst + 4, res.val[1]);
277 vst1_u16(dst + dst_stride, res.val[2]);
278 vst1_u16(dst + dst_stride + 4, res.val[3]);
279 dst += dst_stride << 1;
280 pred_0 += 16;
281 pred_1 += 16;
282 y -= 2;
283 } while (y != 0);
284 } else {
285 int y = height;
286 do {
287 int x = 0;
288 do {
289 const uint16x4x4_t src0 = LoadU16x4_x4(pred_0 + x);
290 const uint16x4x4_t src1 = LoadU16x4_x4(pred_1 + x);
291 const uint16x4x4_t res = ComputeWeightedAverage8(src0, src1, weights);
292 vst1_u16(dst + x, res.val[0]);
293 vst1_u16(dst + x + 4, res.val[1]);
294 vst1_u16(dst + x + 8, res.val[2]);
295 vst1_u16(dst + x + 12, res.val[3]);
296 x += 16;
297 } while (x < width);
298 dst += dst_stride;
299 pred_0 += width;
300 pred_1 += width;
301 } while (--y != 0);
302 }
303 }
304
Init10bpp()305 void Init10bpp() {
306 Dsp* dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
307 assert(dsp != nullptr);
308 dsp->distance_weighted_blend = DistanceWeightedBlend_NEON;
309 }
310
311 } // namespace
312 } // namespace high_bitdepth
313 #endif // LIBGAV1_MAX_BITDEPTH >= 10
314
DistanceWeightedBlendInit_NEON()315 void DistanceWeightedBlendInit_NEON() {
316 low_bitdepth::Init8bpp();
317 #if LIBGAV1_MAX_BITDEPTH >= 10
318 high_bitdepth::Init10bpp();
319 #endif
320 }
321
322 } // namespace dsp
323 } // namespace libgav1
324
325 #else // !LIBGAV1_ENABLE_NEON
326
327 namespace libgav1 {
328 namespace dsp {
329
DistanceWeightedBlendInit_NEON()330 void DistanceWeightedBlendInit_NEON() {}
331
332 } // namespace dsp
333 } // namespace libgav1
334 #endif // LIBGAV1_ENABLE_NEON
335