xref: /aosp_15_r20/external/libgav1/src/dsp/x86/convolve_sse4.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2019 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/convolve.h"
16 #include "src/utils/constants.h"
17 #include "src/utils/cpu.h"
18 
19 #if LIBGAV1_TARGETING_SSE4_1
20 #include <smmintrin.h>
21 
22 #include <algorithm>
23 #include <cassert>
24 #include <cstdint>
25 #include <cstring>
26 
27 #include "src/dsp/constants.h"
28 #include "src/dsp/dsp.h"
29 #include "src/dsp/x86/common_sse4.h"
30 #include "src/utils/common.h"
31 #include "src/utils/compiler_attributes.h"
32 
33 namespace libgav1 {
34 namespace dsp {
35 namespace low_bitdepth {
36 namespace {
37 
38 #include "src/dsp/x86/convolve_sse4.inc"
39 
40 template <int num_taps>
SumHorizontalTaps(const uint8_t * LIBGAV1_RESTRICT const src,const __m128i * const v_tap)41 __m128i SumHorizontalTaps(const uint8_t* LIBGAV1_RESTRICT const src,
42                           const __m128i* const v_tap) {
43   __m128i v_src[4];
44   const __m128i src_long = LoadUnaligned16(src);
45   const __m128i src_long_dup_lo = _mm_unpacklo_epi8(src_long, src_long);
46   const __m128i src_long_dup_hi = _mm_unpackhi_epi8(src_long, src_long);
47 
48   if (num_taps == 6) {
49     // 6 taps.
50     v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 3);   // _21
51     v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7);   // _43
52     v_src[2] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 11);  // _65
53   } else if (num_taps == 8) {
54     // 8 taps.
55     v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 1);   // _10
56     v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5);   // _32
57     v_src[2] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9);   // _54
58     v_src[3] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 13);  // _76
59   } else if (num_taps == 2) {
60     // 2 taps.
61     v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 7);  // _43
62   } else {
63     // 4 taps.
64     v_src[0] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 5);  // _32
65     v_src[1] = _mm_alignr_epi8(src_long_dup_hi, src_long_dup_lo, 9);  // _54
66   }
67   const __m128i sum = SumOnePassTaps<num_taps>(v_src, v_tap);
68   return sum;
69 }
70 
71 template <int num_taps>
SimpleHorizontalTaps(const uint8_t * LIBGAV1_RESTRICT const src,const __m128i * const v_tap)72 __m128i SimpleHorizontalTaps(const uint8_t* LIBGAV1_RESTRICT const src,
73                              const __m128i* const v_tap) {
74   __m128i sum = SumHorizontalTaps<num_taps>(src, v_tap);
75 
76   // Normally the Horizontal pass does the downshift in two passes:
77   // kInterRoundBitsHorizontal - 1 and then (kFilterBits -
78   // kInterRoundBitsHorizontal). Each one uses a rounding shift. Combining them
79   // requires adding the rounding offset from the skipped shift.
80   constexpr int first_shift_rounding_bit = 1 << (kInterRoundBitsHorizontal - 2);
81 
82   sum = _mm_add_epi16(sum, _mm_set1_epi16(first_shift_rounding_bit));
83   sum = RightShiftWithRounding_S16(sum, kFilterBits - 1);
84   return _mm_packus_epi16(sum, sum);
85 }
86 
87 template <int num_taps>
HorizontalTaps8To16(const uint8_t * LIBGAV1_RESTRICT const src,const __m128i * const v_tap)88 __m128i HorizontalTaps8To16(const uint8_t* LIBGAV1_RESTRICT const src,
89                             const __m128i* const v_tap) {
90   const __m128i sum = SumHorizontalTaps<num_taps>(src, v_tap);
91 
92   return RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
93 }
94 
95 template <int num_taps, bool is_2d = false, bool is_compound = false>
FilterHorizontal(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dest,const ptrdiff_t pred_stride,const int width,const int height,const __m128i * const v_tap)96 void FilterHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
97                       const ptrdiff_t src_stride,
98                       void* LIBGAV1_RESTRICT const dest,
99                       const ptrdiff_t pred_stride, const int width,
100                       const int height, const __m128i* const v_tap) {
101   auto* dest8 = static_cast<uint8_t*>(dest);
102   auto* dest16 = static_cast<uint16_t*>(dest);
103 
104   // 4 tap filters are never used when width > 4.
105   if (num_taps != 4 && width > 4) {
106     int y = height;
107     do {
108       int x = 0;
109       do {
110         if (is_2d || is_compound) {
111           const __m128i v_sum = HorizontalTaps8To16<num_taps>(&src[x], v_tap);
112           if (is_2d) {
113             StoreAligned16(&dest16[x], v_sum);
114           } else {
115             StoreUnaligned16(&dest16[x], v_sum);
116           }
117         } else {
118           const __m128i result = SimpleHorizontalTaps<num_taps>(&src[x], v_tap);
119           StoreLo8(&dest8[x], result);
120         }
121         x += 8;
122       } while (x < width);
123       src += src_stride;
124       dest8 += pred_stride;
125       dest16 += pred_stride;
126     } while (--y != 0);
127     return;
128   }
129 
130   // Horizontal passes only needs to account for |num_taps| 2 and 4 when
131   // |width| <= 4.
132   assert(width <= 4);
133   assert(num_taps <= 4);
134   if (num_taps <= 4) {
135     if (width == 4) {
136       int y = height;
137       do {
138         if (is_2d || is_compound) {
139           const __m128i v_sum = HorizontalTaps8To16<num_taps>(src, v_tap);
140           StoreLo8(dest16, v_sum);
141         } else {
142           const __m128i result = SimpleHorizontalTaps<num_taps>(src, v_tap);
143           Store4(&dest8[0], result);
144         }
145         src += src_stride;
146         dest8 += pred_stride;
147         dest16 += pred_stride;
148       } while (--y != 0);
149       return;
150     }
151 
152     if (!is_compound) {
153       int y = height;
154       if (is_2d) y -= 1;
155       do {
156         if (is_2d) {
157           const __m128i sum =
158               HorizontalTaps8To16_2x2<num_taps>(src, src_stride, v_tap);
159           Store4(&dest16[0], sum);
160           dest16 += pred_stride;
161           Store4(&dest16[0], _mm_srli_si128(sum, 8));
162           dest16 += pred_stride;
163         } else {
164           const __m128i sum =
165               SimpleHorizontalTaps2x2<num_taps>(src, src_stride, v_tap);
166           Store2(dest8, sum);
167           dest8 += pred_stride;
168           Store2(dest8, _mm_srli_si128(sum, 4));
169           dest8 += pred_stride;
170         }
171 
172         src += src_stride << 1;
173         y -= 2;
174       } while (y != 0);
175 
176       // The 2d filters have an odd |height| because the horizontal pass
177       // generates context for the vertical pass.
178       if (is_2d) {
179         assert(height % 2 == 1);
180         __m128i sum;
181         const __m128i input = LoadLo8(&src[2]);
182         if (num_taps == 2) {
183           // 03 04 04 05 05 06 06 07 ....
184           const __m128i v_src_43 =
185               _mm_srli_si128(_mm_unpacklo_epi8(input, input), 3);
186           sum = _mm_maddubs_epi16(v_src_43, v_tap[0]);  // k4k3
187         } else {
188           // 02 03 03 04 04 05 05 06 06 07 ....
189           const __m128i v_src_32 =
190               _mm_srli_si128(_mm_unpacklo_epi8(input, input), 1);
191           // 04 05 05 06 06 07 07 08 ...
192           const __m128i v_src_54 = _mm_srli_si128(v_src_32, 4);
193           const __m128i v_madd_32 =
194               _mm_maddubs_epi16(v_src_32, v_tap[0]);  // k3k2
195           const __m128i v_madd_54 =
196               _mm_maddubs_epi16(v_src_54, v_tap[1]);  // k5k4
197           sum = _mm_add_epi16(v_madd_54, v_madd_32);
198         }
199         sum = RightShiftWithRounding_S16(sum, kInterRoundBitsHorizontal - 1);
200         Store4(dest16, sum);
201       }
202     }
203   }
204 }
205 
206 template <bool is_2d = false, bool is_compound = false>
DoHorizontalPass(const uint8_t * LIBGAV1_RESTRICT const src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const int filter_id,const int filter_index)207 LIBGAV1_ALWAYS_INLINE void DoHorizontalPass(
208     const uint8_t* LIBGAV1_RESTRICT const src, const ptrdiff_t src_stride,
209     void* LIBGAV1_RESTRICT const dst, const ptrdiff_t dst_stride,
210     const int width, const int height, const int filter_id,
211     const int filter_index) {
212   assert(filter_id != 0);
213   __m128i v_tap[4];
214   const __m128i v_horizontal_filter =
215       LoadLo8(kHalfSubPixelFilters[filter_index][filter_id]);
216 
217   if (filter_index == 2) {  // 8 tap.
218     SetupTaps<8>(&v_horizontal_filter, v_tap);
219     FilterHorizontal<8, is_2d, is_compound>(src, src_stride, dst, dst_stride,
220                                             width, height, v_tap);
221   } else if (filter_index == 1) {  // 6 tap.
222     SetupTaps<6>(&v_horizontal_filter, v_tap);
223     FilterHorizontal<6, is_2d, is_compound>(src, src_stride, dst, dst_stride,
224                                             width, height, v_tap);
225   } else if (filter_index == 0) {  // 6 tap.
226     SetupTaps<6>(&v_horizontal_filter, v_tap);
227     FilterHorizontal<6, is_2d, is_compound>(src, src_stride, dst, dst_stride,
228                                             width, height, v_tap);
229   } else if ((filter_index & 0x4) != 0) {  // 4 tap.
230     // ((filter_index == 4) | (filter_index == 5))
231     SetupTaps<4>(&v_horizontal_filter, v_tap);
232     FilterHorizontal<4, is_2d, is_compound>(src, src_stride, dst, dst_stride,
233                                             width, height, v_tap);
234   } else {  // 2 tap.
235     SetupTaps<2>(&v_horizontal_filter, v_tap);
236     FilterHorizontal<2, is_2d, is_compound>(src, src_stride, dst, dst_stride,
237                                             width, height, v_tap);
238   }
239 }
240 
Convolve2D_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int horizontal_filter_id,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)241 void Convolve2D_SSE4_1(const void* LIBGAV1_RESTRICT const reference,
242                        const ptrdiff_t reference_stride,
243                        const int horizontal_filter_index,
244                        const int vertical_filter_index,
245                        const int horizontal_filter_id,
246                        const int vertical_filter_id, const int width,
247                        const int height, void* LIBGAV1_RESTRICT prediction,
248                        const ptrdiff_t pred_stride) {
249   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
250   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
251   const int vertical_taps =
252       GetNumTapsInFilter(vert_filter_index, vertical_filter_id);
253 
254   // The output of the horizontal filter is guaranteed to fit in 16 bits.
255   alignas(16) uint16_t
256       intermediate_result[kMaxSuperBlockSizeInPixels *
257                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
258 #if LIBGAV1_MSAN
259   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
260   memset(intermediate_result, 0x33, sizeof(intermediate_result));
261 #endif
262   const int intermediate_height = height + vertical_taps - 1;
263 
264   const ptrdiff_t src_stride = reference_stride;
265   const auto* src = static_cast<const uint8_t*>(reference) -
266                     (vertical_taps / 2 - 1) * src_stride - kHorizontalOffset;
267 
268   DoHorizontalPass</*is_2d=*/true>(src, src_stride, intermediate_result, width,
269                                    width, intermediate_height,
270                                    horizontal_filter_id, horiz_filter_index);
271 
272   // Vertical filter.
273   auto* dest = static_cast<uint8_t*>(prediction);
274   const ptrdiff_t dest_stride = pred_stride;
275   assert(vertical_filter_id != 0);
276 
277   __m128i taps[4];
278   const __m128i v_filter =
279       LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]);
280 
281   if (vertical_taps == 8) {
282     SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps);
283     if (width == 2) {
284       Filter2DVertical2xH<8>(intermediate_result, dest, dest_stride, height,
285                              taps);
286     } else if (width == 4) {
287       Filter2DVertical4xH<8>(intermediate_result, dest, dest_stride, height,
288                              taps);
289     } else {
290       Filter2DVertical<8>(intermediate_result, dest, dest_stride, width, height,
291                           taps);
292     }
293   } else if (vertical_taps == 6) {
294     SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps);
295     if (width == 2) {
296       Filter2DVertical2xH<6>(intermediate_result, dest, dest_stride, height,
297                              taps);
298     } else if (width == 4) {
299       Filter2DVertical4xH<6>(intermediate_result, dest, dest_stride, height,
300                              taps);
301     } else {
302       Filter2DVertical<6>(intermediate_result, dest, dest_stride, width, height,
303                           taps);
304     }
305   } else if (vertical_taps == 4) {
306     SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps);
307     if (width == 2) {
308       Filter2DVertical2xH<4>(intermediate_result, dest, dest_stride, height,
309                              taps);
310     } else if (width == 4) {
311       Filter2DVertical4xH<4>(intermediate_result, dest, dest_stride, height,
312                              taps);
313     } else {
314       Filter2DVertical<4>(intermediate_result, dest, dest_stride, width, height,
315                           taps);
316     }
317   } else {  // |vertical_taps| == 2
318     SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps);
319     if (width == 2) {
320       Filter2DVertical2xH<2>(intermediate_result, dest, dest_stride, height,
321                              taps);
322     } else if (width == 4) {
323       Filter2DVertical4xH<2>(intermediate_result, dest, dest_stride, height,
324                              taps);
325     } else {
326       Filter2DVertical<2>(intermediate_result, dest, dest_stride, width, height,
327                           taps);
328     }
329   }
330 }
331 
332 template <int num_taps, bool is_compound = false>
FilterVertical(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,void * LIBGAV1_RESTRICT const dst,const ptrdiff_t dst_stride,const int width,const int height,const __m128i * const v_tap)333 void FilterVertical(const uint8_t* LIBGAV1_RESTRICT src,
334                     const ptrdiff_t src_stride,
335                     void* LIBGAV1_RESTRICT const dst,
336                     const ptrdiff_t dst_stride, const int width,
337                     const int height, const __m128i* const v_tap) {
338   const int next_row = num_taps - 1;
339   auto* dst8 = static_cast<uint8_t*>(dst);
340   auto* dst16 = static_cast<uint16_t*>(dst);
341   assert(width >= 8);
342 
343   int x = 0;
344   do {
345     const uint8_t* src_x = src + x;
346     __m128i srcs[8];
347     srcs[0] = LoadLo8(src_x);
348     src_x += src_stride;
349     if (num_taps >= 4) {
350       srcs[1] = LoadLo8(src_x);
351       src_x += src_stride;
352       srcs[2] = LoadLo8(src_x);
353       src_x += src_stride;
354       if (num_taps >= 6) {
355         srcs[3] = LoadLo8(src_x);
356         src_x += src_stride;
357         srcs[4] = LoadLo8(src_x);
358         src_x += src_stride;
359         if (num_taps == 8) {
360           srcs[5] = LoadLo8(src_x);
361           src_x += src_stride;
362           srcs[6] = LoadLo8(src_x);
363           src_x += src_stride;
364         }
365       }
366     }
367 
368     auto* dst8_x = dst8 + x;
369     auto* dst16_x = dst16 + x;
370     int y = height;
371     do {
372       srcs[next_row] = LoadLo8(src_x);
373       src_x += src_stride;
374 
375       const __m128i sums = SumVerticalTaps<num_taps>(srcs, v_tap);
376       if (is_compound) {
377         const __m128i results = Compound1DShift(sums);
378         StoreUnaligned16(dst16_x, results);
379         dst16_x += dst_stride;
380       } else {
381         const __m128i results =
382             RightShiftWithRounding_S16(sums, kFilterBits - 1);
383         StoreLo8(dst8_x, _mm_packus_epi16(results, results));
384         dst8_x += dst_stride;
385       }
386 
387       srcs[0] = srcs[1];
388       if (num_taps >= 4) {
389         srcs[1] = srcs[2];
390         srcs[2] = srcs[3];
391         if (num_taps >= 6) {
392           srcs[3] = srcs[4];
393           srcs[4] = srcs[5];
394           if (num_taps == 8) {
395             srcs[5] = srcs[6];
396             srcs[6] = srcs[7];
397           }
398         }
399       }
400     } while (--y != 0);
401     x += 8;
402   } while (x < width);
403 }
404 
ConvolveVertical_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)405 void ConvolveVertical_SSE4_1(
406     const void* LIBGAV1_RESTRICT const reference,
407     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
408     const int vertical_filter_index, const int /*horizontal_filter_id*/,
409     const int vertical_filter_id, const int width, const int height,
410     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
411   const int filter_index = GetFilterIndex(vertical_filter_index, height);
412   const int vertical_taps =
413       GetNumTapsInFilter(filter_index, vertical_filter_id);
414   const ptrdiff_t src_stride = reference_stride;
415   const auto* src = static_cast<const uint8_t*>(reference) -
416                     (vertical_taps / 2 - 1) * src_stride;
417   auto* dest = static_cast<uint8_t*>(prediction);
418   const ptrdiff_t dest_stride = pred_stride;
419   assert(vertical_filter_id != 0);
420 
421   __m128i taps[4];
422   const __m128i v_filter =
423       LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]);
424 
425   if (vertical_taps == 6) {  // 6 tap.
426     SetupTaps<6>(&v_filter, taps);
427     if (width == 2) {
428       FilterVertical2xH<6>(src, src_stride, dest, dest_stride, height, taps);
429     } else if (width == 4) {
430       FilterVertical4xH<6>(src, src_stride, dest, dest_stride, height, taps);
431     } else {
432       FilterVertical<6>(src, src_stride, dest, dest_stride, width, height,
433                         taps);
434     }
435   } else if (vertical_taps == 8) {  // 8 tap.
436     SetupTaps<8>(&v_filter, taps);
437     if (width == 2) {
438       FilterVertical2xH<8>(src, src_stride, dest, dest_stride, height, taps);
439     } else if (width == 4) {
440       FilterVertical4xH<8>(src, src_stride, dest, dest_stride, height, taps);
441     } else {
442       FilterVertical<8>(src, src_stride, dest, dest_stride, width, height,
443                         taps);
444     }
445   } else if (vertical_taps == 2) {  // 2 tap.
446     SetupTaps<2>(&v_filter, taps);
447     if (width == 2) {
448       FilterVertical2xH<2>(src, src_stride, dest, dest_stride, height, taps);
449     } else if (width == 4) {
450       FilterVertical4xH<2>(src, src_stride, dest, dest_stride, height, taps);
451     } else {
452       FilterVertical<2>(src, src_stride, dest, dest_stride, width, height,
453                         taps);
454     }
455   } else {  // 4 tap
456     SetupTaps<4>(&v_filter, taps);
457     if (width == 2) {
458       FilterVertical2xH<4>(src, src_stride, dest, dest_stride, height, taps);
459     } else if (width == 4) {
460       FilterVertical4xH<4>(src, src_stride, dest, dest_stride, height, taps);
461     } else {
462       FilterVertical<4>(src, src_stride, dest, dest_stride, width, height,
463                         taps);
464     }
465   }
466 }
467 
ConvolveCompoundCopy_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)468 void ConvolveCompoundCopy_SSE4_1(
469     const void* LIBGAV1_RESTRICT const reference,
470     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
471     const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
472     const int /*vertical_filter_id*/, const int width, const int height,
473     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
474   const auto* src = static_cast<const uint8_t*>(reference);
475   const ptrdiff_t src_stride = reference_stride;
476   auto* dest = static_cast<uint16_t*>(prediction);
477   constexpr int kRoundBitsVertical =
478       kInterRoundBitsVertical - kInterRoundBitsCompoundVertical;
479   if (width >= 16) {
480     int y = height;
481     do {
482       int x = 0;
483       do {
484         const __m128i v_src = LoadUnaligned16(&src[x]);
485         const __m128i v_src_ext_lo = _mm_cvtepu8_epi16(v_src);
486         const __m128i v_src_ext_hi =
487             _mm_cvtepu8_epi16(_mm_srli_si128(v_src, 8));
488         const __m128i v_dest_lo =
489             _mm_slli_epi16(v_src_ext_lo, kRoundBitsVertical);
490         const __m128i v_dest_hi =
491             _mm_slli_epi16(v_src_ext_hi, kRoundBitsVertical);
492         StoreUnaligned16(&dest[x], v_dest_lo);
493         StoreUnaligned16(&dest[x + 8], v_dest_hi);
494         x += 16;
495       } while (x < width);
496       src += src_stride;
497       dest += pred_stride;
498     } while (--y != 0);
499   } else if (width == 8) {
500     int y = height;
501     do {
502       const __m128i v_src = LoadLo8(&src[0]);
503       const __m128i v_src_ext = _mm_cvtepu8_epi16(v_src);
504       const __m128i v_dest = _mm_slli_epi16(v_src_ext, kRoundBitsVertical);
505       StoreUnaligned16(&dest[0], v_dest);
506       src += src_stride;
507       dest += pred_stride;
508     } while (--y != 0);
509   } else { /* width == 4 */
510     int y = height;
511     do {
512       const __m128i v_src0 = Load4(&src[0]);
513       const __m128i v_src1 = Load4(&src[src_stride]);
514       const __m128i v_src = _mm_unpacklo_epi32(v_src0, v_src1);
515       const __m128i v_src_ext = _mm_cvtepu8_epi16(v_src);
516       const __m128i v_dest = _mm_slli_epi16(v_src_ext, kRoundBitsVertical);
517       StoreLo8(&dest[0], v_dest);
518       StoreHi8(&dest[pred_stride], v_dest);
519       src += src_stride * 2;
520       dest += pred_stride * 2;
521       y -= 2;
522     } while (y != 0);
523   }
524 }
525 
ConvolveCompoundVertical_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int vertical_filter_index,const int,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t)526 void ConvolveCompoundVertical_SSE4_1(
527     const void* LIBGAV1_RESTRICT const reference,
528     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
529     const int vertical_filter_index, const int /*horizontal_filter_id*/,
530     const int vertical_filter_id, const int width, const int height,
531     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) {
532   const int filter_index = GetFilterIndex(vertical_filter_index, height);
533   const int vertical_taps =
534       GetNumTapsInFilter(filter_index, vertical_filter_id);
535   const ptrdiff_t src_stride = reference_stride;
536   const auto* src = static_cast<const uint8_t*>(reference) -
537                     (vertical_taps / 2 - 1) * src_stride;
538   auto* dest = static_cast<uint16_t*>(prediction);
539   assert(vertical_filter_id != 0);
540 
541   __m128i taps[4];
542   const __m128i v_filter =
543       LoadLo8(kHalfSubPixelFilters[filter_index][vertical_filter_id]);
544 
545   if (vertical_taps == 6) {  // 6 tap.
546     SetupTaps<6>(&v_filter, taps);
547     if (width == 4) {
548       FilterVertical4xH<6, /*is_compound=*/true>(src, src_stride, dest, 4,
549                                                  height, taps);
550     } else {
551       FilterVertical<6, /*is_compound=*/true>(src, src_stride, dest, width,
552                                               width, height, taps);
553     }
554   } else if (vertical_taps == 8) {  // 8 tap.
555     SetupTaps<8>(&v_filter, taps);
556     if (width == 4) {
557       FilterVertical4xH<8, /*is_compound=*/true>(src, src_stride, dest, 4,
558                                                  height, taps);
559     } else {
560       FilterVertical<8, /*is_compound=*/true>(src, src_stride, dest, width,
561                                               width, height, taps);
562     }
563   } else if (vertical_taps == 2) {  // 2 tap.
564     SetupTaps<2>(&v_filter, taps);
565     if (width == 4) {
566       FilterVertical4xH<2, /*is_compound=*/true>(src, src_stride, dest, 4,
567                                                  height, taps);
568     } else {
569       FilterVertical<2, /*is_compound=*/true>(src, src_stride, dest, width,
570                                               width, height, taps);
571     }
572   } else {  // 4 tap
573     SetupTaps<4>(&v_filter, taps);
574     if (width == 4) {
575       FilterVertical4xH<4, /*is_compound=*/true>(src, src_stride, dest, 4,
576                                                  height, taps);
577     } else {
578       FilterVertical<4, /*is_compound=*/true>(src, src_stride, dest, width,
579                                               width, height, taps);
580     }
581   }
582 }
583 
ConvolveHorizontal_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int horizontal_filter_id,const int,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)584 void ConvolveHorizontal_SSE4_1(
585     const void* LIBGAV1_RESTRICT const reference,
586     const ptrdiff_t reference_stride, const int horizontal_filter_index,
587     const int /*vertical_filter_index*/, const int horizontal_filter_id,
588     const int /*vertical_filter_id*/, const int width, const int height,
589     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t pred_stride) {
590   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
591   // Set |src| to the outermost tap.
592   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
593   auto* dest = static_cast<uint8_t*>(prediction);
594 
595   DoHorizontalPass(src, reference_stride, dest, pred_stride, width, height,
596                    horizontal_filter_id, filter_index);
597 }
598 
ConvolveCompoundHorizontal_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int,const int horizontal_filter_id,const int,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t)599 void ConvolveCompoundHorizontal_SSE4_1(
600     const void* LIBGAV1_RESTRICT const reference,
601     const ptrdiff_t reference_stride, const int horizontal_filter_index,
602     const int /*vertical_filter_index*/, const int horizontal_filter_id,
603     const int /*vertical_filter_id*/, const int width, const int height,
604     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) {
605   const int filter_index = GetFilterIndex(horizontal_filter_index, width);
606   const auto* src = static_cast<const uint8_t*>(reference) - kHorizontalOffset;
607   auto* dest = static_cast<uint16_t*>(prediction);
608 
609   DoHorizontalPass</*is_2d=*/false, /*is_compound=*/true>(
610       src, reference_stride, dest, width, width, height, horizontal_filter_id,
611       filter_index);
612 }
613 
ConvolveCompound2D_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int horizontal_filter_id,const int vertical_filter_id,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t)614 void ConvolveCompound2D_SSE4_1(
615     const void* LIBGAV1_RESTRICT const reference,
616     const ptrdiff_t reference_stride, const int horizontal_filter_index,
617     const int vertical_filter_index, const int horizontal_filter_id,
618     const int vertical_filter_id, const int width, const int height,
619     void* LIBGAV1_RESTRICT prediction, const ptrdiff_t /*pred_stride*/) {
620   // The output of the horizontal filter, i.e. the intermediate_result, is
621   // guaranteed to fit in int16_t.
622   alignas(16) uint16_t
623       intermediate_result[kMaxSuperBlockSizeInPixels *
624                           (kMaxSuperBlockSizeInPixels + kSubPixelTaps - 1)];
625 #if LIBGAV1_MSAN
626   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
627   memset(intermediate_result, 0x33, sizeof(intermediate_result));
628 #endif
629 
630   // Horizontal filter.
631   // Filter types used for width <= 4 are different from those for width > 4.
632   // When width > 4, the valid filter index range is always [0, 3].
633   // When width <= 4, the valid filter index range is always [4, 5].
634   // Similarly for height.
635   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
636   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
637   const int vertical_taps =
638       GetNumTapsInFilter(vert_filter_index, vertical_filter_id);
639   const int intermediate_height = height + vertical_taps - 1;
640   const ptrdiff_t src_stride = reference_stride;
641   const auto* const src = static_cast<const uint8_t*>(reference) -
642                           (vertical_taps / 2 - 1) * src_stride -
643                           kHorizontalOffset;
644 
645   DoHorizontalPass</*is_2d=*/true, /*is_compound=*/true>(
646       src, src_stride, intermediate_result, width, width, intermediate_height,
647       horizontal_filter_id, horiz_filter_index);
648 
649   // Vertical filter.
650   auto* dest = static_cast<uint16_t*>(prediction);
651   assert(vertical_filter_id != 0);
652 
653   const ptrdiff_t dest_stride = width;
654   __m128i taps[4];
655   const __m128i v_filter =
656       LoadLo8(kHalfSubPixelFilters[vert_filter_index][vertical_filter_id]);
657 
658   if (vertical_taps == 8) {
659     SetupTaps<8, /*is_2d_vertical=*/true>(&v_filter, taps);
660     if (width == 4) {
661       Filter2DVertical4xH<8, /*is_compound=*/true>(intermediate_result, dest,
662                                                    dest_stride, height, taps);
663     } else {
664       Filter2DVertical<8, /*is_compound=*/true>(
665           intermediate_result, dest, dest_stride, width, height, taps);
666     }
667   } else if (vertical_taps == 6) {
668     SetupTaps<6, /*is_2d_vertical=*/true>(&v_filter, taps);
669     if (width == 4) {
670       Filter2DVertical4xH<6, /*is_compound=*/true>(intermediate_result, dest,
671                                                    dest_stride, height, taps);
672     } else {
673       Filter2DVertical<6, /*is_compound=*/true>(
674           intermediate_result, dest, dest_stride, width, height, taps);
675     }
676   } else if (vertical_taps == 4) {
677     SetupTaps<4, /*is_2d_vertical=*/true>(&v_filter, taps);
678     if (width == 4) {
679       Filter2DVertical4xH<4, /*is_compound=*/true>(intermediate_result, dest,
680                                                    dest_stride, height, taps);
681     } else {
682       Filter2DVertical<4, /*is_compound=*/true>(
683           intermediate_result, dest, dest_stride, width, height, taps);
684     }
685   } else {  // |vertical_taps| == 2
686     SetupTaps<2, /*is_2d_vertical=*/true>(&v_filter, taps);
687     if (width == 4) {
688       Filter2DVertical4xH<2, /*is_compound=*/true>(intermediate_result, dest,
689                                                    dest_stride, height, taps);
690     } else {
691       Filter2DVertical<2, /*is_compound=*/true>(
692           intermediate_result, dest, dest_stride, width, height, taps);
693     }
694   }
695 }
696 
697 // Pre-transposed filters.
698 template <int filter_index>
GetHalfSubPixelFilter(__m128i * output)699 inline void GetHalfSubPixelFilter(__m128i* output) {
700   // Filter 0
701   alignas(
702       16) static constexpr int8_t kHalfSubPixel6TapSignedFilterColumns[6][16] =
703       {{0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 0},
704        {0, -3, -5, -6, -7, -7, -8, -7, -7, -6, -6, -6, -5, -4, -2, -1},
705        {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
706        {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
707        {0, -1, -2, -4, -5, -6, -6, -6, -7, -7, -8, -7, -7, -6, -5, -3},
708        {0, 0, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1}};
709   // Filter 1
710   alignas(16) static constexpr int8_t
711       kHalfSubPixel6TapMixedSignedFilterColumns[6][16] = {
712           {0, 1, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 0, 0},
713           {0, 14, 13, 11, 10, 9, 8, 8, 7, 6, 5, 4, 3, 2, 2, 1},
714           {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
715           {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
716           {0, 1, 2, 2, 3, 4, 5, 6, 7, 8, 8, 9, 10, 11, 13, 14},
717           {0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 0, 0, 1}};
718   // Filter 2
719   alignas(
720       16) static constexpr int8_t kHalfSubPixel8TapSignedFilterColumns[8][16] =
721       {{0, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1, -1, -1, -1, 0},
722        {0, 1, 3, 4, 5, 5, 5, 5, 6, 5, 4, 4, 3, 3, 2, 1},
723        {0, -3, -6, -9, -11, -11, -12, -12, -12, -11, -10, -9, -7, -5, -3, -1},
724        {64, 63, 62, 60, 58, 54, 50, 45, 40, 35, 30, 24, 19, 13, 8, 4},
725        {0, 4, 8, 13, 19, 24, 30, 35, 40, 45, 50, 54, 58, 60, 62, 63},
726        {0, -1, -3, -5, -7, -9, -10, -11, -12, -12, -12, -11, -11, -9, -6, -3},
727        {0, 1, 2, 3, 3, 4, 4, 5, 6, 5, 5, 5, 5, 4, 3, 1},
728        {0, 0, -1, -1, -1, -1, -1, -1, -2, -2, -2, -2, -2, -1, -1, -1}};
729   // Filter 3
730   alignas(16) static constexpr uint8_t kHalfSubPixel2TapFilterColumns[2][16] = {
731       {64, 60, 56, 52, 48, 44, 40, 36, 32, 28, 24, 20, 16, 12, 8, 4},
732       {0, 4, 8, 12, 16, 20, 24, 28, 32, 36, 40, 44, 48, 52, 56, 60}};
733   // Filter 4
734   alignas(
735       16) static constexpr int8_t kHalfSubPixel4TapSignedFilterColumns[4][16] =
736       {{0, -2, -4, -5, -6, -6, -7, -6, -6, -5, -5, -5, -4, -3, -2, -1},
737        {64, 63, 61, 58, 55, 51, 47, 42, 38, 33, 29, 24, 19, 14, 9, 4},
738        {0, 4, 9, 14, 19, 24, 29, 33, 38, 42, 47, 51, 55, 58, 61, 63},
739        {0, -1, -2, -3, -4, -5, -5, -5, -6, -6, -7, -6, -6, -5, -4, -2}};
740   // Filter 5
741   alignas(
742       16) static constexpr uint8_t kSubPixel4TapPositiveFilterColumns[4][16] = {
743       {0, 15, 13, 11, 10, 9, 8, 7, 6, 6, 5, 4, 3, 2, 2, 1},
744       {64, 31, 31, 31, 30, 29, 28, 27, 26, 24, 23, 22, 21, 20, 18, 17},
745       {0, 17, 18, 20, 21, 22, 23, 24, 26, 27, 28, 29, 30, 31, 31, 31},
746       {0, 1, 2, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 11, 13, 15}};
747   switch (filter_index) {
748     case 0:
749       output[0] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[0]);
750       output[1] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[1]);
751       output[2] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[2]);
752       output[3] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[3]);
753       output[4] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[4]);
754       output[5] = LoadAligned16(kHalfSubPixel6TapSignedFilterColumns[5]);
755       break;
756     case 1:
757       // The term "mixed" refers to the fact that the outer taps have a mix of
758       // negative and positive values.
759       output[0] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[0]);
760       output[1] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[1]);
761       output[2] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[2]);
762       output[3] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[3]);
763       output[4] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[4]);
764       output[5] = LoadAligned16(kHalfSubPixel6TapMixedSignedFilterColumns[5]);
765       break;
766     case 2:
767       output[0] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[0]);
768       output[1] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[1]);
769       output[2] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[2]);
770       output[3] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[3]);
771       output[4] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[4]);
772       output[5] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[5]);
773       output[6] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[6]);
774       output[7] = LoadAligned16(kHalfSubPixel8TapSignedFilterColumns[7]);
775       break;
776     case 3:
777       output[0] = LoadAligned16(kHalfSubPixel2TapFilterColumns[0]);
778       output[1] = LoadAligned16(kHalfSubPixel2TapFilterColumns[1]);
779       break;
780     case 4:
781       output[0] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[0]);
782       output[1] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[1]);
783       output[2] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[2]);
784       output[3] = LoadAligned16(kHalfSubPixel4TapSignedFilterColumns[3]);
785       break;
786     default:
787       assert(filter_index == 5);
788       output[0] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[0]);
789       output[1] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[1]);
790       output[2] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[2]);
791       output[3] = LoadAligned16(kSubPixel4TapPositiveFilterColumns[3]);
792       break;
793   }
794 }
795 
796 // There are many opportunities for overreading in scaled convolve, because
797 // the range of starting points for filter windows is anywhere from 0 to 16
798 // for 8 destination pixels, and the window sizes range from 2 to 8. To
799 // accommodate this range concisely, we use |grade_x| to mean the most steps
800 // in src that can be traversed in a single |step_x| increment, i.e. 1 or 2.
801 // More importantly, |grade_x| answers the question "how many vector loads are
802 // needed to cover the source values?"
803 // When |grade_x| == 1, the maximum number of source values needed is 8 separate
804 // starting positions plus 7 more to cover taps, all fitting into 16 bytes.
805 // When |grade_x| > 1, we are guaranteed to exceed 8 whole steps in src for
806 // every 8 |step_x| increments, on top of 8 possible taps. The first load covers
807 // the starting sources for each kernel, while the final load covers the taps.
808 // Since the offset value of src_x cannot exceed 8 and |num_taps| does not
809 // exceed 4 when width <= 4, |grade_x| is set to 1 regardless of the value of
810 // |step_x|.
811 template <int num_taps, int grade_x>
PrepareSourceVectors(const uint8_t * LIBGAV1_RESTRICT src,const __m128i src_indices,__m128i * const source)812 inline void PrepareSourceVectors(const uint8_t* LIBGAV1_RESTRICT src,
813                                  const __m128i src_indices,
814                                  __m128i* const source /*[num_taps >> 1]*/) {
815   // |used_bytes| is only computed in msan builds. Mask away unused bytes for
816   // msan because it incorrectly models the outcome of the shuffles in some
817   // cases. This has not been reproduced out of context.
818   const int used_bytes = _mm_extract_epi8(src_indices, 15) + 1 + num_taps - 2;
819   const __m128i src_vals = LoadUnaligned16Msan(src, 16 - used_bytes);
820   source[0] = _mm_shuffle_epi8(src_vals, src_indices);
821   if (grade_x == 1) {
822     if (num_taps > 2) {
823       source[1] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 2), src_indices);
824     }
825     if (num_taps > 4) {
826       source[2] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 4), src_indices);
827     }
828     if (num_taps > 6) {
829       source[3] = _mm_shuffle_epi8(_mm_srli_si128(src_vals, 6), src_indices);
830     }
831   } else {
832     assert(grade_x > 1);
833     assert(num_taps != 4);
834     // grade_x > 1 also means width >= 8 && num_taps != 4
835     const __m128i src_vals_ext = LoadLo8Msan(src + 16, 24 - used_bytes);
836     if (num_taps > 2) {
837       source[1] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 2),
838                                    src_indices);
839       source[2] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 4),
840                                    src_indices);
841     }
842     if (num_taps > 6) {
843       source[3] = _mm_shuffle_epi8(_mm_alignr_epi8(src_vals_ext, src_vals, 6),
844                                    src_indices);
845     }
846   }
847 }
848 
849 template <int num_taps>
PrepareHorizontalTaps(const __m128i subpel_indices,const __m128i * filter_taps,__m128i * out_taps)850 inline void PrepareHorizontalTaps(const __m128i subpel_indices,
851                                   const __m128i* filter_taps,
852                                   __m128i* out_taps) {
853   const __m128i scale_index_offsets =
854       _mm_srli_epi16(subpel_indices, kFilterIndexShift);
855   const __m128i filter_index_mask = _mm_set1_epi8(kSubPixelMask);
856   const __m128i filter_indices =
857       _mm_and_si128(_mm_packus_epi16(scale_index_offsets, scale_index_offsets),
858                     filter_index_mask);
859   // Line up taps for maddubs_epi16.
860   // The unpack is also assumed to be lighter than shift+alignr.
861   for (int k = 0; k < (num_taps >> 1); ++k) {
862     const __m128i taps0 = _mm_shuffle_epi8(filter_taps[2 * k], filter_indices);
863     const __m128i taps1 =
864         _mm_shuffle_epi8(filter_taps[2 * k + 1], filter_indices);
865     out_taps[k] = _mm_unpacklo_epi8(taps0, taps1);
866   }
867 }
868 
HorizontalScaleIndices(const __m128i subpel_indices)869 inline __m128i HorizontalScaleIndices(const __m128i subpel_indices) {
870   const __m128i src_indices16 =
871       _mm_srli_epi16(subpel_indices, kScaleSubPixelBits);
872   const __m128i src_indices = _mm_packus_epi16(src_indices16, src_indices16);
873   return _mm_unpacklo_epi8(src_indices,
874                            _mm_add_epi8(src_indices, _mm_set1_epi8(1)));
875 }
876 
877 template <int grade_x, int filter_index, int num_taps>
ConvolveHorizontalScale(const uint8_t * LIBGAV1_RESTRICT src,ptrdiff_t src_stride,int width,int subpixel_x,int step_x,int intermediate_height,int16_t * LIBGAV1_RESTRICT intermediate)878 inline void ConvolveHorizontalScale(const uint8_t* LIBGAV1_RESTRICT src,
879                                     ptrdiff_t src_stride, int width,
880                                     int subpixel_x, int step_x,
881                                     int intermediate_height,
882                                     int16_t* LIBGAV1_RESTRICT intermediate) {
883   // Account for the 0-taps that precede the 2 nonzero taps.
884   const int kernel_offset = (8 - num_taps) >> 1;
885   const int ref_x = subpixel_x >> kScaleSubPixelBits;
886   const int step_x8 = step_x << 3;
887   __m128i filter_taps[num_taps];
888   GetHalfSubPixelFilter<filter_index>(filter_taps);
889   const __m128i index_steps =
890       _mm_mullo_epi16(_mm_set_epi16(7, 6, 5, 4, 3, 2, 1, 0),
891                       _mm_set1_epi16(static_cast<int16_t>(step_x)));
892 
893   __m128i taps[num_taps >> 1];
894   __m128i source[num_taps >> 1];
895   int p = subpixel_x;
896   // Case when width <= 4 is possible.
897   if (filter_index >= 3) {
898     if (filter_index > 3 || width <= 4) {
899       const uint8_t* src_x =
900           &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
901       // Only add steps to the 10-bit truncated p to avoid overflow.
902       const __m128i p_fraction = _mm_set1_epi16(p & 1023);
903       const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction);
904       PrepareHorizontalTaps<num_taps>(subpel_indices, filter_taps, taps);
905       const __m128i packed_indices = HorizontalScaleIndices(subpel_indices);
906 
907       int y = intermediate_height;
908       do {
909         // Load and line up source values with the taps. Width 4 means no need
910         // to load extended source.
911         PrepareSourceVectors<num_taps, /*grade_x=*/1>(src_x, packed_indices,
912                                                       source);
913 
914         StoreLo8(intermediate, RightShiftWithRounding_S16(
915                                    SumOnePassTaps<num_taps>(source, taps),
916                                    kInterRoundBitsHorizontal - 1));
917         src_x += src_stride;
918         intermediate += kIntermediateStride;
919       } while (--y != 0);
920       return;
921     }
922   }
923 
924   // |width| >= 8
925   int16_t* intermediate_x = intermediate;
926   int x = 0;
927   do {
928     const uint8_t* src_x =
929         &src[(p >> kScaleSubPixelBits) - ref_x + kernel_offset];
930     // Only add steps to the 10-bit truncated p to avoid overflow.
931     const __m128i p_fraction = _mm_set1_epi16(p & 1023);
932     const __m128i subpel_indices = _mm_add_epi16(index_steps, p_fraction);
933     PrepareHorizontalTaps<num_taps>(subpel_indices, filter_taps, taps);
934     const __m128i packed_indices = HorizontalScaleIndices(subpel_indices);
935 
936     int y = intermediate_height;
937     do {
938       // For each x, a lane of src_k[k] contains src_x[k].
939       PrepareSourceVectors<num_taps, grade_x>(src_x, packed_indices, source);
940 
941       // Shift by one less because the taps are halved.
942       StoreAligned16(intermediate_x, RightShiftWithRounding_S16(
943                                          SumOnePassTaps<num_taps>(source, taps),
944                                          kInterRoundBitsHorizontal - 1));
945       src_x += src_stride;
946       intermediate_x += kIntermediateStride;
947     } while (--y != 0);
948     x += 8;
949     p += step_x8;
950   } while (x < width);
951 }
952 
953 template <int num_taps>
PrepareVerticalTaps(const int8_t * LIBGAV1_RESTRICT taps,__m128i * output)954 inline void PrepareVerticalTaps(const int8_t* LIBGAV1_RESTRICT taps,
955                                 __m128i* output) {
956   // Avoid overreading the filter due to starting at kernel_offset.
957   // The only danger of overread is in the final filter, which has 4 taps.
958   const __m128i filter =
959       _mm_cvtepi8_epi16((num_taps > 4) ? LoadLo8(taps) : Load4(taps));
960   output[0] = _mm_shuffle_epi32(filter, 0);
961   if (num_taps > 2) {
962     output[1] = _mm_shuffle_epi32(filter, 0x55);
963   }
964   if (num_taps > 4) {
965     output[2] = _mm_shuffle_epi32(filter, 0xAA);
966   }
967   if (num_taps > 6) {
968     output[3] = _mm_shuffle_epi32(filter, 0xFF);
969   }
970 }
971 
972 // Process eight 16 bit inputs and output eight 16 bit values.
973 template <int num_taps, bool is_compound>
Sum2DVerticalTaps(const __m128i * const src,const __m128i * taps)974 inline __m128i Sum2DVerticalTaps(const __m128i* const src,
975                                  const __m128i* taps) {
976   const __m128i src_lo_01 = _mm_unpacklo_epi16(src[0], src[1]);
977   __m128i sum_lo = _mm_madd_epi16(src_lo_01, taps[0]);
978   const __m128i src_hi_01 = _mm_unpackhi_epi16(src[0], src[1]);
979   __m128i sum_hi = _mm_madd_epi16(src_hi_01, taps[0]);
980   if (num_taps > 2) {
981     const __m128i src_lo_23 = _mm_unpacklo_epi16(src[2], src[3]);
982     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_23, taps[1]));
983     const __m128i src_hi_23 = _mm_unpackhi_epi16(src[2], src[3]);
984     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_23, taps[1]));
985   }
986   if (num_taps > 4) {
987     const __m128i src_lo_45 = _mm_unpacklo_epi16(src[4], src[5]);
988     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_45, taps[2]));
989     const __m128i src_hi_45 = _mm_unpackhi_epi16(src[4], src[5]);
990     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_45, taps[2]));
991   }
992   if (num_taps > 6) {
993     const __m128i src_lo_67 = _mm_unpacklo_epi16(src[6], src[7]);
994     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_67, taps[3]));
995     const __m128i src_hi_67 = _mm_unpackhi_epi16(src[6], src[7]);
996     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_67, taps[3]));
997   }
998   if (is_compound) {
999     return _mm_packs_epi32(
1000         RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
1001         RightShiftWithRounding_S32(sum_hi,
1002                                    kInterRoundBitsCompoundVertical - 1));
1003   }
1004   return _mm_packs_epi32(
1005       RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
1006       RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
1007 }
1008 
1009 // Bottom half of each src[k] is the source for one filter, and the top half
1010 // is the source for the other filter, for the next destination row.
1011 template <int num_taps, bool is_compound>
Sum2DVerticalTaps4x2(const __m128i * const src,const __m128i * taps_lo,const __m128i * taps_hi)1012 __m128i Sum2DVerticalTaps4x2(const __m128i* const src, const __m128i* taps_lo,
1013                              const __m128i* taps_hi) {
1014   const __m128i src_lo_01 = _mm_unpacklo_epi16(src[0], src[1]);
1015   __m128i sum_lo = _mm_madd_epi16(src_lo_01, taps_lo[0]);
1016   const __m128i src_hi_01 = _mm_unpackhi_epi16(src[0], src[1]);
1017   __m128i sum_hi = _mm_madd_epi16(src_hi_01, taps_hi[0]);
1018   if (num_taps > 2) {
1019     const __m128i src_lo_23 = _mm_unpacklo_epi16(src[2], src[3]);
1020     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_23, taps_lo[1]));
1021     const __m128i src_hi_23 = _mm_unpackhi_epi16(src[2], src[3]);
1022     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_23, taps_hi[1]));
1023   }
1024   if (num_taps > 4) {
1025     const __m128i src_lo_45 = _mm_unpacklo_epi16(src[4], src[5]);
1026     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_45, taps_lo[2]));
1027     const __m128i src_hi_45 = _mm_unpackhi_epi16(src[4], src[5]);
1028     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_45, taps_hi[2]));
1029   }
1030   if (num_taps > 6) {
1031     const __m128i src_lo_67 = _mm_unpacklo_epi16(src[6], src[7]);
1032     sum_lo = _mm_add_epi32(sum_lo, _mm_madd_epi16(src_lo_67, taps_lo[3]));
1033     const __m128i src_hi_67 = _mm_unpackhi_epi16(src[6], src[7]);
1034     sum_hi = _mm_add_epi32(sum_hi, _mm_madd_epi16(src_hi_67, taps_hi[3]));
1035   }
1036 
1037   if (is_compound) {
1038     return _mm_packs_epi32(
1039         RightShiftWithRounding_S32(sum_lo, kInterRoundBitsCompoundVertical - 1),
1040         RightShiftWithRounding_S32(sum_hi,
1041                                    kInterRoundBitsCompoundVertical - 1));
1042   }
1043   return _mm_packs_epi32(
1044       RightShiftWithRounding_S32(sum_lo, kInterRoundBitsVertical - 1),
1045       RightShiftWithRounding_S32(sum_hi, kInterRoundBitsVertical - 1));
1046 }
1047 
1048 // |width_class| is 2, 4, or 8, according to the Store function that should be
1049 // used.
1050 template <int num_taps, int width_class, bool is_compound>
ConvolveVerticalScale(const int16_t * LIBGAV1_RESTRICT src,const int intermediate_height,const int width,const int subpixel_y,const int filter_index,const int step_y,const int height,void * LIBGAV1_RESTRICT dest,const ptrdiff_t dest_stride)1051 inline void ConvolveVerticalScale(const int16_t* LIBGAV1_RESTRICT src,
1052                                   const int intermediate_height,
1053                                   const int width, const int subpixel_y,
1054                                   const int filter_index, const int step_y,
1055                                   const int height, void* LIBGAV1_RESTRICT dest,
1056                                   const ptrdiff_t dest_stride) {
1057   constexpr ptrdiff_t src_stride = kIntermediateStride;
1058   constexpr int kernel_offset = (8 - num_taps) / 2;
1059   const int16_t* src_y = src;
1060   // |dest| is 16-bit in compound mode, Pixel otherwise.
1061   auto* dest16_y = static_cast<uint16_t*>(dest);
1062   auto* dest_y = static_cast<uint8_t*>(dest);
1063   __m128i s[num_taps];
1064 
1065   int p = subpixel_y & 1023;
1066   int y = height;
1067   if (width_class <= 4) {
1068     __m128i filter_taps_lo[num_taps >> 1];
1069     __m128i filter_taps_hi[num_taps >> 1];
1070     do {  // y > 0
1071       for (int i = 0; i < num_taps; ++i) {
1072         s[i] = LoadLo8(src_y + i * src_stride);
1073       }
1074       int filter_id = (p >> 6) & kSubPixelMask;
1075       const int8_t* filter0 =
1076           kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
1077       PrepareVerticalTaps<num_taps>(filter0, filter_taps_lo);
1078       p += step_y;
1079       src_y = src + (p >> kScaleSubPixelBits) * src_stride;
1080 
1081       for (int i = 0; i < num_taps; ++i) {
1082         s[i] = LoadHi8(s[i], src_y + i * src_stride);
1083       }
1084       filter_id = (p >> 6) & kSubPixelMask;
1085       const int8_t* filter1 =
1086           kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
1087       PrepareVerticalTaps<num_taps>(filter1, filter_taps_hi);
1088       p += step_y;
1089       src_y = src + (p >> kScaleSubPixelBits) * src_stride;
1090 
1091       const __m128i sums = Sum2DVerticalTaps4x2<num_taps, is_compound>(
1092           s, filter_taps_lo, filter_taps_hi);
1093       if (is_compound) {
1094         assert(width_class > 2);
1095         StoreLo8(dest16_y, sums);
1096         dest16_y += dest_stride;
1097         StoreHi8(dest16_y, sums);
1098         dest16_y += dest_stride;
1099       } else {
1100         const __m128i result = _mm_packus_epi16(sums, sums);
1101         if (width_class == 2) {
1102           Store2(dest_y, result);
1103           dest_y += dest_stride;
1104           Store2(dest_y, _mm_srli_si128(result, 4));
1105         } else {
1106           Store4(dest_y, result);
1107           dest_y += dest_stride;
1108           Store4(dest_y, _mm_srli_si128(result, 4));
1109         }
1110         dest_y += dest_stride;
1111       }
1112       y -= 2;
1113     } while (y != 0);
1114     return;
1115   }
1116 
1117   // |width_class| >= 8
1118   __m128i filter_taps[num_taps >> 1];
1119   int x = 0;
1120   do {  // x < width
1121     auto* dest_y = static_cast<uint8_t*>(dest) + x;
1122     auto* dest16_y = static_cast<uint16_t*>(dest) + x;
1123     int p = subpixel_y & 1023;
1124     int y = height;
1125     do {  // y > 0
1126       const int filter_id = (p >> 6) & kSubPixelMask;
1127       const int8_t* filter =
1128           kHalfSubPixelFilters[filter_index][filter_id] + kernel_offset;
1129       PrepareVerticalTaps<num_taps>(filter, filter_taps);
1130 
1131       src_y = src + (p >> kScaleSubPixelBits) * src_stride;
1132       for (int i = 0; i < num_taps; ++i) {
1133         s[i] = LoadUnaligned16(src_y + i * src_stride);
1134       }
1135 
1136       const __m128i sums =
1137           Sum2DVerticalTaps<num_taps, is_compound>(s, filter_taps);
1138       if (is_compound) {
1139         StoreUnaligned16(dest16_y, sums);
1140       } else {
1141         StoreLo8(dest_y, _mm_packus_epi16(sums, sums));
1142       }
1143       p += step_y;
1144       dest_y += dest_stride;
1145       dest16_y += dest_stride;
1146     } while (--y != 0);
1147     src += kIntermediateStride * intermediate_height;
1148     x += 8;
1149   } while (x < width);
1150 }
1151 
1152 template <bool is_compound>
ConvolveScale2D_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int horizontal_filter_index,const int vertical_filter_index,const int subpixel_x,const int subpixel_y,const int step_x,const int step_y,const int width,const int height,void * LIBGAV1_RESTRICT prediction,const ptrdiff_t pred_stride)1153 void ConvolveScale2D_SSE4_1(const void* LIBGAV1_RESTRICT const reference,
1154                             const ptrdiff_t reference_stride,
1155                             const int horizontal_filter_index,
1156                             const int vertical_filter_index,
1157                             const int subpixel_x, const int subpixel_y,
1158                             const int step_x, const int step_y, const int width,
1159                             const int height, void* LIBGAV1_RESTRICT prediction,
1160                             const ptrdiff_t pred_stride) {
1161   const int horiz_filter_index = GetFilterIndex(horizontal_filter_index, width);
1162   const int vert_filter_index = GetFilterIndex(vertical_filter_index, height);
1163   assert(step_x <= 2048);
1164   // The output of the horizontal filter, i.e. the intermediate_result, is
1165   // guaranteed to fit in int16_t.
1166   alignas(16) int16_t
1167       intermediate_result[kIntermediateAllocWidth *
1168                           (2 * kIntermediateAllocWidth + kSubPixelTaps)];
1169 #if LIBGAV1_MSAN
1170   // Quiet msan warnings. Set with random non-zero value to aid in debugging.
1171   memset(intermediate_result, 0x44, sizeof(intermediate_result));
1172 #endif
1173   const int num_vert_taps = dsp::GetNumTapsInFilter(vert_filter_index);
1174   const int intermediate_height =
1175       (((height - 1) * step_y + (1 << kScaleSubPixelBits) - 1) >>
1176        kScaleSubPixelBits) +
1177       num_vert_taps;
1178 
1179   // Horizontal filter.
1180   // Filter types used for width <= 4 are different from those for width > 4.
1181   // When width > 4, the valid filter index range is always [0, 3].
1182   // When width <= 4, the valid filter index range is always [3, 5].
1183   // Similarly for height.
1184   int16_t* intermediate = intermediate_result;
1185   const ptrdiff_t src_stride = reference_stride;
1186   const auto* src = static_cast<const uint8_t*>(reference);
1187   const int vert_kernel_offset = (8 - num_vert_taps) / 2;
1188   src += vert_kernel_offset * src_stride;
1189 
1190   // Derive the maximum value of |step_x| at which all source values fit in one
1191   // 16-byte load. Final index is src_x + |num_taps| - 1 < 16
1192   // step_x*7 is the final base sub-pixel index for the shuffle mask for filter
1193   // inputs in each iteration on large blocks. When step_x is large, we need a
1194   // second register and alignr in order to gather all filter inputs.
1195   // |num_taps| - 1 is the offset for the shuffle of inputs to the final tap.
1196   const int num_horiz_taps = dsp::GetNumTapsInFilter(horiz_filter_index);
1197   const int kernel_start_ceiling = 16 - num_horiz_taps;
1198   // This truncated quotient |grade_x_threshold| selects |step_x| such that:
1199   // (step_x * 7) >> kScaleSubPixelBits < single load limit
1200   const int grade_x_threshold =
1201       (kernel_start_ceiling << kScaleSubPixelBits) / 7;
1202   switch (horiz_filter_index) {
1203     case 0:
1204       if (step_x > grade_x_threshold) {
1205         ConvolveHorizontalScale<2, 0, 6>(src, src_stride, width, subpixel_x,
1206                                          step_x, intermediate_height,
1207                                          intermediate);
1208       } else {
1209         ConvolveHorizontalScale<1, 0, 6>(src, src_stride, width, subpixel_x,
1210                                          step_x, intermediate_height,
1211                                          intermediate);
1212       }
1213       break;
1214     case 1:
1215       if (step_x > grade_x_threshold) {
1216         ConvolveHorizontalScale<2, 1, 6>(src, src_stride, width, subpixel_x,
1217                                          step_x, intermediate_height,
1218                                          intermediate);
1219 
1220       } else {
1221         ConvolveHorizontalScale<1, 1, 6>(src, src_stride, width, subpixel_x,
1222                                          step_x, intermediate_height,
1223                                          intermediate);
1224       }
1225       break;
1226     case 2:
1227       if (step_x > grade_x_threshold) {
1228         ConvolveHorizontalScale<2, 2, 8>(src, src_stride, width, subpixel_x,
1229                                          step_x, intermediate_height,
1230                                          intermediate);
1231       } else {
1232         ConvolveHorizontalScale<1, 2, 8>(src, src_stride, width, subpixel_x,
1233                                          step_x, intermediate_height,
1234                                          intermediate);
1235       }
1236       break;
1237     case 3:
1238       if (step_x > grade_x_threshold) {
1239         ConvolveHorizontalScale<2, 3, 2>(src, src_stride, width, subpixel_x,
1240                                          step_x, intermediate_height,
1241                                          intermediate);
1242       } else {
1243         ConvolveHorizontalScale<1, 3, 2>(src, src_stride, width, subpixel_x,
1244                                          step_x, intermediate_height,
1245                                          intermediate);
1246       }
1247       break;
1248     case 4:
1249       assert(width <= 4);
1250       ConvolveHorizontalScale<1, 4, 4>(src, src_stride, width, subpixel_x,
1251                                        step_x, intermediate_height,
1252                                        intermediate);
1253       break;
1254     default:
1255       assert(horiz_filter_index == 5);
1256       assert(width <= 4);
1257       ConvolveHorizontalScale<1, 5, 4>(src, src_stride, width, subpixel_x,
1258                                        step_x, intermediate_height,
1259                                        intermediate);
1260   }
1261 
1262   // Vertical filter.
1263   intermediate = intermediate_result;
1264   switch (vert_filter_index) {
1265     case 0:
1266     case 1:
1267       if (!is_compound && width == 2) {
1268         ConvolveVerticalScale<6, 2, is_compound>(
1269             intermediate, intermediate_height, width, subpixel_y,
1270             vert_filter_index, step_y, height, prediction, pred_stride);
1271       } else if (width == 4) {
1272         ConvolveVerticalScale<6, 4, is_compound>(
1273             intermediate, intermediate_height, width, subpixel_y,
1274             vert_filter_index, step_y, height, prediction, pred_stride);
1275       } else {
1276         ConvolveVerticalScale<6, 8, is_compound>(
1277             intermediate, intermediate_height, width, subpixel_y,
1278             vert_filter_index, step_y, height, prediction, pred_stride);
1279       }
1280       break;
1281     case 2:
1282       if (!is_compound && width == 2) {
1283         ConvolveVerticalScale<8, 2, is_compound>(
1284             intermediate, intermediate_height, width, subpixel_y,
1285             vert_filter_index, step_y, height, prediction, pred_stride);
1286       } else if (width == 4) {
1287         ConvolveVerticalScale<8, 4, is_compound>(
1288             intermediate, intermediate_height, width, subpixel_y,
1289             vert_filter_index, step_y, height, prediction, pred_stride);
1290       } else {
1291         ConvolveVerticalScale<8, 8, is_compound>(
1292             intermediate, intermediate_height, width, subpixel_y,
1293             vert_filter_index, step_y, height, prediction, pred_stride);
1294       }
1295       break;
1296     case 3:
1297       if (!is_compound && width == 2) {
1298         ConvolveVerticalScale<2, 2, is_compound>(
1299             intermediate, intermediate_height, width, subpixel_y,
1300             vert_filter_index, step_y, height, prediction, pred_stride);
1301       } else if (width == 4) {
1302         ConvolveVerticalScale<2, 4, is_compound>(
1303             intermediate, intermediate_height, width, subpixel_y,
1304             vert_filter_index, step_y, height, prediction, pred_stride);
1305       } else {
1306         ConvolveVerticalScale<2, 8, is_compound>(
1307             intermediate, intermediate_height, width, subpixel_y,
1308             vert_filter_index, step_y, height, prediction, pred_stride);
1309       }
1310       break;
1311     default:
1312       assert(vert_filter_index == 4 || vert_filter_index == 5);
1313       if (!is_compound && width == 2) {
1314         ConvolveVerticalScale<4, 2, is_compound>(
1315             intermediate, intermediate_height, width, subpixel_y,
1316             vert_filter_index, step_y, height, prediction, pred_stride);
1317       } else if (width == 4) {
1318         ConvolveVerticalScale<4, 4, is_compound>(
1319             intermediate, intermediate_height, width, subpixel_y,
1320             vert_filter_index, step_y, height, prediction, pred_stride);
1321       } else {
1322         ConvolveVerticalScale<4, 8, is_compound>(
1323             intermediate, intermediate_height, width, subpixel_y,
1324             vert_filter_index, step_y, height, prediction, pred_stride);
1325       }
1326   }
1327 }
1328 
HalfAddHorizontal(const uint8_t * LIBGAV1_RESTRICT src,uint8_t * LIBGAV1_RESTRICT dst)1329 inline void HalfAddHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
1330                               uint8_t* LIBGAV1_RESTRICT dst) {
1331   const __m128i left = LoadUnaligned16(src);
1332   const __m128i right = LoadUnaligned16(src + 1);
1333   StoreUnaligned16(dst, _mm_avg_epu8(left, right));
1334 }
1335 
1336 template <int width>
IntraBlockCopyHorizontal(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)1337 inline void IntraBlockCopyHorizontal(const uint8_t* LIBGAV1_RESTRICT src,
1338                                      const ptrdiff_t src_stride,
1339                                      const int height,
1340                                      uint8_t* LIBGAV1_RESTRICT dst,
1341                                      const ptrdiff_t dst_stride) {
1342   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
1343   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
1344 
1345   int y = height;
1346   do {
1347     HalfAddHorizontal(src, dst);
1348     if (width >= 32) {
1349       src += 16;
1350       dst += 16;
1351       HalfAddHorizontal(src, dst);
1352       if (width >= 64) {
1353         src += 16;
1354         dst += 16;
1355         HalfAddHorizontal(src, dst);
1356         src += 16;
1357         dst += 16;
1358         HalfAddHorizontal(src, dst);
1359         if (width == 128) {
1360           src += 16;
1361           dst += 16;
1362           HalfAddHorizontal(src, dst);
1363           src += 16;
1364           dst += 16;
1365           HalfAddHorizontal(src, dst);
1366           src += 16;
1367           dst += 16;
1368           HalfAddHorizontal(src, dst);
1369           src += 16;
1370           dst += 16;
1371           HalfAddHorizontal(src, dst);
1372         }
1373       }
1374     }
1375     src += src_remainder_stride;
1376     dst += dst_remainder_stride;
1377   } while (--y != 0);
1378 }
1379 
ConvolveIntraBlockCopyHorizontal_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1380 void ConvolveIntraBlockCopyHorizontal_SSE4_1(
1381     const void* LIBGAV1_RESTRICT const reference,
1382     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1383     const int /*vertical_filter_index*/, const int /*subpixel_x*/,
1384     const int /*subpixel_y*/, const int width, const int height,
1385     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
1386   const auto* src = static_cast<const uint8_t*>(reference);
1387   auto* dest = static_cast<uint8_t*>(prediction);
1388 
1389   if (width == 128) {
1390     IntraBlockCopyHorizontal<128>(src, reference_stride, height, dest,
1391                                   pred_stride);
1392   } else if (width == 64) {
1393     IntraBlockCopyHorizontal<64>(src, reference_stride, height, dest,
1394                                  pred_stride);
1395   } else if (width == 32) {
1396     IntraBlockCopyHorizontal<32>(src, reference_stride, height, dest,
1397                                  pred_stride);
1398   } else if (width == 16) {
1399     IntraBlockCopyHorizontal<16>(src, reference_stride, height, dest,
1400                                  pred_stride);
1401   } else if (width == 8) {
1402     int y = height;
1403     do {
1404       const __m128i left = LoadLo8(src);
1405       const __m128i right = LoadLo8(src + 1);
1406       StoreLo8(dest, _mm_avg_epu8(left, right));
1407 
1408       src += reference_stride;
1409       dest += pred_stride;
1410     } while (--y != 0);
1411   } else if (width == 4) {
1412     int y = height;
1413     do {
1414       __m128i left = Load4(src);
1415       __m128i right = Load4(src + 1);
1416       src += reference_stride;
1417       left = _mm_unpacklo_epi32(left, Load4(src));
1418       right = _mm_unpacklo_epi32(right, Load4(src + 1));
1419       src += reference_stride;
1420 
1421       const __m128i result = _mm_avg_epu8(left, right);
1422 
1423       Store4(dest, result);
1424       dest += pred_stride;
1425       Store4(dest, _mm_srli_si128(result, 4));
1426       dest += pred_stride;
1427       y -= 2;
1428     } while (y != 0);
1429   } else {
1430     assert(width == 2);
1431     __m128i left = _mm_setzero_si128();
1432     __m128i right = _mm_setzero_si128();
1433     int y = height;
1434     do {
1435       left = Load2<0>(src, left);
1436       right = Load2<0>(src + 1, right);
1437       src += reference_stride;
1438       left = Load2<1>(src, left);
1439       right = Load2<1>(src + 1, right);
1440       src += reference_stride;
1441 
1442       const __m128i result = _mm_avg_epu8(left, right);
1443 
1444       Store2(dest, result);
1445       dest += pred_stride;
1446       Store2(dest, _mm_srli_si128(result, 2));
1447       dest += pred_stride;
1448       y -= 2;
1449     } while (y != 0);
1450   }
1451 }
1452 
1453 template <int width>
IntraBlockCopyVertical(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)1454 inline void IntraBlockCopyVertical(const uint8_t* LIBGAV1_RESTRICT src,
1455                                    const ptrdiff_t src_stride, const int height,
1456                                    uint8_t* LIBGAV1_RESTRICT dst,
1457                                    const ptrdiff_t dst_stride) {
1458   const ptrdiff_t src_remainder_stride = src_stride - (width - 16);
1459   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 16);
1460   __m128i row[8], below[8];
1461 
1462   row[0] = LoadUnaligned16(src);
1463   if (width >= 32) {
1464     src += 16;
1465     row[1] = LoadUnaligned16(src);
1466     if (width >= 64) {
1467       src += 16;
1468       row[2] = LoadUnaligned16(src);
1469       src += 16;
1470       row[3] = LoadUnaligned16(src);
1471       if (width == 128) {
1472         src += 16;
1473         row[4] = LoadUnaligned16(src);
1474         src += 16;
1475         row[5] = LoadUnaligned16(src);
1476         src += 16;
1477         row[6] = LoadUnaligned16(src);
1478         src += 16;
1479         row[7] = LoadUnaligned16(src);
1480       }
1481     }
1482   }
1483   src += src_remainder_stride;
1484 
1485   int y = height;
1486   do {
1487     below[0] = LoadUnaligned16(src);
1488     if (width >= 32) {
1489       src += 16;
1490       below[1] = LoadUnaligned16(src);
1491       if (width >= 64) {
1492         src += 16;
1493         below[2] = LoadUnaligned16(src);
1494         src += 16;
1495         below[3] = LoadUnaligned16(src);
1496         if (width == 128) {
1497           src += 16;
1498           below[4] = LoadUnaligned16(src);
1499           src += 16;
1500           below[5] = LoadUnaligned16(src);
1501           src += 16;
1502           below[6] = LoadUnaligned16(src);
1503           src += 16;
1504           below[7] = LoadUnaligned16(src);
1505         }
1506       }
1507     }
1508     src += src_remainder_stride;
1509 
1510     StoreUnaligned16(dst, _mm_avg_epu8(row[0], below[0]));
1511     row[0] = below[0];
1512     if (width >= 32) {
1513       dst += 16;
1514       StoreUnaligned16(dst, _mm_avg_epu8(row[1], below[1]));
1515       row[1] = below[1];
1516       if (width >= 64) {
1517         dst += 16;
1518         StoreUnaligned16(dst, _mm_avg_epu8(row[2], below[2]));
1519         row[2] = below[2];
1520         dst += 16;
1521         StoreUnaligned16(dst, _mm_avg_epu8(row[3], below[3]));
1522         row[3] = below[3];
1523         if (width >= 128) {
1524           dst += 16;
1525           StoreUnaligned16(dst, _mm_avg_epu8(row[4], below[4]));
1526           row[4] = below[4];
1527           dst += 16;
1528           StoreUnaligned16(dst, _mm_avg_epu8(row[5], below[5]));
1529           row[5] = below[5];
1530           dst += 16;
1531           StoreUnaligned16(dst, _mm_avg_epu8(row[6], below[6]));
1532           row[6] = below[6];
1533           dst += 16;
1534           StoreUnaligned16(dst, _mm_avg_epu8(row[7], below[7]));
1535           row[7] = below[7];
1536         }
1537       }
1538     }
1539     dst += dst_remainder_stride;
1540   } while (--y != 0);
1541 }
1542 
ConvolveIntraBlockCopyVertical_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1543 void ConvolveIntraBlockCopyVertical_SSE4_1(
1544     const void* LIBGAV1_RESTRICT const reference,
1545     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1546     const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
1547     const int /*vertical_filter_id*/, const int width, const int height,
1548     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
1549   const auto* src = static_cast<const uint8_t*>(reference);
1550   auto* dest = static_cast<uint8_t*>(prediction);
1551 
1552   if (width == 128) {
1553     IntraBlockCopyVertical<128>(src, reference_stride, height, dest,
1554                                 pred_stride);
1555   } else if (width == 64) {
1556     IntraBlockCopyVertical<64>(src, reference_stride, height, dest,
1557                                pred_stride);
1558   } else if (width == 32) {
1559     IntraBlockCopyVertical<32>(src, reference_stride, height, dest,
1560                                pred_stride);
1561   } else if (width == 16) {
1562     IntraBlockCopyVertical<16>(src, reference_stride, height, dest,
1563                                pred_stride);
1564   } else if (width == 8) {
1565     __m128i row, below;
1566     row = LoadLo8(src);
1567     src += reference_stride;
1568 
1569     int y = height;
1570     do {
1571       below = LoadLo8(src);
1572       src += reference_stride;
1573 
1574       StoreLo8(dest, _mm_avg_epu8(row, below));
1575       dest += pred_stride;
1576 
1577       row = below;
1578     } while (--y != 0);
1579   } else if (width == 4) {
1580     __m128i row = Load4(src);
1581     src += reference_stride;
1582 
1583     int y = height;
1584     do {
1585       __m128i below = Load4(src);
1586       src += reference_stride;
1587 
1588       Store4(dest, _mm_avg_epu8(row, below));
1589       dest += pred_stride;
1590 
1591       row = below;
1592     } while (--y != 0);
1593   } else {
1594     assert(width == 2);
1595     __m128i row = Load2(src);
1596     __m128i below = _mm_setzero_si128();
1597     src += reference_stride;
1598 
1599     int y = height;
1600     do {
1601       below = Load2<0>(src, below);
1602       src += reference_stride;
1603 
1604       Store2(dest, _mm_avg_epu8(row, below));
1605       dest += pred_stride;
1606 
1607       row = below;
1608     } while (--y != 0);
1609   }
1610 }
1611 
1612 // Load then add two uint8_t vectors. Return the uint16_t vector result.
LoadU8AndAddLong(const uint8_t * LIBGAV1_RESTRICT src,const uint8_t * LIBGAV1_RESTRICT src1)1613 inline __m128i LoadU8AndAddLong(const uint8_t* LIBGAV1_RESTRICT src,
1614                                 const uint8_t* LIBGAV1_RESTRICT src1) {
1615   const __m128i a = _mm_cvtepu8_epi16(LoadLo8(src));
1616   const __m128i b = _mm_cvtepu8_epi16(LoadLo8(src1));
1617   return _mm_add_epi16(a, b);
1618 }
1619 
AddU16RightShift2AndPack(__m128i v0,__m128i v1)1620 inline __m128i AddU16RightShift2AndPack(__m128i v0, __m128i v1) {
1621   const __m128i a = _mm_add_epi16(v0, v1);
1622   const __m128i b = _mm_srli_epi16(a, 1);
1623   // Use avg here to shift right by 1 with round.
1624   const __m128i c = _mm_avg_epu16(b, _mm_setzero_si128());
1625   return _mm_packus_epi16(c, c);
1626 }
1627 
1628 template <int width>
IntraBlockCopy2D(const uint8_t * LIBGAV1_RESTRICT src,const ptrdiff_t src_stride,const int height,uint8_t * LIBGAV1_RESTRICT dst,const ptrdiff_t dst_stride)1629 inline void IntraBlockCopy2D(const uint8_t* LIBGAV1_RESTRICT src,
1630                              const ptrdiff_t src_stride, const int height,
1631                              uint8_t* LIBGAV1_RESTRICT dst,
1632                              const ptrdiff_t dst_stride) {
1633   const ptrdiff_t src_remainder_stride = src_stride - (width - 8);
1634   const ptrdiff_t dst_remainder_stride = dst_stride - (width - 8);
1635   __m128i row[16];
1636   row[0] = LoadU8AndAddLong(src, src + 1);
1637   if (width >= 16) {
1638     src += 8;
1639     row[1] = LoadU8AndAddLong(src, src + 1);
1640     if (width >= 32) {
1641       src += 8;
1642       row[2] = LoadU8AndAddLong(src, src + 1);
1643       src += 8;
1644       row[3] = LoadU8AndAddLong(src, src + 1);
1645       if (width >= 64) {
1646         src += 8;
1647         row[4] = LoadU8AndAddLong(src, src + 1);
1648         src += 8;
1649         row[5] = LoadU8AndAddLong(src, src + 1);
1650         src += 8;
1651         row[6] = LoadU8AndAddLong(src, src + 1);
1652         src += 8;
1653         row[7] = LoadU8AndAddLong(src, src + 1);
1654         if (width == 128) {
1655           src += 8;
1656           row[8] = LoadU8AndAddLong(src, src + 1);
1657           src += 8;
1658           row[9] = LoadU8AndAddLong(src, src + 1);
1659           src += 8;
1660           row[10] = LoadU8AndAddLong(src, src + 1);
1661           src += 8;
1662           row[11] = LoadU8AndAddLong(src, src + 1);
1663           src += 8;
1664           row[12] = LoadU8AndAddLong(src, src + 1);
1665           src += 8;
1666           row[13] = LoadU8AndAddLong(src, src + 1);
1667           src += 8;
1668           row[14] = LoadU8AndAddLong(src, src + 1);
1669           src += 8;
1670           row[15] = LoadU8AndAddLong(src, src + 1);
1671         }
1672       }
1673     }
1674   }
1675   src += src_remainder_stride;
1676 
1677   int y = height;
1678   do {
1679     const __m128i below_0 = LoadU8AndAddLong(src, src + 1);
1680     StoreLo8(dst, AddU16RightShift2AndPack(row[0], below_0));
1681     row[0] = below_0;
1682     if (width >= 16) {
1683       src += 8;
1684       dst += 8;
1685 
1686       const __m128i below_1 = LoadU8AndAddLong(src, src + 1);
1687       StoreLo8(dst, AddU16RightShift2AndPack(row[1], below_1));
1688       row[1] = below_1;
1689       if (width >= 32) {
1690         src += 8;
1691         dst += 8;
1692 
1693         const __m128i below_2 = LoadU8AndAddLong(src, src + 1);
1694         StoreLo8(dst, AddU16RightShift2AndPack(row[2], below_2));
1695         row[2] = below_2;
1696         src += 8;
1697         dst += 8;
1698 
1699         const __m128i below_3 = LoadU8AndAddLong(src, src + 1);
1700         StoreLo8(dst, AddU16RightShift2AndPack(row[3], below_3));
1701         row[3] = below_3;
1702         if (width >= 64) {
1703           src += 8;
1704           dst += 8;
1705 
1706           const __m128i below_4 = LoadU8AndAddLong(src, src + 1);
1707           StoreLo8(dst, AddU16RightShift2AndPack(row[4], below_4));
1708           row[4] = below_4;
1709           src += 8;
1710           dst += 8;
1711 
1712           const __m128i below_5 = LoadU8AndAddLong(src, src + 1);
1713           StoreLo8(dst, AddU16RightShift2AndPack(row[5], below_5));
1714           row[5] = below_5;
1715           src += 8;
1716           dst += 8;
1717 
1718           const __m128i below_6 = LoadU8AndAddLong(src, src + 1);
1719           StoreLo8(dst, AddU16RightShift2AndPack(row[6], below_6));
1720           row[6] = below_6;
1721           src += 8;
1722           dst += 8;
1723 
1724           const __m128i below_7 = LoadU8AndAddLong(src, src + 1);
1725           StoreLo8(dst, AddU16RightShift2AndPack(row[7], below_7));
1726           row[7] = below_7;
1727           if (width == 128) {
1728             src += 8;
1729             dst += 8;
1730 
1731             const __m128i below_8 = LoadU8AndAddLong(src, src + 1);
1732             StoreLo8(dst, AddU16RightShift2AndPack(row[8], below_8));
1733             row[8] = below_8;
1734             src += 8;
1735             dst += 8;
1736 
1737             const __m128i below_9 = LoadU8AndAddLong(src, src + 1);
1738             StoreLo8(dst, AddU16RightShift2AndPack(row[9], below_9));
1739             row[9] = below_9;
1740             src += 8;
1741             dst += 8;
1742 
1743             const __m128i below_10 = LoadU8AndAddLong(src, src + 1);
1744             StoreLo8(dst, AddU16RightShift2AndPack(row[10], below_10));
1745             row[10] = below_10;
1746             src += 8;
1747             dst += 8;
1748 
1749             const __m128i below_11 = LoadU8AndAddLong(src, src + 1);
1750             StoreLo8(dst, AddU16RightShift2AndPack(row[11], below_11));
1751             row[11] = below_11;
1752             src += 8;
1753             dst += 8;
1754 
1755             const __m128i below_12 = LoadU8AndAddLong(src, src + 1);
1756             StoreLo8(dst, AddU16RightShift2AndPack(row[12], below_12));
1757             row[12] = below_12;
1758             src += 8;
1759             dst += 8;
1760 
1761             const __m128i below_13 = LoadU8AndAddLong(src, src + 1);
1762             StoreLo8(dst, AddU16RightShift2AndPack(row[13], below_13));
1763             row[13] = below_13;
1764             src += 8;
1765             dst += 8;
1766 
1767             const __m128i below_14 = LoadU8AndAddLong(src, src + 1);
1768             StoreLo8(dst, AddU16RightShift2AndPack(row[14], below_14));
1769             row[14] = below_14;
1770             src += 8;
1771             dst += 8;
1772 
1773             const __m128i below_15 = LoadU8AndAddLong(src, src + 1);
1774             StoreLo8(dst, AddU16RightShift2AndPack(row[15], below_15));
1775             row[15] = below_15;
1776           }
1777         }
1778       }
1779     }
1780     src += src_remainder_stride;
1781     dst += dst_remainder_stride;
1782   } while (--y != 0);
1783 }
1784 
ConvolveIntraBlockCopy2D_SSE4_1(const void * LIBGAV1_RESTRICT const reference,const ptrdiff_t reference_stride,const int,const int,const int,const int,const int width,const int height,void * LIBGAV1_RESTRICT const prediction,const ptrdiff_t pred_stride)1785 void ConvolveIntraBlockCopy2D_SSE4_1(
1786     const void* LIBGAV1_RESTRICT const reference,
1787     const ptrdiff_t reference_stride, const int /*horizontal_filter_index*/,
1788     const int /*vertical_filter_index*/, const int /*horizontal_filter_id*/,
1789     const int /*vertical_filter_id*/, const int width, const int height,
1790     void* LIBGAV1_RESTRICT const prediction, const ptrdiff_t pred_stride) {
1791   const auto* src = static_cast<const uint8_t*>(reference);
1792   auto* dest = static_cast<uint8_t*>(prediction);
1793   // Note: allow vertical access to height + 1. Because this function is only
1794   // for u/v plane of intra block copy, such access is guaranteed to be within
1795   // the prediction block.
1796 
1797   if (width == 128) {
1798     IntraBlockCopy2D<128>(src, reference_stride, height, dest, pred_stride);
1799   } else if (width == 64) {
1800     IntraBlockCopy2D<64>(src, reference_stride, height, dest, pred_stride);
1801   } else if (width == 32) {
1802     IntraBlockCopy2D<32>(src, reference_stride, height, dest, pred_stride);
1803   } else if (width == 16) {
1804     IntraBlockCopy2D<16>(src, reference_stride, height, dest, pred_stride);
1805   } else if (width == 8) {
1806     IntraBlockCopy2D<8>(src, reference_stride, height, dest, pred_stride);
1807   } else if (width == 4) {
1808     __m128i left = _mm_cvtepu8_epi16(Load4(src));
1809     __m128i right = _mm_cvtepu8_epi16(Load4(src + 1));
1810     src += reference_stride;
1811 
1812     __m128i row = _mm_add_epi16(left, right);
1813 
1814     int y = height;
1815     do {
1816       left = Load4(src);
1817       right = Load4(src + 1);
1818       src += reference_stride;
1819       left = _mm_unpacklo_epi32(left, Load4(src));
1820       right = _mm_unpacklo_epi32(right, Load4(src + 1));
1821       src += reference_stride;
1822 
1823       const __m128i below =
1824           _mm_add_epi16(_mm_cvtepu8_epi16(left), _mm_cvtepu8_epi16(right));
1825       const __m128i result =
1826           AddU16RightShift2AndPack(_mm_unpacklo_epi64(row, below), below);
1827 
1828       Store4(dest, result);
1829       dest += pred_stride;
1830       Store4(dest, _mm_srli_si128(result, 4));
1831       dest += pred_stride;
1832 
1833       row = _mm_srli_si128(below, 8);
1834       y -= 2;
1835     } while (y != 0);
1836   } else {
1837     __m128i left = Load2(src);
1838     __m128i right = Load2(src + 1);
1839     src += reference_stride;
1840 
1841     __m128i row =
1842         _mm_add_epi16(_mm_cvtepu8_epi16(left), _mm_cvtepu8_epi16(right));
1843 
1844     int y = height;
1845     do {
1846       left = Load2<0>(src, left);
1847       right = Load2<0>(src + 1, right);
1848       src += reference_stride;
1849       left = Load2<2>(src, left);
1850       right = Load2<2>(src + 1, right);
1851       src += reference_stride;
1852 
1853       const __m128i below =
1854           _mm_add_epi16(_mm_cvtepu8_epi16(left), _mm_cvtepu8_epi16(right));
1855       const __m128i result =
1856           AddU16RightShift2AndPack(_mm_unpacklo_epi64(row, below), below);
1857 
1858       Store2(dest, result);
1859       dest += pred_stride;
1860       Store2(dest, _mm_srli_si128(result, 4));
1861       dest += pred_stride;
1862 
1863       row = _mm_srli_si128(below, 8);
1864       y -= 2;
1865     } while (y != 0);
1866   }
1867 }
1868 
Init8bpp()1869 void Init8bpp() {
1870   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
1871   assert(dsp != nullptr);
1872   dsp->convolve[0][0][0][1] = ConvolveHorizontal_SSE4_1;
1873   dsp->convolve[0][0][1][0] = ConvolveVertical_SSE4_1;
1874   dsp->convolve[0][0][1][1] = Convolve2D_SSE4_1;
1875 
1876   dsp->convolve[0][1][0][0] = ConvolveCompoundCopy_SSE4_1;
1877   dsp->convolve[0][1][0][1] = ConvolveCompoundHorizontal_SSE4_1;
1878   dsp->convolve[0][1][1][0] = ConvolveCompoundVertical_SSE4_1;
1879   dsp->convolve[0][1][1][1] = ConvolveCompound2D_SSE4_1;
1880 
1881   dsp->convolve[1][0][0][1] = ConvolveIntraBlockCopyHorizontal_SSE4_1;
1882   dsp->convolve[1][0][1][0] = ConvolveIntraBlockCopyVertical_SSE4_1;
1883   dsp->convolve[1][0][1][1] = ConvolveIntraBlockCopy2D_SSE4_1;
1884 
1885   dsp->convolve_scale[0] = ConvolveScale2D_SSE4_1<false>;
1886   dsp->convolve_scale[1] = ConvolveScale2D_SSE4_1<true>;
1887 }
1888 
1889 }  // namespace
1890 }  // namespace low_bitdepth
1891 
ConvolveInit_SSE4_1()1892 void ConvolveInit_SSE4_1() { low_bitdepth::Init8bpp(); }
1893 
1894 }  // namespace dsp
1895 }  // namespace libgav1
1896 
1897 #else   // !LIBGAV1_TARGETING_SSE4_1
1898 namespace libgav1 {
1899 namespace dsp {
1900 
ConvolveInit_SSE4_1()1901 void ConvolveInit_SSE4_1() {}
1902 
1903 }  // namespace dsp
1904 }  // namespace libgav1
1905 #endif  // LIBGAV1_TARGETING_SSE4_1
1906