xref: /aosp_15_r20/external/libgav1/src/dsp/x86/average_blend_sse4.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/average_blend.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_SSE4_1
19 
20 #include <xmmintrin.h>
21 
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30 
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35 
36 constexpr int kInterPostRoundBit = 4;
37 
AverageBlend4x4Row(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)38 inline void AverageBlend4x4Row(const int16_t* LIBGAV1_RESTRICT prediction_0,
39                                const int16_t* LIBGAV1_RESTRICT prediction_1,
40                                uint8_t* LIBGAV1_RESTRICT dest,
41                                const ptrdiff_t dest_stride) {
42   const __m128i pred_00 = LoadAligned16(prediction_0);
43   const __m128i pred_10 = LoadAligned16(prediction_1);
44   __m128i res_0 = _mm_add_epi16(pred_00, pred_10);
45   res_0 = RightShiftWithRounding_S16(res_0, kInterPostRoundBit + 1);
46   const __m128i pred_01 = LoadAligned16(prediction_0 + 8);
47   const __m128i pred_11 = LoadAligned16(prediction_1 + 8);
48   __m128i res_1 = _mm_add_epi16(pred_01, pred_11);
49   res_1 = RightShiftWithRounding_S16(res_1, kInterPostRoundBit + 1);
50   const __m128i result_pixels = _mm_packus_epi16(res_0, res_1);
51   Store4(dest, result_pixels);
52   dest += dest_stride;
53   const int result_1 = _mm_extract_epi32(result_pixels, 1);
54   memcpy(dest, &result_1, sizeof(result_1));
55   dest += dest_stride;
56   const int result_2 = _mm_extract_epi32(result_pixels, 2);
57   memcpy(dest, &result_2, sizeof(result_2));
58   dest += dest_stride;
59   const int result_3 = _mm_extract_epi32(result_pixels, 3);
60   memcpy(dest, &result_3, sizeof(result_3));
61 }
62 
AverageBlend8Row(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)63 inline void AverageBlend8Row(const int16_t* LIBGAV1_RESTRICT prediction_0,
64                              const int16_t* LIBGAV1_RESTRICT prediction_1,
65                              uint8_t* LIBGAV1_RESTRICT dest,
66                              const ptrdiff_t dest_stride) {
67   const __m128i pred_00 = LoadAligned16(prediction_0);
68   const __m128i pred_10 = LoadAligned16(prediction_1);
69   __m128i res_0 = _mm_add_epi16(pred_00, pred_10);
70   res_0 = RightShiftWithRounding_S16(res_0, kInterPostRoundBit + 1);
71   const __m128i pred_01 = LoadAligned16(prediction_0 + 8);
72   const __m128i pred_11 = LoadAligned16(prediction_1 + 8);
73   __m128i res_1 = _mm_add_epi16(pred_01, pred_11);
74   res_1 = RightShiftWithRounding_S16(res_1, kInterPostRoundBit + 1);
75   const __m128i result_pixels = _mm_packus_epi16(res_0, res_1);
76   StoreLo8(dest, result_pixels);
77   StoreHi8(dest + dest_stride, result_pixels);
78 }
79 
AverageBlendLargeRow(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,const int width,uint8_t * LIBGAV1_RESTRICT dest)80 inline void AverageBlendLargeRow(const int16_t* LIBGAV1_RESTRICT prediction_0,
81                                  const int16_t* LIBGAV1_RESTRICT prediction_1,
82                                  const int width,
83                                  uint8_t* LIBGAV1_RESTRICT dest) {
84   int x = 0;
85   do {
86     const __m128i pred_00 = LoadAligned16(&prediction_0[x]);
87     const __m128i pred_01 = LoadAligned16(&prediction_1[x]);
88     __m128i res0 = _mm_add_epi16(pred_00, pred_01);
89     res0 = RightShiftWithRounding_S16(res0, kInterPostRoundBit + 1);
90     const __m128i pred_10 = LoadAligned16(&prediction_0[x + 8]);
91     const __m128i pred_11 = LoadAligned16(&prediction_1[x + 8]);
92     __m128i res1 = _mm_add_epi16(pred_10, pred_11);
93     res1 = RightShiftWithRounding_S16(res1, kInterPostRoundBit + 1);
94     StoreUnaligned16(dest + x, _mm_packus_epi16(res0, res1));
95     x += 16;
96   } while (x < width);
97 }
98 
AverageBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)99 void AverageBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
100                          const void* LIBGAV1_RESTRICT prediction_1,
101                          const int width, const int height,
102                          void* LIBGAV1_RESTRICT const dest,
103                          const ptrdiff_t dest_stride) {
104   auto* dst = static_cast<uint8_t*>(dest);
105   const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
106   const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
107   int y = height;
108 
109   if (width == 4) {
110     const ptrdiff_t dest_stride4 = dest_stride << 2;
111     constexpr ptrdiff_t width4 = 4 << 2;
112     do {
113       AverageBlend4x4Row(pred_0, pred_1, dst, dest_stride);
114       dst += dest_stride4;
115       pred_0 += width4;
116       pred_1 += width4;
117 
118       y -= 4;
119     } while (y != 0);
120     return;
121   }
122 
123   if (width == 8) {
124     const ptrdiff_t dest_stride2 = dest_stride << 1;
125     constexpr ptrdiff_t width2 = 8 << 1;
126     do {
127       AverageBlend8Row(pred_0, pred_1, dst, dest_stride);
128       dst += dest_stride2;
129       pred_0 += width2;
130       pred_1 += width2;
131 
132       y -= 2;
133     } while (y != 0);
134     return;
135   }
136 
137   do {
138     AverageBlendLargeRow(pred_0, pred_1, width, dst);
139     dst += dest_stride;
140     pred_0 += width;
141     pred_1 += width;
142 
143     AverageBlendLargeRow(pred_0, pred_1, width, dst);
144     dst += dest_stride;
145     pred_0 += width;
146     pred_1 += width;
147 
148     y -= 2;
149   } while (y != 0);
150 }
151 
Init8bpp()152 void Init8bpp() {
153   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
154   assert(dsp != nullptr);
155 #if DSP_ENABLED_8BPP_SSE4_1(AverageBlend)
156   dsp->average_blend = AverageBlend_SSE4_1;
157 #endif
158 }
159 
160 }  // namespace
161 }  // namespace low_bitdepth
162 
163 #if LIBGAV1_MAX_BITDEPTH >= 10
164 namespace high_bitdepth {
165 namespace {
166 
167 constexpr int kInterPostRoundBitPlusOne = 5;
168 
169 template <const int width, const int offset>
AverageBlendRow(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1,const __m128i & compound_offset,const __m128i & round_offset,const __m128i & max,const __m128i & zero,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dest_stride)170 inline void AverageBlendRow(const uint16_t* LIBGAV1_RESTRICT prediction_0,
171                             const uint16_t* LIBGAV1_RESTRICT prediction_1,
172                             const __m128i& compound_offset,
173                             const __m128i& round_offset, const __m128i& max,
174                             const __m128i& zero, uint16_t* LIBGAV1_RESTRICT dst,
175                             const ptrdiff_t dest_stride) {
176   // pred_0/1 max range is 16b.
177   const __m128i pred_0 = LoadUnaligned16(prediction_0 + offset);
178   const __m128i pred_1 = LoadUnaligned16(prediction_1 + offset);
179   const __m128i pred_00 = _mm_cvtepu16_epi32(pred_0);
180   const __m128i pred_01 = _mm_unpackhi_epi16(pred_0, zero);
181   const __m128i pred_10 = _mm_cvtepu16_epi32(pred_1);
182   const __m128i pred_11 = _mm_unpackhi_epi16(pred_1, zero);
183 
184   const __m128i pred_add_0 = _mm_add_epi32(pred_00, pred_10);
185   const __m128i pred_add_1 = _mm_add_epi32(pred_01, pred_11);
186   const __m128i compound_offset_0 = _mm_sub_epi32(pred_add_0, compound_offset);
187   const __m128i compound_offset_1 = _mm_sub_epi32(pred_add_1, compound_offset);
188   // RightShiftWithRounding and Clip3.
189   const __m128i round_0 = _mm_add_epi32(compound_offset_0, round_offset);
190   const __m128i round_1 = _mm_add_epi32(compound_offset_1, round_offset);
191   const __m128i res_0 = _mm_srai_epi32(round_0, kInterPostRoundBitPlusOne);
192   const __m128i res_1 = _mm_srai_epi32(round_1, kInterPostRoundBitPlusOne);
193   const __m128i result = _mm_min_epi16(_mm_packus_epi32(res_0, res_1), max);
194   if (width != 4) {
195     // Store width=8/16/32/64/128.
196     StoreUnaligned16(dst + offset, result);
197     return;
198   }
199   assert(width == 4);
200   StoreLo8(dst, result);
201   StoreHi8(dst + dest_stride, result);
202 }
203 
AverageBlend10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dst_stride)204 void AverageBlend10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
205                               const void* LIBGAV1_RESTRICT prediction_1,
206                               const int width, const int height,
207                               void* LIBGAV1_RESTRICT const dest,
208                               const ptrdiff_t dst_stride) {
209   auto* dst = static_cast<uint16_t*>(dest);
210   const ptrdiff_t dest_stride = dst_stride / sizeof(dst[0]);
211   const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
212   const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
213   const __m128i compound_offset =
214       _mm_set1_epi32(kCompoundOffset + kCompoundOffset);
215   const __m128i round_offset =
216       _mm_set1_epi32((1 << kInterPostRoundBitPlusOne) >> 1);
217   const __m128i max = _mm_set1_epi16((1 << kBitdepth10) - 1);
218   const __m128i zero = _mm_setzero_si128();
219   int y = height;
220 
221   if (width == 4) {
222     const ptrdiff_t dest_stride2 = dest_stride << 1;
223     const ptrdiff_t width2 = width << 1;
224     do {
225       // row0,1
226       AverageBlendRow<4, 0>(pred_0, pred_1, compound_offset, round_offset, max,
227                             zero, dst, dest_stride);
228       dst += dest_stride2;
229       pred_0 += width2;
230       pred_1 += width2;
231       y -= 2;
232     } while (y != 0);
233     return;
234   }
235   if (width == 8) {
236     const ptrdiff_t dest_stride2 = dest_stride << 1;
237     const ptrdiff_t width2 = width << 1;
238     do {
239       // row0.
240       AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
241                             zero, dst, dest_stride);
242       // row1.
243       AverageBlendRow<8, 0>(pred_0 + width, pred_1 + width, compound_offset,
244                             round_offset, max, zero, dst + dest_stride,
245                             dest_stride);
246       dst += dest_stride2;
247       pred_0 += width2;
248       pred_1 += width2;
249       y -= 2;
250     } while (y != 0);
251     return;
252   }
253   if (width == 16) {
254     const ptrdiff_t dest_stride2 = dest_stride << 1;
255     const ptrdiff_t width2 = width << 1;
256     do {
257       // row0.
258       AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
259                             zero, dst, dest_stride);
260       AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max,
261                             zero, dst, dest_stride);
262       // row1.
263       AverageBlendRow<8, 0>(pred_0 + width, pred_1 + width, compound_offset,
264                             round_offset, max, zero, dst + dest_stride,
265                             dest_stride);
266       AverageBlendRow<8, 8>(pred_0 + width, pred_1 + width, compound_offset,
267                             round_offset, max, zero, dst + dest_stride,
268                             dest_stride);
269       dst += dest_stride2;
270       pred_0 += width2;
271       pred_1 += width2;
272       y -= 2;
273     } while (y != 0);
274     return;
275   }
276   if (width == 32) {
277     do {
278       // pred [0 - 15].
279       AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
280                             zero, dst, dest_stride);
281       AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max,
282                             zero, dst, dest_stride);
283       // pred [16 - 31].
284       AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max,
285                              zero, dst, dest_stride);
286       AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max,
287                              zero, dst, dest_stride);
288       dst += dest_stride;
289       pred_0 += width;
290       pred_1 += width;
291     } while (--y != 0);
292     return;
293   }
294   if (width == 64) {
295     do {
296       // pred [0 - 31].
297       AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
298                             zero, dst, dest_stride);
299       AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max,
300                             zero, dst, dest_stride);
301       AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max,
302                              zero, dst, dest_stride);
303       AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max,
304                              zero, dst, dest_stride);
305       // pred [31 - 63].
306       AverageBlendRow<8, 32>(pred_0, pred_1, compound_offset, round_offset, max,
307                              zero, dst, dest_stride);
308       AverageBlendRow<8, 40>(pred_0, pred_1, compound_offset, round_offset, max,
309                              zero, dst, dest_stride);
310       AverageBlendRow<8, 48>(pred_0, pred_1, compound_offset, round_offset, max,
311                              zero, dst, dest_stride);
312       AverageBlendRow<8, 56>(pred_0, pred_1, compound_offset, round_offset, max,
313                              zero, dst, dest_stride);
314       dst += dest_stride;
315       pred_0 += width;
316       pred_1 += width;
317     } while (--y != 0);
318     return;
319   }
320   assert(width == 128);
321   do {
322     // pred [0 - 31].
323     AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
324                           zero, dst, dest_stride);
325     AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max,
326                           zero, dst, dest_stride);
327     AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max,
328                            zero, dst, dest_stride);
329     AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max,
330                            zero, dst, dest_stride);
331     // pred [31 - 63].
332     AverageBlendRow<8, 32>(pred_0, pred_1, compound_offset, round_offset, max,
333                            zero, dst, dest_stride);
334     AverageBlendRow<8, 40>(pred_0, pred_1, compound_offset, round_offset, max,
335                            zero, dst, dest_stride);
336     AverageBlendRow<8, 48>(pred_0, pred_1, compound_offset, round_offset, max,
337                            zero, dst, dest_stride);
338     AverageBlendRow<8, 56>(pred_0, pred_1, compound_offset, round_offset, max,
339                            zero, dst, dest_stride);
340 
341     // pred [64 - 95].
342     AverageBlendRow<8, 64>(pred_0, pred_1, compound_offset, round_offset, max,
343                            zero, dst, dest_stride);
344     AverageBlendRow<8, 72>(pred_0, pred_1, compound_offset, round_offset, max,
345                            zero, dst, dest_stride);
346     AverageBlendRow<8, 80>(pred_0, pred_1, compound_offset, round_offset, max,
347                            zero, dst, dest_stride);
348     AverageBlendRow<8, 88>(pred_0, pred_1, compound_offset, round_offset, max,
349                            zero, dst, dest_stride);
350     // pred [96 - 127].
351     AverageBlendRow<8, 96>(pred_0, pred_1, compound_offset, round_offset, max,
352                            zero, dst, dest_stride);
353     AverageBlendRow<8, 104>(pred_0, pred_1, compound_offset, round_offset, max,
354                             zero, dst, dest_stride);
355     AverageBlendRow<8, 112>(pred_0, pred_1, compound_offset, round_offset, max,
356                             zero, dst, dest_stride);
357     AverageBlendRow<8, 120>(pred_0, pred_1, compound_offset, round_offset, max,
358                             zero, dst, dest_stride);
359     dst += dest_stride;
360     pred_0 += width;
361     pred_1 += width;
362   } while (--y != 0);
363 }
364 
Init10bpp()365 void Init10bpp() {
366   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
367   assert(dsp != nullptr);
368 #if DSP_ENABLED_10BPP_SSE4_1(AverageBlend)
369   dsp->average_blend = AverageBlend10bpp_SSE4_1;
370 #endif
371 }
372 
373 }  // namespace
374 }  // namespace high_bitdepth
375 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
376 
AverageBlendInit_SSE4_1()377 void AverageBlendInit_SSE4_1() {
378   low_bitdepth::Init8bpp();
379 #if LIBGAV1_MAX_BITDEPTH >= 10
380   high_bitdepth::Init10bpp();
381 #endif  // LIBGAV1_MAX_BITDEPTH >= 10
382 }
383 
384 }  // namespace dsp
385 }  // namespace libgav1
386 
387 #else   // !LIBGAV1_TARGETING_SSE4_1
388 
389 namespace libgav1 {
390 namespace dsp {
391 
AverageBlendInit_SSE4_1()392 void AverageBlendInit_SSE4_1() {}
393 
394 }  // namespace dsp
395 }  // namespace libgav1
396 #endif  // LIBGAV1_TARGETING_SSE4_1
397