1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_TARGETING_SSE4_1
19
20 #include <xmmintrin.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35
36 constexpr int kInterPostRoundBit = 4;
37 constexpr int kInterPostRhsAdjust = 1 << (16 - kInterPostRoundBit - 1);
38
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weight)39 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
40 const __m128i& pred1,
41 const __m128i& weight) {
42 // Given: p0,p1 in range [-5132,9212] and w0 = 16 - w1, w1 = 16 - w0
43 // Output: (p0 * w0 + p1 * w1 + 128(=rounding bit)) >>
44 // 8(=kInterPostRoundBit + 4)
45 // The formula is manipulated to avoid lengthening to 32 bits.
46 // p0 * w0 + p1 * w1 = p0 * w0 + (16 - w0) * p1
47 // = (p0 - p1) * w0 + 16 * p1
48 // Maximum value of p0 - p1 is 9212 + 5132 = 0x3808.
49 const __m128i diff = _mm_slli_epi16(_mm_sub_epi16(pred0, pred1), 1);
50 // (((p0 - p1) * (w0 << 12) >> 16) + ((16 * p1) >> 4)
51 const __m128i weighted_diff = _mm_mulhi_epi16(diff, weight);
52 // ((p0 - p1) * w0 >> 4) + p1
53 const __m128i upscaled_average = _mm_add_epi16(weighted_diff, pred1);
54 // (x << 11) >> 15 == x >> 4
55 const __m128i right_shift_prep = _mm_set1_epi16(kInterPostRhsAdjust);
56 // (((p0 - p1) * w0 >> 4) + p1 + (128 >> 4)) >> 4
57 return _mm_mulhrs_epi16(upscaled_average, right_shift_prep);
58 }
59
60 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)61 inline void DistanceWeightedBlend4xH_SSE4_1(
62 const int16_t* LIBGAV1_RESTRICT pred_0,
63 const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
64 void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
65 auto* dst = static_cast<uint8_t*>(dest);
66 // Upscale the weight for mulhi.
67 const __m128i weights = _mm_set1_epi16(weight << 11);
68
69 for (int y = 0; y < height; y += 4) {
70 const __m128i src_00 = LoadAligned16(pred_0);
71 const __m128i src_10 = LoadAligned16(pred_1);
72 pred_0 += 8;
73 pred_1 += 8;
74 const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
75
76 const __m128i src_01 = LoadAligned16(pred_0);
77 const __m128i src_11 = LoadAligned16(pred_1);
78 pred_0 += 8;
79 pred_1 += 8;
80 const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
81
82 const __m128i result_pixels = _mm_packus_epi16(res0, res1);
83 Store4(dst, result_pixels);
84 dst += dest_stride;
85 const int result_1 = _mm_extract_epi32(result_pixels, 1);
86 memcpy(dst, &result_1, sizeof(result_1));
87 dst += dest_stride;
88 const int result_2 = _mm_extract_epi32(result_pixels, 2);
89 memcpy(dst, &result_2, sizeof(result_2));
90 dst += dest_stride;
91 const int result_3 = _mm_extract_epi32(result_pixels, 3);
92 memcpy(dst, &result_3, sizeof(result_3));
93 dst += dest_stride;
94 }
95 }
96
97 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)98 inline void DistanceWeightedBlend8xH_SSE4_1(
99 const int16_t* LIBGAV1_RESTRICT pred_0,
100 const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
101 void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
102 auto* dst = static_cast<uint8_t*>(dest);
103 // Upscale the weight for mulhi.
104 const __m128i weights = _mm_set1_epi16(weight << 11);
105
106 for (int y = 0; y < height; y += 2) {
107 const __m128i src_00 = LoadAligned16(pred_0);
108 const __m128i src_10 = LoadAligned16(pred_1);
109 pred_0 += 8;
110 pred_1 += 8;
111 const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
112
113 const __m128i src_01 = LoadAligned16(pred_0);
114 const __m128i src_11 = LoadAligned16(pred_1);
115 pred_0 += 8;
116 pred_1 += 8;
117 const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
118
119 const __m128i result_pixels = _mm_packus_epi16(res0, res1);
120 StoreLo8(dst, result_pixels);
121 dst += dest_stride;
122 StoreHi8(dst, result_pixels);
123 dst += dest_stride;
124 }
125 }
126
DistanceWeightedBlendLarge_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)127 inline void DistanceWeightedBlendLarge_SSE4_1(
128 const int16_t* LIBGAV1_RESTRICT pred_0,
129 const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
130 const int width, const int height, void* LIBGAV1_RESTRICT const dest,
131 const ptrdiff_t dest_stride) {
132 auto* dst = static_cast<uint8_t*>(dest);
133 // Upscale the weight for mulhi.
134 const __m128i weights = _mm_set1_epi16(weight << 11);
135
136 int y = height;
137 do {
138 int x = 0;
139 do {
140 const __m128i src_0_lo = LoadAligned16(pred_0 + x);
141 const __m128i src_1_lo = LoadAligned16(pred_1 + x);
142 const __m128i res_lo =
143 ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
144
145 const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
146 const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
147 const __m128i res_hi =
148 ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
149
150 StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
151 x += 16;
152 } while (x < width);
153 dst += dest_stride;
154 pred_0 += width;
155 pred_1 += width;
156 } while (--y != 0);
157 }
158
DistanceWeightedBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)159 void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
160 const void* LIBGAV1_RESTRICT prediction_1,
161 const uint8_t weight_0,
162 const uint8_t /*weight_1*/, const int width,
163 const int height,
164 void* LIBGAV1_RESTRICT const dest,
165 const ptrdiff_t dest_stride) {
166 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
167 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
168 const uint8_t weight = weight_0;
169 if (width == 4) {
170 if (height == 4) {
171 DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight, dest,
172 dest_stride);
173 } else if (height == 8) {
174 DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight, dest,
175 dest_stride);
176 } else {
177 assert(height == 16);
178 DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight, dest,
179 dest_stride);
180 }
181 return;
182 }
183
184 if (width == 8) {
185 switch (height) {
186 case 4:
187 DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight, dest,
188 dest_stride);
189 return;
190 case 8:
191 DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight, dest,
192 dest_stride);
193 return;
194 case 16:
195 DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight, dest,
196 dest_stride);
197 return;
198 default:
199 assert(height == 32);
200 DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight, dest,
201 dest_stride);
202
203 return;
204 }
205 }
206
207 DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight, width, height, dest,
208 dest_stride);
209 }
210
Init8bpp()211 void Init8bpp() {
212 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
213 assert(dsp != nullptr);
214 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
215 dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
216 #endif
217 }
218
219 } // namespace
220 } // namespace low_bitdepth
221
222 #if LIBGAV1_MAX_BITDEPTH >= 10
223 namespace high_bitdepth {
224 namespace {
225
226 constexpr int kMax10bppSample = (1 << 10) - 1;
227 constexpr int kInterPostRoundBit = 4;
228
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weight0,const __m128i & weight1)229 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
230 const __m128i& pred1,
231 const __m128i& weight0,
232 const __m128i& weight1) {
233 // This offset is a combination of round_factor and round_offset
234 // which are to be added and subtracted respectively.
235 // Here kInterPostRoundBit + 4 is considering bitdepth=10.
236 constexpr int offset =
237 (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4);
238 const __m128i zero = _mm_setzero_si128();
239 const __m128i bias = _mm_set1_epi32(offset);
240 const __m128i clip_high = _mm_set1_epi16(kMax10bppSample);
241
242 __m128i prediction0 = _mm_cvtepu16_epi32(pred0);
243 __m128i mult0 = _mm_mullo_epi32(prediction0, weight0);
244 __m128i prediction1 = _mm_cvtepu16_epi32(pred1);
245 __m128i mult1 = _mm_mullo_epi32(prediction1, weight1);
246 __m128i sum = _mm_add_epi32(mult0, mult1);
247 sum = _mm_add_epi32(sum, bias);
248 const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
249
250 prediction0 = _mm_unpackhi_epi16(pred0, zero);
251 mult0 = _mm_mullo_epi32(prediction0, weight0);
252 prediction1 = _mm_unpackhi_epi16(pred1, zero);
253 mult1 = _mm_mullo_epi32(prediction1, weight1);
254 sum = _mm_add_epi32(mult0, mult1);
255 sum = _mm_add_epi32(sum, bias);
256 const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
257 const __m128i pack = _mm_packus_epi32(result0, result1);
258
259 return _mm_min_epi16(pack, clip_high);
260 }
261
262 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)263 inline void DistanceWeightedBlend4xH_SSE4_1(
264 const uint16_t* LIBGAV1_RESTRICT pred_0,
265 const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
266 const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
267 const ptrdiff_t dest_stride) {
268 auto* dst = static_cast<uint16_t*>(dest);
269 const __m128i weight0 = _mm_set1_epi32(weight_0);
270 const __m128i weight1 = _mm_set1_epi32(weight_1);
271
272 int y = height;
273 do {
274 const __m128i src_00 = LoadAligned16(pred_0);
275 const __m128i src_10 = LoadAligned16(pred_1);
276 pred_0 += 8;
277 pred_1 += 8;
278 const __m128i res0 =
279 ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
280
281 const __m128i src_01 = LoadAligned16(pred_0);
282 const __m128i src_11 = LoadAligned16(pred_1);
283 pred_0 += 8;
284 pred_1 += 8;
285 const __m128i res1 =
286 ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
287
288 StoreLo8(dst, res0);
289 dst += dest_stride;
290 StoreHi8(dst, res0);
291 dst += dest_stride;
292 StoreLo8(dst, res1);
293 dst += dest_stride;
294 StoreHi8(dst, res1);
295 dst += dest_stride;
296 y -= 4;
297 } while (y != 0);
298 }
299
300 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)301 inline void DistanceWeightedBlend8xH_SSE4_1(
302 const uint16_t* LIBGAV1_RESTRICT pred_0,
303 const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
304 const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
305 const ptrdiff_t dest_stride) {
306 auto* dst = static_cast<uint16_t*>(dest);
307 const __m128i weight0 = _mm_set1_epi32(weight_0);
308 const __m128i weight1 = _mm_set1_epi32(weight_1);
309
310 int y = height;
311 do {
312 const __m128i src_00 = LoadAligned16(pred_0);
313 const __m128i src_10 = LoadAligned16(pred_1);
314 pred_0 += 8;
315 pred_1 += 8;
316 const __m128i res0 =
317 ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
318
319 const __m128i src_01 = LoadAligned16(pred_0);
320 const __m128i src_11 = LoadAligned16(pred_1);
321 pred_0 += 8;
322 pred_1 += 8;
323 const __m128i res1 =
324 ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
325
326 StoreUnaligned16(dst, res0);
327 dst += dest_stride;
328 StoreUnaligned16(dst, res1);
329 dst += dest_stride;
330 y -= 2;
331 } while (y != 0);
332 }
333
DistanceWeightedBlendLarge_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)334 inline void DistanceWeightedBlendLarge_SSE4_1(
335 const uint16_t* LIBGAV1_RESTRICT pred_0,
336 const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
337 const uint8_t weight_1, const int width, const int height,
338 void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
339 auto* dst = static_cast<uint16_t*>(dest);
340 const __m128i weight0 = _mm_set1_epi32(weight_0);
341 const __m128i weight1 = _mm_set1_epi32(weight_1);
342
343 int y = height;
344 do {
345 int x = 0;
346 do {
347 const __m128i src_0_lo = LoadAligned16(pred_0 + x);
348 const __m128i src_1_lo = LoadAligned16(pred_1 + x);
349 const __m128i res_lo =
350 ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1);
351
352 const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
353 const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
354 const __m128i res_hi =
355 ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1);
356
357 StoreUnaligned16(dst + x, res_lo);
358 x += 8;
359 StoreUnaligned16(dst + x, res_hi);
360 x += 8;
361 } while (x < width);
362 dst += dest_stride;
363 pred_0 += width;
364 pred_1 += width;
365 } while (--y != 0);
366 }
367
DistanceWeightedBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)368 void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
369 const void* LIBGAV1_RESTRICT prediction_1,
370 const uint8_t weight_0,
371 const uint8_t weight_1, const int width,
372 const int height,
373 void* LIBGAV1_RESTRICT const dest,
374 const ptrdiff_t dest_stride) {
375 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
376 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
377 const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0);
378 if (width == 4) {
379 if (height == 4) {
380 DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
381 dest, dst_stride);
382 } else if (height == 8) {
383 DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
384 dest, dst_stride);
385 } else {
386 assert(height == 16);
387 DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
388 dest, dst_stride);
389 }
390 return;
391 }
392
393 if (width == 8) {
394 switch (height) {
395 case 4:
396 DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
397 dest, dst_stride);
398 return;
399 case 8:
400 DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
401 dest, dst_stride);
402 return;
403 case 16:
404 DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
405 dest, dst_stride);
406 return;
407 default:
408 assert(height == 32);
409 DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
410 dest, dst_stride);
411
412 return;
413 }
414 }
415
416 DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
417 height, dest, dst_stride);
418 }
419
Init10bpp()420 void Init10bpp() {
421 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
422 assert(dsp != nullptr);
423 #if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend)
424 dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
425 #endif
426 }
427
428 } // namespace
429 } // namespace high_bitdepth
430 #endif // LIBGAV1_MAX_BITDEPTH >= 10
431
DistanceWeightedBlendInit_SSE4_1()432 void DistanceWeightedBlendInit_SSE4_1() {
433 low_bitdepth::Init8bpp();
434 #if LIBGAV1_MAX_BITDEPTH >= 10
435 high_bitdepth::Init10bpp();
436 #endif
437 }
438
439 } // namespace dsp
440 } // namespace libgav1
441
442 #else // !LIBGAV1_TARGETING_SSE4_1
443
444 namespace libgav1 {
445 namespace dsp {
446
DistanceWeightedBlendInit_SSE4_1()447 void DistanceWeightedBlendInit_SSE4_1() {}
448
449 } // namespace dsp
450 } // namespace libgav1
451 #endif // LIBGAV1_TARGETING_SSE4_1
452