xref: /aosp_15_r20/external/libgav1/src/dsp/x86/distance_weighted_blend_sse4.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/distance_weighted_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_SSE4_1
19 
20 #include <xmmintrin.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35 
36 constexpr int kInterPostRoundBit = 4;
37 constexpr int kInterPostRhsAdjust = 1 << (16 - kInterPostRoundBit - 1);
38 
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weight)39 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
40                                        const __m128i& pred1,
41                                        const __m128i& weight) {
42   // Given: p0,p1 in range [-5132,9212] and w0 = 16 - w1, w1 = 16 - w0
43   // Output: (p0 * w0 + p1 * w1 + 128(=rounding bit)) >>
44   //    8(=kInterPostRoundBit + 4)
45   // The formula is manipulated to avoid lengthening to 32 bits.
46   // p0 * w0 + p1 * w1 = p0 * w0 + (16 - w0) * p1
47   // = (p0 - p1) * w0 + 16 * p1
48   // Maximum value of p0 - p1 is 9212 + 5132 = 0x3808.
49   const __m128i diff = _mm_slli_epi16(_mm_sub_epi16(pred0, pred1), 1);
50   // (((p0 - p1) * (w0 << 12) >> 16) + ((16 * p1) >> 4)
51   const __m128i weighted_diff = _mm_mulhi_epi16(diff, weight);
52   // ((p0 - p1) * w0 >> 4) + p1
53   const __m128i upscaled_average = _mm_add_epi16(weighted_diff, pred1);
54   // (x << 11) >> 15 == x >> 4
55   const __m128i right_shift_prep = _mm_set1_epi16(kInterPostRhsAdjust);
56   // (((p0 - p1) * w0 >> 4) + p1 + (128 >> 4)) >> 4
57   return _mm_mulhrs_epi16(upscaled_average, right_shift_prep);
58 }
59 
60 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)61 inline void DistanceWeightedBlend4xH_SSE4_1(
62     const int16_t* LIBGAV1_RESTRICT pred_0,
63     const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
64     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
65   auto* dst = static_cast<uint8_t*>(dest);
66   // Upscale the weight for mulhi.
67   const __m128i weights = _mm_set1_epi16(weight << 11);
68 
69   for (int y = 0; y < height; y += 4) {
70     const __m128i src_00 = LoadAligned16(pred_0);
71     const __m128i src_10 = LoadAligned16(pred_1);
72     pred_0 += 8;
73     pred_1 += 8;
74     const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
75 
76     const __m128i src_01 = LoadAligned16(pred_0);
77     const __m128i src_11 = LoadAligned16(pred_1);
78     pred_0 += 8;
79     pred_1 += 8;
80     const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
81 
82     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
83     Store4(dst, result_pixels);
84     dst += dest_stride;
85     const int result_1 = _mm_extract_epi32(result_pixels, 1);
86     memcpy(dst, &result_1, sizeof(result_1));
87     dst += dest_stride;
88     const int result_2 = _mm_extract_epi32(result_pixels, 2);
89     memcpy(dst, &result_2, sizeof(result_2));
90     dst += dest_stride;
91     const int result_3 = _mm_extract_epi32(result_pixels, 3);
92     memcpy(dst, &result_3, sizeof(result_3));
93     dst += dest_stride;
94   }
95 }
96 
97 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)98 inline void DistanceWeightedBlend8xH_SSE4_1(
99     const int16_t* LIBGAV1_RESTRICT pred_0,
100     const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
101     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
102   auto* dst = static_cast<uint8_t*>(dest);
103   // Upscale the weight for mulhi.
104   const __m128i weights = _mm_set1_epi16(weight << 11);
105 
106   for (int y = 0; y < height; y += 2) {
107     const __m128i src_00 = LoadAligned16(pred_0);
108     const __m128i src_10 = LoadAligned16(pred_1);
109     pred_0 += 8;
110     pred_1 += 8;
111     const __m128i res0 = ComputeWeightedAverage8(src_00, src_10, weights);
112 
113     const __m128i src_01 = LoadAligned16(pred_0);
114     const __m128i src_11 = LoadAligned16(pred_1);
115     pred_0 += 8;
116     pred_1 += 8;
117     const __m128i res1 = ComputeWeightedAverage8(src_01, src_11, weights);
118 
119     const __m128i result_pixels = _mm_packus_epi16(res0, res1);
120     StoreLo8(dst, result_pixels);
121     dst += dest_stride;
122     StoreHi8(dst, result_pixels);
123     dst += dest_stride;
124   }
125 }
126 
DistanceWeightedBlendLarge_SSE4_1(const int16_t * LIBGAV1_RESTRICT pred_0,const int16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)127 inline void DistanceWeightedBlendLarge_SSE4_1(
128     const int16_t* LIBGAV1_RESTRICT pred_0,
129     const int16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight,
130     const int width, const int height, void* LIBGAV1_RESTRICT const dest,
131     const ptrdiff_t dest_stride) {
132   auto* dst = static_cast<uint8_t*>(dest);
133   // Upscale the weight for mulhi.
134   const __m128i weights = _mm_set1_epi16(weight << 11);
135 
136   int y = height;
137   do {
138     int x = 0;
139     do {
140       const __m128i src_0_lo = LoadAligned16(pred_0 + x);
141       const __m128i src_1_lo = LoadAligned16(pred_1 + x);
142       const __m128i res_lo =
143           ComputeWeightedAverage8(src_0_lo, src_1_lo, weights);
144 
145       const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
146       const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
147       const __m128i res_hi =
148           ComputeWeightedAverage8(src_0_hi, src_1_hi, weights);
149 
150       StoreUnaligned16(dst + x, _mm_packus_epi16(res_lo, res_hi));
151       x += 16;
152     } while (x < width);
153     dst += dest_stride;
154     pred_0 += width;
155     pred_1 += width;
156   } while (--y != 0);
157 }
158 
DistanceWeightedBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)159 void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
160                                   const void* LIBGAV1_RESTRICT prediction_1,
161                                   const uint8_t weight_0,
162                                   const uint8_t /*weight_1*/, const int width,
163                                   const int height,
164                                   void* LIBGAV1_RESTRICT const dest,
165                                   const ptrdiff_t dest_stride) {
166   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
167   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
168   const uint8_t weight = weight_0;
169   if (width == 4) {
170     if (height == 4) {
171       DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight, dest,
172                                          dest_stride);
173     } else if (height == 8) {
174       DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight, dest,
175                                          dest_stride);
176     } else {
177       assert(height == 16);
178       DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight, dest,
179                                           dest_stride);
180     }
181     return;
182   }
183 
184   if (width == 8) {
185     switch (height) {
186       case 4:
187         DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight, dest,
188                                            dest_stride);
189         return;
190       case 8:
191         DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight, dest,
192                                            dest_stride);
193         return;
194       case 16:
195         DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight, dest,
196                                             dest_stride);
197         return;
198       default:
199         assert(height == 32);
200         DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight, dest,
201                                             dest_stride);
202 
203         return;
204     }
205   }
206 
207   DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight, width, height, dest,
208                                     dest_stride);
209 }
210 
Init8bpp()211 void Init8bpp() {
212   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
213   assert(dsp != nullptr);
214 #if DSP_ENABLED_8BPP_SSE4_1(DistanceWeightedBlend)
215   dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
216 #endif
217 }
218 
219 }  // namespace
220 }  // namespace low_bitdepth
221 
222 #if LIBGAV1_MAX_BITDEPTH >= 10
223 namespace high_bitdepth {
224 namespace {
225 
226 constexpr int kMax10bppSample = (1 << 10) - 1;
227 constexpr int kInterPostRoundBit = 4;
228 
ComputeWeightedAverage8(const __m128i & pred0,const __m128i & pred1,const __m128i & weight0,const __m128i & weight1)229 inline __m128i ComputeWeightedAverage8(const __m128i& pred0,
230                                        const __m128i& pred1,
231                                        const __m128i& weight0,
232                                        const __m128i& weight1) {
233   // This offset is a combination of round_factor and round_offset
234   // which are to be added and subtracted respectively.
235   // Here kInterPostRoundBit + 4 is considering bitdepth=10.
236   constexpr int offset =
237       (1 << ((kInterPostRoundBit + 4) - 1)) - (kCompoundOffset << 4);
238   const __m128i zero = _mm_setzero_si128();
239   const __m128i bias = _mm_set1_epi32(offset);
240   const __m128i clip_high = _mm_set1_epi16(kMax10bppSample);
241 
242   __m128i prediction0 = _mm_cvtepu16_epi32(pred0);
243   __m128i mult0 = _mm_mullo_epi32(prediction0, weight0);
244   __m128i prediction1 = _mm_cvtepu16_epi32(pred1);
245   __m128i mult1 = _mm_mullo_epi32(prediction1, weight1);
246   __m128i sum = _mm_add_epi32(mult0, mult1);
247   sum = _mm_add_epi32(sum, bias);
248   const __m128i result0 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
249 
250   prediction0 = _mm_unpackhi_epi16(pred0, zero);
251   mult0 = _mm_mullo_epi32(prediction0, weight0);
252   prediction1 = _mm_unpackhi_epi16(pred1, zero);
253   mult1 = _mm_mullo_epi32(prediction1, weight1);
254   sum = _mm_add_epi32(mult0, mult1);
255   sum = _mm_add_epi32(sum, bias);
256   const __m128i result1 = _mm_srai_epi32(sum, kInterPostRoundBit + 4);
257   const __m128i pack = _mm_packus_epi32(result0, result1);
258 
259   return _mm_min_epi16(pack, clip_high);
260 }
261 
262 template <int height>
DistanceWeightedBlend4xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)263 inline void DistanceWeightedBlend4xH_SSE4_1(
264     const uint16_t* LIBGAV1_RESTRICT pred_0,
265     const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
266     const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
267     const ptrdiff_t dest_stride) {
268   auto* dst = static_cast<uint16_t*>(dest);
269   const __m128i weight0 = _mm_set1_epi32(weight_0);
270   const __m128i weight1 = _mm_set1_epi32(weight_1);
271 
272   int y = height;
273   do {
274     const __m128i src_00 = LoadAligned16(pred_0);
275     const __m128i src_10 = LoadAligned16(pred_1);
276     pred_0 += 8;
277     pred_1 += 8;
278     const __m128i res0 =
279         ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
280 
281     const __m128i src_01 = LoadAligned16(pred_0);
282     const __m128i src_11 = LoadAligned16(pred_1);
283     pred_0 += 8;
284     pred_1 += 8;
285     const __m128i res1 =
286         ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
287 
288     StoreLo8(dst, res0);
289     dst += dest_stride;
290     StoreHi8(dst, res0);
291     dst += dest_stride;
292     StoreLo8(dst, res1);
293     dst += dest_stride;
294     StoreHi8(dst, res1);
295     dst += dest_stride;
296     y -= 4;
297   } while (y != 0);
298 }
299 
300 template <int height>
DistanceWeightedBlend8xH_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)301 inline void DistanceWeightedBlend8xH_SSE4_1(
302     const uint16_t* LIBGAV1_RESTRICT pred_0,
303     const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
304     const uint8_t weight_1, void* LIBGAV1_RESTRICT const dest,
305     const ptrdiff_t dest_stride) {
306   auto* dst = static_cast<uint16_t*>(dest);
307   const __m128i weight0 = _mm_set1_epi32(weight_0);
308   const __m128i weight1 = _mm_set1_epi32(weight_1);
309 
310   int y = height;
311   do {
312     const __m128i src_00 = LoadAligned16(pred_0);
313     const __m128i src_10 = LoadAligned16(pred_1);
314     pred_0 += 8;
315     pred_1 += 8;
316     const __m128i res0 =
317         ComputeWeightedAverage8(src_00, src_10, weight0, weight1);
318 
319     const __m128i src_01 = LoadAligned16(pred_0);
320     const __m128i src_11 = LoadAligned16(pred_1);
321     pred_0 += 8;
322     pred_1 += 8;
323     const __m128i res1 =
324         ComputeWeightedAverage8(src_01, src_11, weight0, weight1);
325 
326     StoreUnaligned16(dst, res0);
327     dst += dest_stride;
328     StoreUnaligned16(dst, res1);
329     dst += dest_stride;
330     y -= 2;
331   } while (y != 0);
332 }
333 
DistanceWeightedBlendLarge_SSE4_1(const uint16_t * LIBGAV1_RESTRICT pred_0,const uint16_t * LIBGAV1_RESTRICT pred_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)334 inline void DistanceWeightedBlendLarge_SSE4_1(
335     const uint16_t* LIBGAV1_RESTRICT pred_0,
336     const uint16_t* LIBGAV1_RESTRICT pred_1, const uint8_t weight_0,
337     const uint8_t weight_1, const int width, const int height,
338     void* LIBGAV1_RESTRICT const dest, const ptrdiff_t dest_stride) {
339   auto* dst = static_cast<uint16_t*>(dest);
340   const __m128i weight0 = _mm_set1_epi32(weight_0);
341   const __m128i weight1 = _mm_set1_epi32(weight_1);
342 
343   int y = height;
344   do {
345     int x = 0;
346     do {
347       const __m128i src_0_lo = LoadAligned16(pred_0 + x);
348       const __m128i src_1_lo = LoadAligned16(pred_1 + x);
349       const __m128i res_lo =
350           ComputeWeightedAverage8(src_0_lo, src_1_lo, weight0, weight1);
351 
352       const __m128i src_0_hi = LoadAligned16(pred_0 + x + 8);
353       const __m128i src_1_hi = LoadAligned16(pred_1 + x + 8);
354       const __m128i res_hi =
355           ComputeWeightedAverage8(src_0_hi, src_1_hi, weight0, weight1);
356 
357       StoreUnaligned16(dst + x, res_lo);
358       x += 8;
359       StoreUnaligned16(dst + x, res_hi);
360       x += 8;
361     } while (x < width);
362     dst += dest_stride;
363     pred_0 += width;
364     pred_1 += width;
365   } while (--y != 0);
366 }
367 
DistanceWeightedBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const uint8_t weight_0,const uint8_t weight_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)368 void DistanceWeightedBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
369                                   const void* LIBGAV1_RESTRICT prediction_1,
370                                   const uint8_t weight_0,
371                                   const uint8_t weight_1, const int width,
372                                   const int height,
373                                   void* LIBGAV1_RESTRICT const dest,
374                                   const ptrdiff_t dest_stride) {
375   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
376   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
377   const ptrdiff_t dst_stride = dest_stride / sizeof(*pred_0);
378   if (width == 4) {
379     if (height == 4) {
380       DistanceWeightedBlend4xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
381                                          dest, dst_stride);
382     } else if (height == 8) {
383       DistanceWeightedBlend4xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
384                                          dest, dst_stride);
385     } else {
386       assert(height == 16);
387       DistanceWeightedBlend4xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
388                                           dest, dst_stride);
389     }
390     return;
391   }
392 
393   if (width == 8) {
394     switch (height) {
395       case 4:
396         DistanceWeightedBlend8xH_SSE4_1<4>(pred_0, pred_1, weight_0, weight_1,
397                                            dest, dst_stride);
398         return;
399       case 8:
400         DistanceWeightedBlend8xH_SSE4_1<8>(pred_0, pred_1, weight_0, weight_1,
401                                            dest, dst_stride);
402         return;
403       case 16:
404         DistanceWeightedBlend8xH_SSE4_1<16>(pred_0, pred_1, weight_0, weight_1,
405                                             dest, dst_stride);
406         return;
407       default:
408         assert(height == 32);
409         DistanceWeightedBlend8xH_SSE4_1<32>(pred_0, pred_1, weight_0, weight_1,
410                                             dest, dst_stride);
411 
412         return;
413     }
414   }
415 
416   DistanceWeightedBlendLarge_SSE4_1(pred_0, pred_1, weight_0, weight_1, width,
417                                     height, dest, dst_stride);
418 }
419 
Init10bpp()420 void Init10bpp() {
421   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
422   assert(dsp != nullptr);
423 #if DSP_ENABLED_10BPP_SSE4_1(DistanceWeightedBlend)
424   dsp->distance_weighted_blend = DistanceWeightedBlend_SSE4_1;
425 #endif
426 }
427 
428 }  // namespace
429 }  // namespace high_bitdepth
430 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
431 
DistanceWeightedBlendInit_SSE4_1()432 void DistanceWeightedBlendInit_SSE4_1() {
433   low_bitdepth::Init8bpp();
434 #if LIBGAV1_MAX_BITDEPTH >= 10
435   high_bitdepth::Init10bpp();
436 #endif
437 }
438 
439 }  // namespace dsp
440 }  // namespace libgav1
441 
442 #else   // !LIBGAV1_TARGETING_SSE4_1
443 
444 namespace libgav1 {
445 namespace dsp {
446 
DistanceWeightedBlendInit_SSE4_1()447 void DistanceWeightedBlendInit_SSE4_1() {}
448 
449 }  // namespace dsp
450 }  // namespace libgav1
451 #endif  // LIBGAV1_TARGETING_SSE4_1
452