1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 // http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14
15 #include "src/dsp/average_blend.h"
16 #include "src/utils/cpu.h"
17
18 #if LIBGAV1_TARGETING_SSE4_1
19
20 #include <xmmintrin.h>
21
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25
26 #include "src/dsp/constants.h"
27 #include "src/dsp/dsp.h"
28 #include "src/dsp/x86/common_sse4.h"
29 #include "src/utils/common.h"
30
31 namespace libgav1 {
32 namespace dsp {
33 namespace low_bitdepth {
34 namespace {
35
36 constexpr int kInterPostRoundBit = 4;
37
AverageBlend4x4Row(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)38 inline void AverageBlend4x4Row(const int16_t* LIBGAV1_RESTRICT prediction_0,
39 const int16_t* LIBGAV1_RESTRICT prediction_1,
40 uint8_t* LIBGAV1_RESTRICT dest,
41 const ptrdiff_t dest_stride) {
42 const __m128i pred_00 = LoadAligned16(prediction_0);
43 const __m128i pred_10 = LoadAligned16(prediction_1);
44 __m128i res_0 = _mm_add_epi16(pred_00, pred_10);
45 res_0 = RightShiftWithRounding_S16(res_0, kInterPostRoundBit + 1);
46 const __m128i pred_01 = LoadAligned16(prediction_0 + 8);
47 const __m128i pred_11 = LoadAligned16(prediction_1 + 8);
48 __m128i res_1 = _mm_add_epi16(pred_01, pred_11);
49 res_1 = RightShiftWithRounding_S16(res_1, kInterPostRoundBit + 1);
50 const __m128i result_pixels = _mm_packus_epi16(res_0, res_1);
51 Store4(dest, result_pixels);
52 dest += dest_stride;
53 const int result_1 = _mm_extract_epi32(result_pixels, 1);
54 memcpy(dest, &result_1, sizeof(result_1));
55 dest += dest_stride;
56 const int result_2 = _mm_extract_epi32(result_pixels, 2);
57 memcpy(dest, &result_2, sizeof(result_2));
58 dest += dest_stride;
59 const int result_3 = _mm_extract_epi32(result_pixels, 3);
60 memcpy(dest, &result_3, sizeof(result_3));
61 }
62
AverageBlend8Row(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,uint8_t * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)63 inline void AverageBlend8Row(const int16_t* LIBGAV1_RESTRICT prediction_0,
64 const int16_t* LIBGAV1_RESTRICT prediction_1,
65 uint8_t* LIBGAV1_RESTRICT dest,
66 const ptrdiff_t dest_stride) {
67 const __m128i pred_00 = LoadAligned16(prediction_0);
68 const __m128i pred_10 = LoadAligned16(prediction_1);
69 __m128i res_0 = _mm_add_epi16(pred_00, pred_10);
70 res_0 = RightShiftWithRounding_S16(res_0, kInterPostRoundBit + 1);
71 const __m128i pred_01 = LoadAligned16(prediction_0 + 8);
72 const __m128i pred_11 = LoadAligned16(prediction_1 + 8);
73 __m128i res_1 = _mm_add_epi16(pred_01, pred_11);
74 res_1 = RightShiftWithRounding_S16(res_1, kInterPostRoundBit + 1);
75 const __m128i result_pixels = _mm_packus_epi16(res_0, res_1);
76 StoreLo8(dest, result_pixels);
77 StoreHi8(dest + dest_stride, result_pixels);
78 }
79
AverageBlendLargeRow(const int16_t * LIBGAV1_RESTRICT prediction_0,const int16_t * LIBGAV1_RESTRICT prediction_1,const int width,uint8_t * LIBGAV1_RESTRICT dest)80 inline void AverageBlendLargeRow(const int16_t* LIBGAV1_RESTRICT prediction_0,
81 const int16_t* LIBGAV1_RESTRICT prediction_1,
82 const int width,
83 uint8_t* LIBGAV1_RESTRICT dest) {
84 int x = 0;
85 do {
86 const __m128i pred_00 = LoadAligned16(&prediction_0[x]);
87 const __m128i pred_01 = LoadAligned16(&prediction_1[x]);
88 __m128i res0 = _mm_add_epi16(pred_00, pred_01);
89 res0 = RightShiftWithRounding_S16(res0, kInterPostRoundBit + 1);
90 const __m128i pred_10 = LoadAligned16(&prediction_0[x + 8]);
91 const __m128i pred_11 = LoadAligned16(&prediction_1[x + 8]);
92 __m128i res1 = _mm_add_epi16(pred_10, pred_11);
93 res1 = RightShiftWithRounding_S16(res1, kInterPostRoundBit + 1);
94 StoreUnaligned16(dest + x, _mm_packus_epi16(res0, res1));
95 x += 16;
96 } while (x < width);
97 }
98
AverageBlend_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dest_stride)99 void AverageBlend_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
100 const void* LIBGAV1_RESTRICT prediction_1,
101 const int width, const int height,
102 void* LIBGAV1_RESTRICT const dest,
103 const ptrdiff_t dest_stride) {
104 auto* dst = static_cast<uint8_t*>(dest);
105 const auto* pred_0 = static_cast<const int16_t*>(prediction_0);
106 const auto* pred_1 = static_cast<const int16_t*>(prediction_1);
107 int y = height;
108
109 if (width == 4) {
110 const ptrdiff_t dest_stride4 = dest_stride << 2;
111 constexpr ptrdiff_t width4 = 4 << 2;
112 do {
113 AverageBlend4x4Row(pred_0, pred_1, dst, dest_stride);
114 dst += dest_stride4;
115 pred_0 += width4;
116 pred_1 += width4;
117
118 y -= 4;
119 } while (y != 0);
120 return;
121 }
122
123 if (width == 8) {
124 const ptrdiff_t dest_stride2 = dest_stride << 1;
125 constexpr ptrdiff_t width2 = 8 << 1;
126 do {
127 AverageBlend8Row(pred_0, pred_1, dst, dest_stride);
128 dst += dest_stride2;
129 pred_0 += width2;
130 pred_1 += width2;
131
132 y -= 2;
133 } while (y != 0);
134 return;
135 }
136
137 do {
138 AverageBlendLargeRow(pred_0, pred_1, width, dst);
139 dst += dest_stride;
140 pred_0 += width;
141 pred_1 += width;
142
143 AverageBlendLargeRow(pred_0, pred_1, width, dst);
144 dst += dest_stride;
145 pred_0 += width;
146 pred_1 += width;
147
148 y -= 2;
149 } while (y != 0);
150 }
151
Init8bpp()152 void Init8bpp() {
153 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
154 assert(dsp != nullptr);
155 #if DSP_ENABLED_8BPP_SSE4_1(AverageBlend)
156 dsp->average_blend = AverageBlend_SSE4_1;
157 #endif
158 }
159
160 } // namespace
161 } // namespace low_bitdepth
162
163 #if LIBGAV1_MAX_BITDEPTH >= 10
164 namespace high_bitdepth {
165 namespace {
166
167 constexpr int kInterPostRoundBitPlusOne = 5;
168
169 template <const int width, const int offset>
AverageBlendRow(const uint16_t * LIBGAV1_RESTRICT prediction_0,const uint16_t * LIBGAV1_RESTRICT prediction_1,const __m128i & compound_offset,const __m128i & round_offset,const __m128i & max,const __m128i & zero,uint16_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dest_stride)170 inline void AverageBlendRow(const uint16_t* LIBGAV1_RESTRICT prediction_0,
171 const uint16_t* LIBGAV1_RESTRICT prediction_1,
172 const __m128i& compound_offset,
173 const __m128i& round_offset, const __m128i& max,
174 const __m128i& zero, uint16_t* LIBGAV1_RESTRICT dst,
175 const ptrdiff_t dest_stride) {
176 // pred_0/1 max range is 16b.
177 const __m128i pred_0 = LoadUnaligned16(prediction_0 + offset);
178 const __m128i pred_1 = LoadUnaligned16(prediction_1 + offset);
179 const __m128i pred_00 = _mm_cvtepu16_epi32(pred_0);
180 const __m128i pred_01 = _mm_unpackhi_epi16(pred_0, zero);
181 const __m128i pred_10 = _mm_cvtepu16_epi32(pred_1);
182 const __m128i pred_11 = _mm_unpackhi_epi16(pred_1, zero);
183
184 const __m128i pred_add_0 = _mm_add_epi32(pred_00, pred_10);
185 const __m128i pred_add_1 = _mm_add_epi32(pred_01, pred_11);
186 const __m128i compound_offset_0 = _mm_sub_epi32(pred_add_0, compound_offset);
187 const __m128i compound_offset_1 = _mm_sub_epi32(pred_add_1, compound_offset);
188 // RightShiftWithRounding and Clip3.
189 const __m128i round_0 = _mm_add_epi32(compound_offset_0, round_offset);
190 const __m128i round_1 = _mm_add_epi32(compound_offset_1, round_offset);
191 const __m128i res_0 = _mm_srai_epi32(round_0, kInterPostRoundBitPlusOne);
192 const __m128i res_1 = _mm_srai_epi32(round_1, kInterPostRoundBitPlusOne);
193 const __m128i result = _mm_min_epi16(_mm_packus_epi32(res_0, res_1), max);
194 if (width != 4) {
195 // Store width=8/16/32/64/128.
196 StoreUnaligned16(dst + offset, result);
197 return;
198 }
199 assert(width == 4);
200 StoreLo8(dst, result);
201 StoreHi8(dst + dest_stride, result);
202 }
203
AverageBlend10bpp_SSE4_1(const void * LIBGAV1_RESTRICT prediction_0,const void * LIBGAV1_RESTRICT prediction_1,const int width,const int height,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t dst_stride)204 void AverageBlend10bpp_SSE4_1(const void* LIBGAV1_RESTRICT prediction_0,
205 const void* LIBGAV1_RESTRICT prediction_1,
206 const int width, const int height,
207 void* LIBGAV1_RESTRICT const dest,
208 const ptrdiff_t dst_stride) {
209 auto* dst = static_cast<uint16_t*>(dest);
210 const ptrdiff_t dest_stride = dst_stride / sizeof(dst[0]);
211 const auto* pred_0 = static_cast<const uint16_t*>(prediction_0);
212 const auto* pred_1 = static_cast<const uint16_t*>(prediction_1);
213 const __m128i compound_offset =
214 _mm_set1_epi32(kCompoundOffset + kCompoundOffset);
215 const __m128i round_offset =
216 _mm_set1_epi32((1 << kInterPostRoundBitPlusOne) >> 1);
217 const __m128i max = _mm_set1_epi16((1 << kBitdepth10) - 1);
218 const __m128i zero = _mm_setzero_si128();
219 int y = height;
220
221 if (width == 4) {
222 const ptrdiff_t dest_stride2 = dest_stride << 1;
223 const ptrdiff_t width2 = width << 1;
224 do {
225 // row0,1
226 AverageBlendRow<4, 0>(pred_0, pred_1, compound_offset, round_offset, max,
227 zero, dst, dest_stride);
228 dst += dest_stride2;
229 pred_0 += width2;
230 pred_1 += width2;
231 y -= 2;
232 } while (y != 0);
233 return;
234 }
235 if (width == 8) {
236 const ptrdiff_t dest_stride2 = dest_stride << 1;
237 const ptrdiff_t width2 = width << 1;
238 do {
239 // row0.
240 AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
241 zero, dst, dest_stride);
242 // row1.
243 AverageBlendRow<8, 0>(pred_0 + width, pred_1 + width, compound_offset,
244 round_offset, max, zero, dst + dest_stride,
245 dest_stride);
246 dst += dest_stride2;
247 pred_0 += width2;
248 pred_1 += width2;
249 y -= 2;
250 } while (y != 0);
251 return;
252 }
253 if (width == 16) {
254 const ptrdiff_t dest_stride2 = dest_stride << 1;
255 const ptrdiff_t width2 = width << 1;
256 do {
257 // row0.
258 AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
259 zero, dst, dest_stride);
260 AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max,
261 zero, dst, dest_stride);
262 // row1.
263 AverageBlendRow<8, 0>(pred_0 + width, pred_1 + width, compound_offset,
264 round_offset, max, zero, dst + dest_stride,
265 dest_stride);
266 AverageBlendRow<8, 8>(pred_0 + width, pred_1 + width, compound_offset,
267 round_offset, max, zero, dst + dest_stride,
268 dest_stride);
269 dst += dest_stride2;
270 pred_0 += width2;
271 pred_1 += width2;
272 y -= 2;
273 } while (y != 0);
274 return;
275 }
276 if (width == 32) {
277 do {
278 // pred [0 - 15].
279 AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
280 zero, dst, dest_stride);
281 AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max,
282 zero, dst, dest_stride);
283 // pred [16 - 31].
284 AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max,
285 zero, dst, dest_stride);
286 AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max,
287 zero, dst, dest_stride);
288 dst += dest_stride;
289 pred_0 += width;
290 pred_1 += width;
291 } while (--y != 0);
292 return;
293 }
294 if (width == 64) {
295 do {
296 // pred [0 - 31].
297 AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
298 zero, dst, dest_stride);
299 AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max,
300 zero, dst, dest_stride);
301 AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max,
302 zero, dst, dest_stride);
303 AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max,
304 zero, dst, dest_stride);
305 // pred [31 - 63].
306 AverageBlendRow<8, 32>(pred_0, pred_1, compound_offset, round_offset, max,
307 zero, dst, dest_stride);
308 AverageBlendRow<8, 40>(pred_0, pred_1, compound_offset, round_offset, max,
309 zero, dst, dest_stride);
310 AverageBlendRow<8, 48>(pred_0, pred_1, compound_offset, round_offset, max,
311 zero, dst, dest_stride);
312 AverageBlendRow<8, 56>(pred_0, pred_1, compound_offset, round_offset, max,
313 zero, dst, dest_stride);
314 dst += dest_stride;
315 pred_0 += width;
316 pred_1 += width;
317 } while (--y != 0);
318 return;
319 }
320 assert(width == 128);
321 do {
322 // pred [0 - 31].
323 AverageBlendRow<8, 0>(pred_0, pred_1, compound_offset, round_offset, max,
324 zero, dst, dest_stride);
325 AverageBlendRow<8, 8>(pred_0, pred_1, compound_offset, round_offset, max,
326 zero, dst, dest_stride);
327 AverageBlendRow<8, 16>(pred_0, pred_1, compound_offset, round_offset, max,
328 zero, dst, dest_stride);
329 AverageBlendRow<8, 24>(pred_0, pred_1, compound_offset, round_offset, max,
330 zero, dst, dest_stride);
331 // pred [31 - 63].
332 AverageBlendRow<8, 32>(pred_0, pred_1, compound_offset, round_offset, max,
333 zero, dst, dest_stride);
334 AverageBlendRow<8, 40>(pred_0, pred_1, compound_offset, round_offset, max,
335 zero, dst, dest_stride);
336 AverageBlendRow<8, 48>(pred_0, pred_1, compound_offset, round_offset, max,
337 zero, dst, dest_stride);
338 AverageBlendRow<8, 56>(pred_0, pred_1, compound_offset, round_offset, max,
339 zero, dst, dest_stride);
340
341 // pred [64 - 95].
342 AverageBlendRow<8, 64>(pred_0, pred_1, compound_offset, round_offset, max,
343 zero, dst, dest_stride);
344 AverageBlendRow<8, 72>(pred_0, pred_1, compound_offset, round_offset, max,
345 zero, dst, dest_stride);
346 AverageBlendRow<8, 80>(pred_0, pred_1, compound_offset, round_offset, max,
347 zero, dst, dest_stride);
348 AverageBlendRow<8, 88>(pred_0, pred_1, compound_offset, round_offset, max,
349 zero, dst, dest_stride);
350 // pred [96 - 127].
351 AverageBlendRow<8, 96>(pred_0, pred_1, compound_offset, round_offset, max,
352 zero, dst, dest_stride);
353 AverageBlendRow<8, 104>(pred_0, pred_1, compound_offset, round_offset, max,
354 zero, dst, dest_stride);
355 AverageBlendRow<8, 112>(pred_0, pred_1, compound_offset, round_offset, max,
356 zero, dst, dest_stride);
357 AverageBlendRow<8, 120>(pred_0, pred_1, compound_offset, round_offset, max,
358 zero, dst, dest_stride);
359 dst += dest_stride;
360 pred_0 += width;
361 pred_1 += width;
362 } while (--y != 0);
363 }
364
Init10bpp()365 void Init10bpp() {
366 Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
367 assert(dsp != nullptr);
368 #if DSP_ENABLED_10BPP_SSE4_1(AverageBlend)
369 dsp->average_blend = AverageBlend10bpp_SSE4_1;
370 #endif
371 }
372
373 } // namespace
374 } // namespace high_bitdepth
375 #endif // LIBGAV1_MAX_BITDEPTH >= 10
376
AverageBlendInit_SSE4_1()377 void AverageBlendInit_SSE4_1() {
378 low_bitdepth::Init8bpp();
379 #if LIBGAV1_MAX_BITDEPTH >= 10
380 high_bitdepth::Init10bpp();
381 #endif // LIBGAV1_MAX_BITDEPTH >= 10
382 }
383
384 } // namespace dsp
385 } // namespace libgav1
386
387 #else // !LIBGAV1_TARGETING_SSE4_1
388
389 namespace libgav1 {
390 namespace dsp {
391
AverageBlendInit_SSE4_1()392 void AverageBlendInit_SSE4_1() {}
393
394 } // namespace dsp
395 } // namespace libgav1
396 #endif // LIBGAV1_TARGETING_SSE4_1
397