xref: /aosp_15_r20/external/libgav1/src/dsp/x86/loop_restoration_avx2.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/loop_restoration.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_AVX2
19 #include <immintrin.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 #include <cstring>
26 
27 #include "src/dsp/common.h"
28 #include "src/dsp/constants.h"
29 #include "src/dsp/dsp.h"
30 #include "src/dsp/x86/common_avx2.h"
31 #include "src/utils/common.h"
32 #include "src/utils/constants.h"
33 
34 namespace libgav1 {
35 namespace dsp {
36 namespace low_bitdepth {
37 namespace {
38 
WienerHorizontalClip(const __m256i s[2],const __m256i s_3x128,int16_t * const wiener_buffer)39 inline void WienerHorizontalClip(const __m256i s[2], const __m256i s_3x128,
40                                  int16_t* const wiener_buffer) {
41   constexpr int offset =
42       1 << (8 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
43   constexpr int limit =
44       (1 << (8 + 1 + kWienerFilterBits - kInterRoundBitsHorizontal)) - 1;
45   const __m256i offsets = _mm256_set1_epi16(-offset);
46   const __m256i limits = _mm256_set1_epi16(limit - offset);
47   const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsHorizontal - 1));
48   // The sum range here is [-128 * 255, 90 * 255].
49   const __m256i madd = _mm256_add_epi16(s[0], s[1]);
50   const __m256i sum = _mm256_add_epi16(madd, round);
51   const __m256i rounded_sum0 =
52       _mm256_srai_epi16(sum, kInterRoundBitsHorizontal);
53   // Add back scaled down offset correction.
54   const __m256i rounded_sum1 = _mm256_add_epi16(rounded_sum0, s_3x128);
55   const __m256i d0 = _mm256_max_epi16(rounded_sum1, offsets);
56   const __m256i d1 = _mm256_min_epi16(d0, limits);
57   StoreAligned32(wiener_buffer, d1);
58 }
59 
60 // Using _mm256_alignr_epi8() is about 8% faster than loading all and unpacking,
61 // because the compiler generates redundant code when loading all and unpacking.
WienerHorizontalTap7Kernel(const __m256i s[2],const __m256i filter[4],int16_t * const wiener_buffer)62 inline void WienerHorizontalTap7Kernel(const __m256i s[2],
63                                        const __m256i filter[4],
64                                        int16_t* const wiener_buffer) {
65   const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1);
66   const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5);
67   const auto s45 = _mm256_alignr_epi8(s[1], s[0], 9);
68   const auto s67 = _mm256_alignr_epi8(s[1], s[0], 13);
69   __m256i madds[4];
70   madds[0] = _mm256_maddubs_epi16(s01, filter[0]);
71   madds[1] = _mm256_maddubs_epi16(s23, filter[1]);
72   madds[2] = _mm256_maddubs_epi16(s45, filter[2]);
73   madds[3] = _mm256_maddubs_epi16(s67, filter[3]);
74   madds[0] = _mm256_add_epi16(madds[0], madds[2]);
75   madds[1] = _mm256_add_epi16(madds[1], madds[3]);
76   const __m256i s_3x128 = _mm256_slli_epi16(_mm256_srli_epi16(s23, 8),
77                                             7 - kInterRoundBitsHorizontal);
78   WienerHorizontalClip(madds, s_3x128, wiener_buffer);
79 }
80 
WienerHorizontalTap5Kernel(const __m256i s[2],const __m256i filter[3],int16_t * const wiener_buffer)81 inline void WienerHorizontalTap5Kernel(const __m256i s[2],
82                                        const __m256i filter[3],
83                                        int16_t* const wiener_buffer) {
84   const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1);
85   const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5);
86   const auto s45 = _mm256_alignr_epi8(s[1], s[0], 9);
87   __m256i madds[3];
88   madds[0] = _mm256_maddubs_epi16(s01, filter[0]);
89   madds[1] = _mm256_maddubs_epi16(s23, filter[1]);
90   madds[2] = _mm256_maddubs_epi16(s45, filter[2]);
91   madds[0] = _mm256_add_epi16(madds[0], madds[2]);
92   const __m256i s_3x128 = _mm256_srli_epi16(_mm256_slli_epi16(s23, 8),
93                                             kInterRoundBitsHorizontal + 1);
94   WienerHorizontalClip(madds, s_3x128, wiener_buffer);
95 }
96 
WienerHorizontalTap3Kernel(const __m256i s[2],const __m256i filter[2],int16_t * const wiener_buffer)97 inline void WienerHorizontalTap3Kernel(const __m256i s[2],
98                                        const __m256i filter[2],
99                                        int16_t* const wiener_buffer) {
100   const auto s01 = _mm256_alignr_epi8(s[1], s[0], 1);
101   const auto s23 = _mm256_alignr_epi8(s[1], s[0], 5);
102   __m256i madds[2];
103   madds[0] = _mm256_maddubs_epi16(s01, filter[0]);
104   madds[1] = _mm256_maddubs_epi16(s23, filter[1]);
105   const __m256i s_3x128 = _mm256_slli_epi16(_mm256_srli_epi16(s01, 8),
106                                             7 - kInterRoundBitsHorizontal);
107   WienerHorizontalClip(madds, s_3x128, wiener_buffer);
108 }
109 
WienerHorizontalTap7(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i coefficients,int16_t ** const wiener_buffer)110 inline void WienerHorizontalTap7(const uint8_t* src, const ptrdiff_t src_stride,
111                                  const ptrdiff_t width, const int height,
112                                  const __m256i coefficients,
113                                  int16_t** const wiener_buffer) {
114   __m256i filter[4];
115   filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0100));
116   filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302));
117   filter[2] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0102));
118   filter[3] = _mm256_shuffle_epi8(
119       coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8000)));
120   for (int y = height; y != 0; --y) {
121     __m256i s = LoadUnaligned32(src);
122     __m256i ss[4];
123     ss[0] = _mm256_unpacklo_epi8(s, s);
124     ptrdiff_t x = 0;
125     do {
126       ss[1] = _mm256_unpackhi_epi8(s, s);
127       s = LoadUnaligned32(src + x + 32);
128       ss[3] = _mm256_unpacklo_epi8(s, s);
129       ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21);
130       WienerHorizontalTap7Kernel(ss + 0, filter, *wiener_buffer + x + 0);
131       WienerHorizontalTap7Kernel(ss + 1, filter, *wiener_buffer + x + 16);
132       ss[0] = ss[3];
133       x += 32;
134     } while (x < width);
135     src += src_stride;
136     *wiener_buffer += width;
137   }
138 }
139 
WienerHorizontalTap5(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i coefficients,int16_t ** const wiener_buffer)140 inline void WienerHorizontalTap5(const uint8_t* src, const ptrdiff_t src_stride,
141                                  const ptrdiff_t width, const int height,
142                                  const __m256i coefficients,
143                                  int16_t** const wiener_buffer) {
144   __m256i filter[3];
145   filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0201));
146   filter[1] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0203));
147   filter[2] = _mm256_shuffle_epi8(
148       coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8001)));
149   for (int y = height; y != 0; --y) {
150     __m256i s = LoadUnaligned32(src);
151     __m256i ss[4];
152     ss[0] = _mm256_unpacklo_epi8(s, s);
153     ptrdiff_t x = 0;
154     do {
155       ss[1] = _mm256_unpackhi_epi8(s, s);
156       s = LoadUnaligned32(src + x + 32);
157       ss[3] = _mm256_unpacklo_epi8(s, s);
158       ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21);
159       WienerHorizontalTap5Kernel(ss + 0, filter, *wiener_buffer + x + 0);
160       WienerHorizontalTap5Kernel(ss + 1, filter, *wiener_buffer + x + 16);
161       ss[0] = ss[3];
162       x += 32;
163     } while (x < width);
164     src += src_stride;
165     *wiener_buffer += width;
166   }
167 }
168 
WienerHorizontalTap3(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i coefficients,int16_t ** const wiener_buffer)169 inline void WienerHorizontalTap3(const uint8_t* src, const ptrdiff_t src_stride,
170                                  const ptrdiff_t width, const int height,
171                                  const __m256i coefficients,
172                                  int16_t** const wiener_buffer) {
173   __m256i filter[2];
174   filter[0] = _mm256_shuffle_epi8(coefficients, _mm256_set1_epi16(0x0302));
175   filter[1] = _mm256_shuffle_epi8(
176       coefficients, _mm256_set1_epi16(static_cast<int16_t>(0x8002)));
177   for (int y = height; y != 0; --y) {
178     __m256i s = LoadUnaligned32(src);
179     __m256i ss[4];
180     ss[0] = _mm256_unpacklo_epi8(s, s);
181     ptrdiff_t x = 0;
182     do {
183       ss[1] = _mm256_unpackhi_epi8(s, s);
184       s = LoadUnaligned32(src + x + 32);
185       ss[3] = _mm256_unpacklo_epi8(s, s);
186       ss[2] = _mm256_permute2x128_si256(ss[0], ss[3], 0x21);
187       WienerHorizontalTap3Kernel(ss + 0, filter, *wiener_buffer + x + 0);
188       WienerHorizontalTap3Kernel(ss + 1, filter, *wiener_buffer + x + 16);
189       ss[0] = ss[3];
190       x += 32;
191     } while (x < width);
192     src += src_stride;
193     *wiener_buffer += width;
194   }
195 }
196 
WienerHorizontalTap1(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,int16_t ** const wiener_buffer)197 inline void WienerHorizontalTap1(const uint8_t* src, const ptrdiff_t src_stride,
198                                  const ptrdiff_t width, const int height,
199                                  int16_t** const wiener_buffer) {
200   for (int y = height; y != 0; --y) {
201     ptrdiff_t x = 0;
202     do {
203       const __m256i s = LoadUnaligned32(src + x);
204       const __m256i s0 = _mm256_unpacklo_epi8(s, _mm256_setzero_si256());
205       const __m256i s1 = _mm256_unpackhi_epi8(s, _mm256_setzero_si256());
206       __m256i d[2];
207       d[0] = _mm256_slli_epi16(s0, 4);
208       d[1] = _mm256_slli_epi16(s1, 4);
209       StoreAligned64(*wiener_buffer + x, d);
210       x += 32;
211     } while (x < width);
212     src += src_stride;
213     *wiener_buffer += width;
214   }
215 }
216 
WienerVertical7(const __m256i a[2],const __m256i filter[2])217 inline __m256i WienerVertical7(const __m256i a[2], const __m256i filter[2]) {
218   const __m256i round = _mm256_set1_epi32(1 << (kInterRoundBitsVertical - 1));
219   const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]);
220   const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]);
221   const __m256i sum0 = _mm256_add_epi32(round, madd0);
222   const __m256i sum1 = _mm256_add_epi32(sum0, madd1);
223   return _mm256_srai_epi32(sum1, kInterRoundBitsVertical);
224 }
225 
WienerVertical5(const __m256i a[2],const __m256i filter[2])226 inline __m256i WienerVertical5(const __m256i a[2], const __m256i filter[2]) {
227   const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]);
228   const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]);
229   const __m256i sum = _mm256_add_epi32(madd0, madd1);
230   return _mm256_srai_epi32(sum, kInterRoundBitsVertical);
231 }
232 
WienerVertical3(const __m256i a,const __m256i filter)233 inline __m256i WienerVertical3(const __m256i a, const __m256i filter) {
234   const __m256i round = _mm256_set1_epi32(1 << (kInterRoundBitsVertical - 1));
235   const __m256i madd = _mm256_madd_epi16(a, filter);
236   const __m256i sum = _mm256_add_epi32(round, madd);
237   return _mm256_srai_epi32(sum, kInterRoundBitsVertical);
238 }
239 
WienerVerticalFilter7(const __m256i a[7],const __m256i filter[2])240 inline __m256i WienerVerticalFilter7(const __m256i a[7],
241                                      const __m256i filter[2]) {
242   __m256i b[2];
243   const __m256i a06 = _mm256_add_epi16(a[0], a[6]);
244   const __m256i a15 = _mm256_add_epi16(a[1], a[5]);
245   const __m256i a24 = _mm256_add_epi16(a[2], a[4]);
246   b[0] = _mm256_unpacklo_epi16(a06, a15);
247   b[1] = _mm256_unpacklo_epi16(a24, a[3]);
248   const __m256i sum0 = WienerVertical7(b, filter);
249   b[0] = _mm256_unpackhi_epi16(a06, a15);
250   b[1] = _mm256_unpackhi_epi16(a24, a[3]);
251   const __m256i sum1 = WienerVertical7(b, filter);
252   return _mm256_packs_epi32(sum0, sum1);
253 }
254 
WienerVerticalFilter5(const __m256i a[5],const __m256i filter[2])255 inline __m256i WienerVerticalFilter5(const __m256i a[5],
256                                      const __m256i filter[2]) {
257   const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1));
258   __m256i b[2];
259   const __m256i a04 = _mm256_add_epi16(a[0], a[4]);
260   const __m256i a13 = _mm256_add_epi16(a[1], a[3]);
261   b[0] = _mm256_unpacklo_epi16(a04, a13);
262   b[1] = _mm256_unpacklo_epi16(a[2], round);
263   const __m256i sum0 = WienerVertical5(b, filter);
264   b[0] = _mm256_unpackhi_epi16(a04, a13);
265   b[1] = _mm256_unpackhi_epi16(a[2], round);
266   const __m256i sum1 = WienerVertical5(b, filter);
267   return _mm256_packs_epi32(sum0, sum1);
268 }
269 
WienerVerticalFilter3(const __m256i a[3],const __m256i filter)270 inline __m256i WienerVerticalFilter3(const __m256i a[3], const __m256i filter) {
271   __m256i b;
272   const __m256i a02 = _mm256_add_epi16(a[0], a[2]);
273   b = _mm256_unpacklo_epi16(a02, a[1]);
274   const __m256i sum0 = WienerVertical3(b, filter);
275   b = _mm256_unpackhi_epi16(a02, a[1]);
276   const __m256i sum1 = WienerVertical3(b, filter);
277   return _mm256_packs_epi32(sum0, sum1);
278 }
279 
WienerVerticalTap7Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i a[7])280 inline __m256i WienerVerticalTap7Kernel(const int16_t* wiener_buffer,
281                                         const ptrdiff_t wiener_stride,
282                                         const __m256i filter[2], __m256i a[7]) {
283   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
284   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
285   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
286   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
287   a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride);
288   a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride);
289   a[6] = LoadAligned32(wiener_buffer + 6 * wiener_stride);
290   return WienerVerticalFilter7(a, filter);
291 }
292 
WienerVerticalTap5Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i a[5])293 inline __m256i WienerVerticalTap5Kernel(const int16_t* wiener_buffer,
294                                         const ptrdiff_t wiener_stride,
295                                         const __m256i filter[2], __m256i a[5]) {
296   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
297   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
298   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
299   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
300   a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride);
301   return WienerVerticalFilter5(a, filter);
302 }
303 
WienerVerticalTap3Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter,__m256i a[3])304 inline __m256i WienerVerticalTap3Kernel(const int16_t* wiener_buffer,
305                                         const ptrdiff_t wiener_stride,
306                                         const __m256i filter, __m256i a[3]) {
307   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
308   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
309   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
310   return WienerVerticalFilter3(a, filter);
311 }
312 
WienerVerticalTap7Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i d[2])313 inline void WienerVerticalTap7Kernel2(const int16_t* wiener_buffer,
314                                       const ptrdiff_t wiener_stride,
315                                       const __m256i filter[2], __m256i d[2]) {
316   __m256i a[8];
317   d[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
318   a[7] = LoadAligned32(wiener_buffer + 7 * wiener_stride);
319   d[1] = WienerVerticalFilter7(a + 1, filter);
320 }
321 
WienerVerticalTap5Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i d[2])322 inline void WienerVerticalTap5Kernel2(const int16_t* wiener_buffer,
323                                       const ptrdiff_t wiener_stride,
324                                       const __m256i filter[2], __m256i d[2]) {
325   __m256i a[6];
326   d[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
327   a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride);
328   d[1] = WienerVerticalFilter5(a + 1, filter);
329 }
330 
WienerVerticalTap3Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter,__m256i d[2])331 inline void WienerVerticalTap3Kernel2(const int16_t* wiener_buffer,
332                                       const ptrdiff_t wiener_stride,
333                                       const __m256i filter, __m256i d[2]) {
334   __m256i a[4];
335   d[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
336   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
337   d[1] = WienerVerticalFilter3(a + 1, filter);
338 }
339 
WienerVerticalTap7(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[4],uint8_t * dst,const ptrdiff_t dst_stride)340 inline void WienerVerticalTap7(const int16_t* wiener_buffer,
341                                const ptrdiff_t width, const int height,
342                                const int16_t coefficients[4], uint8_t* dst,
343                                const ptrdiff_t dst_stride) {
344   const __m256i c = _mm256_broadcastq_epi64(LoadLo8(coefficients));
345   __m256i filter[2];
346   filter[0] = _mm256_shuffle_epi32(c, 0x0);
347   filter[1] = _mm256_shuffle_epi32(c, 0x55);
348   for (int y = height >> 1; y > 0; --y) {
349     ptrdiff_t x = 0;
350     do {
351       __m256i d[2][2];
352       WienerVerticalTap7Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
353       WienerVerticalTap7Kernel2(wiener_buffer + x + 16, width, filter, d[1]);
354       StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0]));
355       StoreUnaligned32(dst + dst_stride + x,
356                        _mm256_packus_epi16(d[0][1], d[1][1]));
357       x += 32;
358     } while (x < width);
359     dst += 2 * dst_stride;
360     wiener_buffer += 2 * width;
361   }
362 
363   if ((height & 1) != 0) {
364     ptrdiff_t x = 0;
365     do {
366       __m256i a[7];
367       const __m256i d0 =
368           WienerVerticalTap7Kernel(wiener_buffer + x + 0, width, filter, a);
369       const __m256i d1 =
370           WienerVerticalTap7Kernel(wiener_buffer + x + 16, width, filter, a);
371       StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
372       x += 32;
373     } while (x < width);
374   }
375 }
376 
WienerVerticalTap5(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[3],uint8_t * dst,const ptrdiff_t dst_stride)377 inline void WienerVerticalTap5(const int16_t* wiener_buffer,
378                                const ptrdiff_t width, const int height,
379                                const int16_t coefficients[3], uint8_t* dst,
380                                const ptrdiff_t dst_stride) {
381   const __m256i c = _mm256_broadcastd_epi32(Load4(coefficients));
382   __m256i filter[2];
383   filter[0] = _mm256_shuffle_epi32(c, 0);
384   filter[1] =
385       _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[2]));
386   for (int y = height >> 1; y > 0; --y) {
387     ptrdiff_t x = 0;
388     do {
389       __m256i d[2][2];
390       WienerVerticalTap5Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
391       WienerVerticalTap5Kernel2(wiener_buffer + x + 16, width, filter, d[1]);
392       StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0]));
393       StoreUnaligned32(dst + dst_stride + x,
394                        _mm256_packus_epi16(d[0][1], d[1][1]));
395       x += 32;
396     } while (x < width);
397     dst += 2 * dst_stride;
398     wiener_buffer += 2 * width;
399   }
400 
401   if ((height & 1) != 0) {
402     ptrdiff_t x = 0;
403     do {
404       __m256i a[5];
405       const __m256i d0 =
406           WienerVerticalTap5Kernel(wiener_buffer + x + 0, width, filter, a);
407       const __m256i d1 =
408           WienerVerticalTap5Kernel(wiener_buffer + x + 16, width, filter, a);
409       StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
410       x += 32;
411     } while (x < width);
412   }
413 }
414 
WienerVerticalTap3(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[2],uint8_t * dst,const ptrdiff_t dst_stride)415 inline void WienerVerticalTap3(const int16_t* wiener_buffer,
416                                const ptrdiff_t width, const int height,
417                                const int16_t coefficients[2], uint8_t* dst,
418                                const ptrdiff_t dst_stride) {
419   const __m256i filter =
420       _mm256_set1_epi32(*reinterpret_cast<const int32_t*>(coefficients));
421   for (int y = height >> 1; y > 0; --y) {
422     ptrdiff_t x = 0;
423     do {
424       __m256i d[2][2];
425       WienerVerticalTap3Kernel2(wiener_buffer + x + 0, width, filter, d[0]);
426       WienerVerticalTap3Kernel2(wiener_buffer + x + 16, width, filter, d[1]);
427       StoreUnaligned32(dst + x, _mm256_packus_epi16(d[0][0], d[1][0]));
428       StoreUnaligned32(dst + dst_stride + x,
429                        _mm256_packus_epi16(d[0][1], d[1][1]));
430       x += 32;
431     } while (x < width);
432     dst += 2 * dst_stride;
433     wiener_buffer += 2 * width;
434   }
435 
436   if ((height & 1) != 0) {
437     ptrdiff_t x = 0;
438     do {
439       __m256i a[3];
440       const __m256i d0 =
441           WienerVerticalTap3Kernel(wiener_buffer + x + 0, width, filter, a);
442       const __m256i d1 =
443           WienerVerticalTap3Kernel(wiener_buffer + x + 16, width, filter, a);
444       StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
445       x += 32;
446     } while (x < width);
447   }
448 }
449 
WienerVerticalTap1Kernel(const int16_t * const wiener_buffer,uint8_t * const dst)450 inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
451                                      uint8_t* const dst) {
452   const __m256i a0 = LoadAligned32(wiener_buffer + 0);
453   const __m256i a1 = LoadAligned32(wiener_buffer + 16);
454   const __m256i b0 = _mm256_add_epi16(a0, _mm256_set1_epi16(8));
455   const __m256i b1 = _mm256_add_epi16(a1, _mm256_set1_epi16(8));
456   const __m256i c0 = _mm256_srai_epi16(b0, 4);
457   const __m256i c1 = _mm256_srai_epi16(b1, 4);
458   const __m256i d = _mm256_packus_epi16(c0, c1);
459   StoreUnaligned32(dst, d);
460 }
461 
WienerVerticalTap1(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,uint8_t * dst,const ptrdiff_t dst_stride)462 inline void WienerVerticalTap1(const int16_t* wiener_buffer,
463                                const ptrdiff_t width, const int height,
464                                uint8_t* dst, const ptrdiff_t dst_stride) {
465   for (int y = height >> 1; y > 0; --y) {
466     ptrdiff_t x = 0;
467     do {
468       WienerVerticalTap1Kernel(wiener_buffer + x, dst + x);
469       WienerVerticalTap1Kernel(wiener_buffer + width + x, dst + dst_stride + x);
470       x += 32;
471     } while (x < width);
472     dst += 2 * dst_stride;
473     wiener_buffer += 2 * width;
474   }
475 
476   if ((height & 1) != 0) {
477     ptrdiff_t x = 0;
478     do {
479       WienerVerticalTap1Kernel(wiener_buffer + x, dst + x);
480       x += 32;
481     } while (x < width);
482   }
483 }
484 
WienerFilter_AVX2(const RestorationUnitInfo & LIBGAV1_RESTRICT restoration_info,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_border,const ptrdiff_t top_border_stride,const void * LIBGAV1_RESTRICT const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * LIBGAV1_RESTRICT const restoration_buffer,void * LIBGAV1_RESTRICT const dest)485 void WienerFilter_AVX2(
486     const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
487     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
488     const void* LIBGAV1_RESTRICT const top_border,
489     const ptrdiff_t top_border_stride,
490     const void* LIBGAV1_RESTRICT const bottom_border,
491     const ptrdiff_t bottom_border_stride, const int width, const int height,
492     RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
493     void* LIBGAV1_RESTRICT const dest) {
494   const int16_t* const number_leading_zero_coefficients =
495       restoration_info.wiener_info.number_leading_zero_coefficients;
496   const int number_rows_to_skip = std::max(
497       static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
498       1);
499   const ptrdiff_t wiener_stride = Align(width, 32);
500   int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
501   // The values are saturated to 13 bits before storing.
502   int16_t* wiener_buffer_horizontal =
503       wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
504 
505   // horizontal filtering.
506   // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
507   const int height_horizontal =
508       height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
509   const int height_extra = (height_horizontal - height) >> 1;
510   assert(height_extra <= 2);
511   const auto* const src = static_cast<const uint8_t*>(source);
512   const auto* const top = static_cast<const uint8_t*>(top_border);
513   const auto* const bottom = static_cast<const uint8_t*>(bottom_border);
514   const __m128i c =
515       LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]);
516   // In order to keep the horizontal pass intermediate values within 16 bits we
517   // offset |filter[3]| by 128. The 128 offset will be added back in the loop.
518   __m128i c_horizontal =
519       _mm_sub_epi16(c, _mm_setr_epi16(0, 0, 0, 128, 0, 0, 0, 0));
520   c_horizontal = _mm_packs_epi16(c_horizontal, c_horizontal);
521   const __m256i coefficients_horizontal = _mm256_broadcastd_epi32(c_horizontal);
522   if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
523     WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3,
524                          top_border_stride, wiener_stride, height_extra,
525                          coefficients_horizontal, &wiener_buffer_horizontal);
526     WienerHorizontalTap7(src - 3, stride, wiener_stride, height,
527                          coefficients_horizontal, &wiener_buffer_horizontal);
528     WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride,
529                          height_extra, coefficients_horizontal,
530                          &wiener_buffer_horizontal);
531   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
532     WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2,
533                          top_border_stride, wiener_stride, height_extra,
534                          coefficients_horizontal, &wiener_buffer_horizontal);
535     WienerHorizontalTap5(src - 2, stride, wiener_stride, height,
536                          coefficients_horizontal, &wiener_buffer_horizontal);
537     WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride,
538                          height_extra, coefficients_horizontal,
539                          &wiener_buffer_horizontal);
540   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
541     // The maximum over-reads happen here.
542     WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1,
543                          top_border_stride, wiener_stride, height_extra,
544                          coefficients_horizontal, &wiener_buffer_horizontal);
545     WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
546                          coefficients_horizontal, &wiener_buffer_horizontal);
547     WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride,
548                          height_extra, coefficients_horizontal,
549                          &wiener_buffer_horizontal);
550   } else {
551     assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
552     WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride,
553                          top_border_stride, wiener_stride, height_extra,
554                          &wiener_buffer_horizontal);
555     WienerHorizontalTap1(src, stride, wiener_stride, height,
556                          &wiener_buffer_horizontal);
557     WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride,
558                          height_extra, &wiener_buffer_horizontal);
559   }
560 
561   // vertical filtering.
562   // Over-writes up to 15 values.
563   const int16_t* const filter_vertical =
564       restoration_info.wiener_info.filter[WienerInfo::kVertical];
565   auto* dst = static_cast<uint8_t*>(dest);
566   if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
567     // Because the top row of |source| is a duplicate of the second row, and the
568     // bottom row of |source| is a duplicate of its above row, we can duplicate
569     // the top and bottom row of |wiener_buffer| accordingly.
570     memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
571            sizeof(*wiener_buffer_horizontal) * wiener_stride);
572     memcpy(restoration_buffer->wiener_buffer,
573            restoration_buffer->wiener_buffer + wiener_stride,
574            sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
575     WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
576                        filter_vertical, dst, stride);
577   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
578     WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
579                        height, filter_vertical + 1, dst, stride);
580   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
581     WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
582                        wiener_stride, height, filter_vertical + 2, dst, stride);
583   } else {
584     assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
585     WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
586                        wiener_stride, height, dst, stride);
587   }
588 }
589 
590 //------------------------------------------------------------------------------
591 // SGR
592 
593 constexpr int kSumOffset = 24;
594 
595 // SIMD overreads the number of bytes in SIMD registers - (width % 16) - 2 *
596 // padding pixels, where padding is 3 for Pass 1 and 2 for Pass 2. The number of
597 // bytes in SIMD registers is 16 for SSE4.1 and 32 for AVX2.
598 constexpr int kOverreadInBytesPass1_128 = 10;
599 constexpr int kOverreadInBytesPass2_128 = 12;
600 constexpr int kOverreadInBytesPass1_256 = kOverreadInBytesPass1_128 + 16;
601 constexpr int kOverreadInBytesPass2_256 = kOverreadInBytesPass2_128 + 16;
602 
LoadAligned16x2U16(const uint16_t * const src[2],const ptrdiff_t x,__m128i dst[2])603 inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x,
604                                __m128i dst[2]) {
605   dst[0] = LoadAligned16(src[0] + x);
606   dst[1] = LoadAligned16(src[1] + x);
607 }
608 
LoadAligned32x2U16(const uint16_t * const src[2],const ptrdiff_t x,__m256i dst[2])609 inline void LoadAligned32x2U16(const uint16_t* const src[2], const ptrdiff_t x,
610                                __m256i dst[2]) {
611   dst[0] = LoadAligned32(src[0] + x);
612   dst[1] = LoadAligned32(src[1] + x);
613 }
614 
LoadAligned32x2U16Msan(const uint16_t * const src[2],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[2])615 inline void LoadAligned32x2U16Msan(const uint16_t* const src[2],
616                                    const ptrdiff_t x, const ptrdiff_t border,
617                                    __m256i dst[2]) {
618   dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border));
619   dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border));
620 }
621 
LoadAligned16x3U16(const uint16_t * const src[3],const ptrdiff_t x,__m128i dst[3])622 inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x,
623                                __m128i dst[3]) {
624   dst[0] = LoadAligned16(src[0] + x);
625   dst[1] = LoadAligned16(src[1] + x);
626   dst[2] = LoadAligned16(src[2] + x);
627 }
628 
LoadAligned32x3U16(const uint16_t * const src[3],const ptrdiff_t x,__m256i dst[3])629 inline void LoadAligned32x3U16(const uint16_t* const src[3], const ptrdiff_t x,
630                                __m256i dst[3]) {
631   dst[0] = LoadAligned32(src[0] + x);
632   dst[1] = LoadAligned32(src[1] + x);
633   dst[2] = LoadAligned32(src[2] + x);
634 }
635 
LoadAligned32x3U16Msan(const uint16_t * const src[3],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[3])636 inline void LoadAligned32x3U16Msan(const uint16_t* const src[3],
637                                    const ptrdiff_t x, const ptrdiff_t border,
638                                    __m256i dst[3]) {
639   dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border));
640   dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border));
641   dst[2] = LoadAligned32Msan(src[2] + x, sizeof(**src) * (x + 16 - border));
642 }
643 
LoadAligned32U32(const uint32_t * const src,__m128i dst[2])644 inline void LoadAligned32U32(const uint32_t* const src, __m128i dst[2]) {
645   dst[0] = LoadAligned16(src + 0);
646   dst[1] = LoadAligned16(src + 4);
647 }
648 
LoadAligned32x2U32(const uint32_t * const src[2],const ptrdiff_t x,__m128i dst[2][2])649 inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x,
650                                __m128i dst[2][2]) {
651   LoadAligned32U32(src[0] + x, dst[0]);
652   LoadAligned32U32(src[1] + x, dst[1]);
653 }
654 
LoadAligned64x2U32(const uint32_t * const src[2],const ptrdiff_t x,__m256i dst[2][2])655 inline void LoadAligned64x2U32(const uint32_t* const src[2], const ptrdiff_t x,
656                                __m256i dst[2][2]) {
657   LoadAligned64(src[0] + x, dst[0]);
658   LoadAligned64(src[1] + x, dst[1]);
659 }
660 
LoadAligned64x2U32Msan(const uint32_t * const src[2],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[2][2])661 inline void LoadAligned64x2U32Msan(const uint32_t* const src[2],
662                                    const ptrdiff_t x, const ptrdiff_t border,
663                                    __m256i dst[2][2]) {
664   LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]);
665   LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]);
666 }
667 
LoadAligned32x3U32(const uint32_t * const src[3],const ptrdiff_t x,__m128i dst[3][2])668 inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x,
669                                __m128i dst[3][2]) {
670   LoadAligned32U32(src[0] + x, dst[0]);
671   LoadAligned32U32(src[1] + x, dst[1]);
672   LoadAligned32U32(src[2] + x, dst[2]);
673 }
674 
LoadAligned64x3U32(const uint32_t * const src[3],const ptrdiff_t x,__m256i dst[3][2])675 inline void LoadAligned64x3U32(const uint32_t* const src[3], const ptrdiff_t x,
676                                __m256i dst[3][2]) {
677   LoadAligned64(src[0] + x, dst[0]);
678   LoadAligned64(src[1] + x, dst[1]);
679   LoadAligned64(src[2] + x, dst[2]);
680 }
681 
LoadAligned64x3U32Msan(const uint32_t * const src[3],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[3][2])682 inline void LoadAligned64x3U32Msan(const uint32_t* const src[3],
683                                    const ptrdiff_t x, const ptrdiff_t border,
684                                    __m256i dst[3][2]) {
685   LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]);
686   LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]);
687   LoadAligned64Msan(src[2] + x, sizeof(**src) * (x + 16 - border), dst[2]);
688 }
689 
StoreAligned32U32(uint32_t * const dst,const __m128i src[2])690 inline void StoreAligned32U32(uint32_t* const dst, const __m128i src[2]) {
691   StoreAligned16(dst + 0, src[0]);
692   StoreAligned16(dst + 4, src[1]);
693 }
694 
695 // Don't use _mm_cvtepu8_epi16() or _mm_cvtepu16_epi32() in the following
696 // functions. Some compilers may generate super inefficient code and the whole
697 // decoder could be 15% slower.
698 
VaddlLo8(const __m128i src0,const __m128i src1)699 inline __m128i VaddlLo8(const __m128i src0, const __m128i src1) {
700   const __m128i s0 = _mm_unpacklo_epi8(src0, _mm_setzero_si128());
701   const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128());
702   return _mm_add_epi16(s0, s1);
703 }
704 
VaddlLo8(const __m256i src0,const __m256i src1)705 inline __m256i VaddlLo8(const __m256i src0, const __m256i src1) {
706   const __m256i s0 = _mm256_unpacklo_epi8(src0, _mm256_setzero_si256());
707   const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256());
708   return _mm256_add_epi16(s0, s1);
709 }
710 
VaddlHi8(const __m256i src0,const __m256i src1)711 inline __m256i VaddlHi8(const __m256i src0, const __m256i src1) {
712   const __m256i s0 = _mm256_unpackhi_epi8(src0, _mm256_setzero_si256());
713   const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256());
714   return _mm256_add_epi16(s0, s1);
715 }
716 
VaddlLo16(const __m128i src0,const __m128i src1)717 inline __m128i VaddlLo16(const __m128i src0, const __m128i src1) {
718   const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128());
719   const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
720   return _mm_add_epi32(s0, s1);
721 }
722 
VaddlLo16(const __m256i src0,const __m256i src1)723 inline __m256i VaddlLo16(const __m256i src0, const __m256i src1) {
724   const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256());
725   const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256());
726   return _mm256_add_epi32(s0, s1);
727 }
728 
VaddlHi16(const __m128i src0,const __m128i src1)729 inline __m128i VaddlHi16(const __m128i src0, const __m128i src1) {
730   const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128());
731   const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
732   return _mm_add_epi32(s0, s1);
733 }
734 
VaddlHi16(const __m256i src0,const __m256i src1)735 inline __m256i VaddlHi16(const __m256i src0, const __m256i src1) {
736   const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256());
737   const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256());
738   return _mm256_add_epi32(s0, s1);
739 }
740 
VaddwLo8(const __m128i src0,const __m128i src1)741 inline __m128i VaddwLo8(const __m128i src0, const __m128i src1) {
742   const __m128i s1 = _mm_unpacklo_epi8(src1, _mm_setzero_si128());
743   return _mm_add_epi16(src0, s1);
744 }
745 
VaddwLo8(const __m256i src0,const __m256i src1)746 inline __m256i VaddwLo8(const __m256i src0, const __m256i src1) {
747   const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256());
748   return _mm256_add_epi16(src0, s1);
749 }
750 
VaddwHi8(const __m256i src0,const __m256i src1)751 inline __m256i VaddwHi8(const __m256i src0, const __m256i src1) {
752   const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256());
753   return _mm256_add_epi16(src0, s1);
754 }
755 
VaddwLo16(const __m128i src0,const __m128i src1)756 inline __m128i VaddwLo16(const __m128i src0, const __m128i src1) {
757   const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
758   return _mm_add_epi32(src0, s1);
759 }
760 
VaddwLo16(const __m256i src0,const __m256i src1)761 inline __m256i VaddwLo16(const __m256i src0, const __m256i src1) {
762   const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256());
763   return _mm256_add_epi32(src0, s1);
764 }
765 
VaddwHi16(const __m128i src0,const __m128i src1)766 inline __m128i VaddwHi16(const __m128i src0, const __m128i src1) {
767   const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
768   return _mm_add_epi32(src0, s1);
769 }
770 
VaddwHi16(const __m256i src0,const __m256i src1)771 inline __m256i VaddwHi16(const __m256i src0, const __m256i src1) {
772   const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256());
773   return _mm256_add_epi32(src0, s1);
774 }
775 
VmullNLo8(const __m256i src0,const int src1)776 inline __m256i VmullNLo8(const __m256i src0, const int src1) {
777   const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256());
778   return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1));
779 }
780 
VmullNHi8(const __m256i src0,const int src1)781 inline __m256i VmullNHi8(const __m256i src0, const int src1) {
782   const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256());
783   return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1));
784 }
785 
VmullLo16(const __m128i src0,const __m128i src1)786 inline __m128i VmullLo16(const __m128i src0, const __m128i src1) {
787   const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128());
788   const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
789   return _mm_madd_epi16(s0, s1);
790 }
791 
VmullLo16(const __m256i src0,const __m256i src1)792 inline __m256i VmullLo16(const __m256i src0, const __m256i src1) {
793   const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256());
794   const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256());
795   return _mm256_madd_epi16(s0, s1);
796 }
797 
VmullHi16(const __m128i src0,const __m128i src1)798 inline __m128i VmullHi16(const __m128i src0, const __m128i src1) {
799   const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128());
800   const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
801   return _mm_madd_epi16(s0, s1);
802 }
803 
VmullHi16(const __m256i src0,const __m256i src1)804 inline __m256i VmullHi16(const __m256i src0, const __m256i src1) {
805   const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256());
806   const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256());
807   return _mm256_madd_epi16(s0, s1);
808 }
809 
VrshrS32(const __m256i src0,const int src1)810 inline __m256i VrshrS32(const __m256i src0, const int src1) {
811   const __m256i sum =
812       _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1)));
813   return _mm256_srai_epi32(sum, src1);
814 }
815 
VrshrU32(const __m128i src0,const int src1)816 inline __m128i VrshrU32(const __m128i src0, const int src1) {
817   const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1)));
818   return _mm_srli_epi32(sum, src1);
819 }
820 
VrshrU32(const __m256i src0,const int src1)821 inline __m256i VrshrU32(const __m256i src0, const int src1) {
822   const __m256i sum =
823       _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1)));
824   return _mm256_srli_epi32(sum, src1);
825 }
826 
SquareLo8(const __m128i src)827 inline __m128i SquareLo8(const __m128i src) {
828   const __m128i s = _mm_unpacklo_epi8(src, _mm_setzero_si128());
829   return _mm_mullo_epi16(s, s);
830 }
831 
SquareLo8(const __m256i src)832 inline __m256i SquareLo8(const __m256i src) {
833   const __m256i s = _mm256_unpacklo_epi8(src, _mm256_setzero_si256());
834   return _mm256_mullo_epi16(s, s);
835 }
836 
SquareHi8(const __m128i src)837 inline __m128i SquareHi8(const __m128i src) {
838   const __m128i s = _mm_unpackhi_epi8(src, _mm_setzero_si128());
839   return _mm_mullo_epi16(s, s);
840 }
841 
SquareHi8(const __m256i src)842 inline __m256i SquareHi8(const __m256i src) {
843   const __m256i s = _mm256_unpackhi_epi8(src, _mm256_setzero_si256());
844   return _mm256_mullo_epi16(s, s);
845 }
846 
Prepare3Lo8(const __m128i src,__m128i dst[3])847 inline void Prepare3Lo8(const __m128i src, __m128i dst[3]) {
848   dst[0] = src;
849   dst[1] = _mm_srli_si128(src, 1);
850   dst[2] = _mm_srli_si128(src, 2);
851 }
852 
Prepare3_8(const __m256i src[2],__m256i dst[3])853 inline void Prepare3_8(const __m256i src[2], __m256i dst[3]) {
854   dst[0] = _mm256_alignr_epi8(src[1], src[0], 0);
855   dst[1] = _mm256_alignr_epi8(src[1], src[0], 1);
856   dst[2] = _mm256_alignr_epi8(src[1], src[0], 2);
857 }
858 
Prepare3_16(const __m128i src[2],__m128i dst[3])859 inline void Prepare3_16(const __m128i src[2], __m128i dst[3]) {
860   dst[0] = src[0];
861   dst[1] = _mm_alignr_epi8(src[1], src[0], 2);
862   dst[2] = _mm_alignr_epi8(src[1], src[0], 4);
863 }
864 
Prepare3_16(const __m256i src[2],__m256i dst[3])865 inline void Prepare3_16(const __m256i src[2], __m256i dst[3]) {
866   dst[0] = src[0];
867   dst[1] = _mm256_alignr_epi8(src[1], src[0], 2);
868   dst[2] = _mm256_alignr_epi8(src[1], src[0], 4);
869 }
870 
Prepare5Lo8(const __m128i src,__m128i dst[5])871 inline void Prepare5Lo8(const __m128i src, __m128i dst[5]) {
872   dst[0] = src;
873   dst[1] = _mm_srli_si128(src, 1);
874   dst[2] = _mm_srli_si128(src, 2);
875   dst[3] = _mm_srli_si128(src, 3);
876   dst[4] = _mm_srli_si128(src, 4);
877 }
878 
Prepare5_16(const __m128i src[2],__m128i dst[5])879 inline void Prepare5_16(const __m128i src[2], __m128i dst[5]) {
880   Prepare3_16(src, dst);
881   dst[3] = _mm_alignr_epi8(src[1], src[0], 6);
882   dst[4] = _mm_alignr_epi8(src[1], src[0], 8);
883 }
884 
Prepare5_16(const __m256i src[2],__m256i dst[5])885 inline void Prepare5_16(const __m256i src[2], __m256i dst[5]) {
886   Prepare3_16(src, dst);
887   dst[3] = _mm256_alignr_epi8(src[1], src[0], 6);
888   dst[4] = _mm256_alignr_epi8(src[1], src[0], 8);
889 }
890 
Sum3_16(const __m128i src0,const __m128i src1,const __m128i src2)891 inline __m128i Sum3_16(const __m128i src0, const __m128i src1,
892                        const __m128i src2) {
893   const __m128i sum = _mm_add_epi16(src0, src1);
894   return _mm_add_epi16(sum, src2);
895 }
896 
Sum3_16(const __m256i src0,const __m256i src1,const __m256i src2)897 inline __m256i Sum3_16(const __m256i src0, const __m256i src1,
898                        const __m256i src2) {
899   const __m256i sum = _mm256_add_epi16(src0, src1);
900   return _mm256_add_epi16(sum, src2);
901 }
902 
Sum3_16(const __m128i src[3])903 inline __m128i Sum3_16(const __m128i src[3]) {
904   return Sum3_16(src[0], src[1], src[2]);
905 }
906 
Sum3_16(const __m256i src[3])907 inline __m256i Sum3_16(const __m256i src[3]) {
908   return Sum3_16(src[0], src[1], src[2]);
909 }
910 
Sum3_32(const __m128i src0,const __m128i src1,const __m128i src2)911 inline __m128i Sum3_32(const __m128i src0, const __m128i src1,
912                        const __m128i src2) {
913   const __m128i sum = _mm_add_epi32(src0, src1);
914   return _mm_add_epi32(sum, src2);
915 }
916 
Sum3_32(const __m256i src0,const __m256i src1,const __m256i src2)917 inline __m256i Sum3_32(const __m256i src0, const __m256i src1,
918                        const __m256i src2) {
919   const __m256i sum = _mm256_add_epi32(src0, src1);
920   return _mm256_add_epi32(sum, src2);
921 }
922 
Sum3_32(const __m128i src[3][2],__m128i dst[2])923 inline void Sum3_32(const __m128i src[3][2], __m128i dst[2]) {
924   dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]);
925   dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]);
926 }
927 
Sum3_32(const __m256i src[3][2],__m256i dst[2])928 inline void Sum3_32(const __m256i src[3][2], __m256i dst[2]) {
929   dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]);
930   dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]);
931 }
932 
Sum3WLo16(const __m128i src[3])933 inline __m128i Sum3WLo16(const __m128i src[3]) {
934   const __m128i sum = VaddlLo8(src[0], src[1]);
935   return VaddwLo8(sum, src[2]);
936 }
937 
Sum3WLo16(const __m256i src[3])938 inline __m256i Sum3WLo16(const __m256i src[3]) {
939   const __m256i sum = VaddlLo8(src[0], src[1]);
940   return VaddwLo8(sum, src[2]);
941 }
942 
Sum3WHi16(const __m256i src[3])943 inline __m256i Sum3WHi16(const __m256i src[3]) {
944   const __m256i sum = VaddlHi8(src[0], src[1]);
945   return VaddwHi8(sum, src[2]);
946 }
947 
Sum3WLo32(const __m128i src[3])948 inline __m128i Sum3WLo32(const __m128i src[3]) {
949   const __m128i sum = VaddlLo16(src[0], src[1]);
950   return VaddwLo16(sum, src[2]);
951 }
952 
Sum3WLo32(const __m256i src[3])953 inline __m256i Sum3WLo32(const __m256i src[3]) {
954   const __m256i sum = VaddlLo16(src[0], src[1]);
955   return VaddwLo16(sum, src[2]);
956 }
957 
Sum3WHi32(const __m128i src[3])958 inline __m128i Sum3WHi32(const __m128i src[3]) {
959   const __m128i sum = VaddlHi16(src[0], src[1]);
960   return VaddwHi16(sum, src[2]);
961 }
962 
Sum3WHi32(const __m256i src[3])963 inline __m256i Sum3WHi32(const __m256i src[3]) {
964   const __m256i sum = VaddlHi16(src[0], src[1]);
965   return VaddwHi16(sum, src[2]);
966 }
967 
Sum5_16(const __m128i src[5])968 inline __m128i Sum5_16(const __m128i src[5]) {
969   const __m128i sum01 = _mm_add_epi16(src[0], src[1]);
970   const __m128i sum23 = _mm_add_epi16(src[2], src[3]);
971   const __m128i sum = _mm_add_epi16(sum01, sum23);
972   return _mm_add_epi16(sum, src[4]);
973 }
974 
Sum5_16(const __m256i src[5])975 inline __m256i Sum5_16(const __m256i src[5]) {
976   const __m256i sum01 = _mm256_add_epi16(src[0], src[1]);
977   const __m256i sum23 = _mm256_add_epi16(src[2], src[3]);
978   const __m256i sum = _mm256_add_epi16(sum01, sum23);
979   return _mm256_add_epi16(sum, src[4]);
980 }
981 
Sum5_32(const __m128i * const src0,const __m128i * const src1,const __m128i * const src2,const __m128i * const src3,const __m128i * const src4)982 inline __m128i Sum5_32(const __m128i* const src0, const __m128i* const src1,
983                        const __m128i* const src2, const __m128i* const src3,
984                        const __m128i* const src4) {
985   const __m128i sum01 = _mm_add_epi32(*src0, *src1);
986   const __m128i sum23 = _mm_add_epi32(*src2, *src3);
987   const __m128i sum = _mm_add_epi32(sum01, sum23);
988   return _mm_add_epi32(sum, *src4);
989 }
990 
Sum5_32(const __m256i * const src0,const __m256i * const src1,const __m256i * const src2,const __m256i * const src3,const __m256i * const src4)991 inline __m256i Sum5_32(const __m256i* const src0, const __m256i* const src1,
992                        const __m256i* const src2, const __m256i* const src3,
993                        const __m256i* const src4) {
994   const __m256i sum01 = _mm256_add_epi32(*src0, *src1);
995   const __m256i sum23 = _mm256_add_epi32(*src2, *src3);
996   const __m256i sum = _mm256_add_epi32(sum01, sum23);
997   return _mm256_add_epi32(sum, *src4);
998 }
999 
Sum5_32(const __m128i src[5][2],__m128i dst[2])1000 inline void Sum5_32(const __m128i src[5][2], __m128i dst[2]) {
1001   dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]);
1002   dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]);
1003 }
1004 
Sum5_32(const __m256i src[5][2],__m256i dst[2])1005 inline void Sum5_32(const __m256i src[5][2], __m256i dst[2]) {
1006   dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]);
1007   dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]);
1008 }
1009 
Sum5WLo16(const __m128i src[5])1010 inline __m128i Sum5WLo16(const __m128i src[5]) {
1011   const __m128i sum01 = VaddlLo8(src[0], src[1]);
1012   const __m128i sum23 = VaddlLo8(src[2], src[3]);
1013   const __m128i sum = _mm_add_epi16(sum01, sum23);
1014   return VaddwLo8(sum, src[4]);
1015 }
1016 
Sum5WLo16(const __m256i src[5])1017 inline __m256i Sum5WLo16(const __m256i src[5]) {
1018   const __m256i sum01 = VaddlLo8(src[0], src[1]);
1019   const __m256i sum23 = VaddlLo8(src[2], src[3]);
1020   const __m256i sum = _mm256_add_epi16(sum01, sum23);
1021   return VaddwLo8(sum, src[4]);
1022 }
1023 
Sum5WHi16(const __m256i src[5])1024 inline __m256i Sum5WHi16(const __m256i src[5]) {
1025   const __m256i sum01 = VaddlHi8(src[0], src[1]);
1026   const __m256i sum23 = VaddlHi8(src[2], src[3]);
1027   const __m256i sum = _mm256_add_epi16(sum01, sum23);
1028   return VaddwHi8(sum, src[4]);
1029 }
1030 
Sum3Horizontal(const __m128i src)1031 inline __m128i Sum3Horizontal(const __m128i src) {
1032   __m128i s[3];
1033   Prepare3Lo8(src, s);
1034   return Sum3WLo16(s);
1035 }
1036 
Sum3Horizontal(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,__m256i dst[2])1037 inline void Sum3Horizontal(const uint8_t* const src,
1038                            const ptrdiff_t over_read_in_bytes, __m256i dst[2]) {
1039   __m256i s[3];
1040   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
1041   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1);
1042   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2);
1043   dst[0] = Sum3WLo16(s);
1044   dst[1] = Sum3WHi16(s);
1045 }
1046 
Sum3WHorizontal(const __m128i src[2],__m128i dst[2])1047 inline void Sum3WHorizontal(const __m128i src[2], __m128i dst[2]) {
1048   __m128i s[3];
1049   Prepare3_16(src, s);
1050   dst[0] = Sum3WLo32(s);
1051   dst[1] = Sum3WHi32(s);
1052 }
1053 
Sum3WHorizontal(const __m256i src[2],__m256i dst[2])1054 inline void Sum3WHorizontal(const __m256i src[2], __m256i dst[2]) {
1055   __m256i s[3];
1056   Prepare3_16(src, s);
1057   dst[0] = Sum3WLo32(s);
1058   dst[1] = Sum3WHi32(s);
1059 }
1060 
Sum5Horizontal(const __m128i src)1061 inline __m128i Sum5Horizontal(const __m128i src) {
1062   __m128i s[5];
1063   Prepare5Lo8(src, s);
1064   return Sum5WLo16(s);
1065 }
1066 
Sum5Horizontal(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,__m256i * const dst0,__m256i * const dst1)1067 inline void Sum5Horizontal(const uint8_t* const src,
1068                            const ptrdiff_t over_read_in_bytes,
1069                            __m256i* const dst0, __m256i* const dst1) {
1070   __m256i s[5];
1071   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
1072   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1);
1073   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2);
1074   s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 3);
1075   s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 4);
1076   *dst0 = Sum5WLo16(s);
1077   *dst1 = Sum5WHi16(s);
1078 }
1079 
Sum5WHorizontal(const __m128i src[2],__m128i dst[2])1080 inline void Sum5WHorizontal(const __m128i src[2], __m128i dst[2]) {
1081   __m128i s[5];
1082   Prepare5_16(src, s);
1083   const __m128i sum01_lo = VaddlLo16(s[0], s[1]);
1084   const __m128i sum23_lo = VaddlLo16(s[2], s[3]);
1085   const __m128i sum0123_lo = _mm_add_epi32(sum01_lo, sum23_lo);
1086   dst[0] = VaddwLo16(sum0123_lo, s[4]);
1087   const __m128i sum01_hi = VaddlHi16(s[0], s[1]);
1088   const __m128i sum23_hi = VaddlHi16(s[2], s[3]);
1089   const __m128i sum0123_hi = _mm_add_epi32(sum01_hi, sum23_hi);
1090   dst[1] = VaddwHi16(sum0123_hi, s[4]);
1091 }
1092 
Sum5WHorizontal(const __m256i src[2],__m256i dst[2])1093 inline void Sum5WHorizontal(const __m256i src[2], __m256i dst[2]) {
1094   __m256i s[5];
1095   Prepare5_16(src, s);
1096   const __m256i sum01_lo = VaddlLo16(s[0], s[1]);
1097   const __m256i sum23_lo = VaddlLo16(s[2], s[3]);
1098   const __m256i sum0123_lo = _mm256_add_epi32(sum01_lo, sum23_lo);
1099   dst[0] = VaddwLo16(sum0123_lo, s[4]);
1100   const __m256i sum01_hi = VaddlHi16(s[0], s[1]);
1101   const __m256i sum23_hi = VaddlHi16(s[2], s[3]);
1102   const __m256i sum0123_hi = _mm256_add_epi32(sum01_hi, sum23_hi);
1103   dst[1] = VaddwHi16(sum0123_hi, s[4]);
1104 }
1105 
SumHorizontalLo(const __m128i src[5],__m128i * const row_sq3,__m128i * const row_sq5)1106 void SumHorizontalLo(const __m128i src[5], __m128i* const row_sq3,
1107                      __m128i* const row_sq5) {
1108   const __m128i sum04 = VaddlLo16(src[0], src[4]);
1109   *row_sq3 = Sum3WLo32(src + 1);
1110   *row_sq5 = _mm_add_epi32(sum04, *row_sq3);
1111 }
1112 
SumHorizontalLo(const __m256i src[5],__m256i * const row_sq3,__m256i * const row_sq5)1113 void SumHorizontalLo(const __m256i src[5], __m256i* const row_sq3,
1114                      __m256i* const row_sq5) {
1115   const __m256i sum04 = VaddlLo16(src[0], src[4]);
1116   *row_sq3 = Sum3WLo32(src + 1);
1117   *row_sq5 = _mm256_add_epi32(sum04, *row_sq3);
1118 }
1119 
SumHorizontalHi(const __m128i src[5],__m128i * const row_sq3,__m128i * const row_sq5)1120 void SumHorizontalHi(const __m128i src[5], __m128i* const row_sq3,
1121                      __m128i* const row_sq5) {
1122   const __m128i sum04 = VaddlHi16(src[0], src[4]);
1123   *row_sq3 = Sum3WHi32(src + 1);
1124   *row_sq5 = _mm_add_epi32(sum04, *row_sq3);
1125 }
1126 
SumHorizontalHi(const __m256i src[5],__m256i * const row_sq3,__m256i * const row_sq5)1127 void SumHorizontalHi(const __m256i src[5], __m256i* const row_sq3,
1128                      __m256i* const row_sq5) {
1129   const __m256i sum04 = VaddlHi16(src[0], src[4]);
1130   *row_sq3 = Sum3WHi32(src + 1);
1131   *row_sq5 = _mm256_add_epi32(sum04, *row_sq3);
1132 }
1133 
SumHorizontalLo(const __m128i src,__m128i * const row3,__m128i * const row5)1134 void SumHorizontalLo(const __m128i src, __m128i* const row3,
1135                      __m128i* const row5) {
1136   __m128i s[5];
1137   Prepare5Lo8(src, s);
1138   const __m128i sum04 = VaddlLo8(s[0], s[4]);
1139   *row3 = Sum3WLo16(s + 1);
1140   *row5 = _mm_add_epi16(sum04, *row3);
1141 }
1142 
SumHorizontal(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,__m256i * const row3_0,__m256i * const row3_1,__m256i * const row5_0,__m256i * const row5_1)1143 inline void SumHorizontal(const uint8_t* const src,
1144                           const ptrdiff_t over_read_in_bytes,
1145                           __m256i* const row3_0, __m256i* const row3_1,
1146                           __m256i* const row5_0, __m256i* const row5_1) {
1147   __m256i s[5];
1148   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
1149   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 1);
1150   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 2);
1151   s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 3);
1152   s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 4);
1153   const __m256i sum04_lo = VaddlLo8(s[0], s[4]);
1154   const __m256i sum04_hi = VaddlHi8(s[0], s[4]);
1155   *row3_0 = Sum3WLo16(s + 1);
1156   *row3_1 = Sum3WHi16(s + 1);
1157   *row5_0 = _mm256_add_epi16(sum04_lo, *row3_0);
1158   *row5_1 = _mm256_add_epi16(sum04_hi, *row3_1);
1159 }
1160 
SumHorizontal(const __m128i src[2],__m128i * const row_sq3_0,__m128i * const row_sq3_1,__m128i * const row_sq5_0,__m128i * const row_sq5_1)1161 inline void SumHorizontal(const __m128i src[2], __m128i* const row_sq3_0,
1162                           __m128i* const row_sq3_1, __m128i* const row_sq5_0,
1163                           __m128i* const row_sq5_1) {
1164   __m128i s[5];
1165   Prepare5_16(src, s);
1166   SumHorizontalLo(s, row_sq3_0, row_sq5_0);
1167   SumHorizontalHi(s, row_sq3_1, row_sq5_1);
1168 }
1169 
SumHorizontal(const __m256i src[2],__m256i * const row_sq3_0,__m256i * const row_sq3_1,__m256i * const row_sq5_0,__m256i * const row_sq5_1)1170 inline void SumHorizontal(const __m256i src[2], __m256i* const row_sq3_0,
1171                           __m256i* const row_sq3_1, __m256i* const row_sq5_0,
1172                           __m256i* const row_sq5_1) {
1173   __m256i s[5];
1174   Prepare5_16(src, s);
1175   SumHorizontalLo(s, row_sq3_0, row_sq5_0);
1176   SumHorizontalHi(s, row_sq3_1, row_sq5_1);
1177 }
1178 
Sum343Lo(const __m256i ma3[3])1179 inline __m256i Sum343Lo(const __m256i ma3[3]) {
1180   const __m256i sum = Sum3WLo16(ma3);
1181   const __m256i sum3 = Sum3_16(sum, sum, sum);
1182   return VaddwLo8(sum3, ma3[1]);
1183 }
1184 
Sum343Hi(const __m256i ma3[3])1185 inline __m256i Sum343Hi(const __m256i ma3[3]) {
1186   const __m256i sum = Sum3WHi16(ma3);
1187   const __m256i sum3 = Sum3_16(sum, sum, sum);
1188   return VaddwHi8(sum3, ma3[1]);
1189 }
1190 
Sum343WLo(const __m256i src[3])1191 inline __m256i Sum343WLo(const __m256i src[3]) {
1192   const __m256i sum = Sum3WLo32(src);
1193   const __m256i sum3 = Sum3_32(sum, sum, sum);
1194   return VaddwLo16(sum3, src[1]);
1195 }
1196 
Sum343WHi(const __m256i src[3])1197 inline __m256i Sum343WHi(const __m256i src[3]) {
1198   const __m256i sum = Sum3WHi32(src);
1199   const __m256i sum3 = Sum3_32(sum, sum, sum);
1200   return VaddwHi16(sum3, src[1]);
1201 }
1202 
Sum343W(const __m256i src[2],__m256i dst[2])1203 inline void Sum343W(const __m256i src[2], __m256i dst[2]) {
1204   __m256i s[3];
1205   Prepare3_16(src, s);
1206   dst[0] = Sum343WLo(s);
1207   dst[1] = Sum343WHi(s);
1208 }
1209 
Sum565Lo(const __m256i src[3])1210 inline __m256i Sum565Lo(const __m256i src[3]) {
1211   const __m256i sum = Sum3WLo16(src);
1212   const __m256i sum4 = _mm256_slli_epi16(sum, 2);
1213   const __m256i sum5 = _mm256_add_epi16(sum4, sum);
1214   return VaddwLo8(sum5, src[1]);
1215 }
1216 
Sum565Hi(const __m256i src[3])1217 inline __m256i Sum565Hi(const __m256i src[3]) {
1218   const __m256i sum = Sum3WHi16(src);
1219   const __m256i sum4 = _mm256_slli_epi16(sum, 2);
1220   const __m256i sum5 = _mm256_add_epi16(sum4, sum);
1221   return VaddwHi8(sum5, src[1]);
1222 }
1223 
Sum565WLo(const __m256i src[3])1224 inline __m256i Sum565WLo(const __m256i src[3]) {
1225   const __m256i sum = Sum3WLo32(src);
1226   const __m256i sum4 = _mm256_slli_epi32(sum, 2);
1227   const __m256i sum5 = _mm256_add_epi32(sum4, sum);
1228   return VaddwLo16(sum5, src[1]);
1229 }
1230 
Sum565WHi(const __m256i src[3])1231 inline __m256i Sum565WHi(const __m256i src[3]) {
1232   const __m256i sum = Sum3WHi32(src);
1233   const __m256i sum4 = _mm256_slli_epi32(sum, 2);
1234   const __m256i sum5 = _mm256_add_epi32(sum4, sum);
1235   return VaddwHi16(sum5, src[1]);
1236 }
1237 
Sum565W(const __m256i src[2],__m256i dst[2])1238 inline void Sum565W(const __m256i src[2], __m256i dst[2]) {
1239   __m256i s[3];
1240   Prepare3_16(src, s);
1241   dst[0] = Sum565WLo(s);
1242   dst[1] = Sum565WHi(s);
1243 }
1244 
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sum3,uint16_t * sum5,uint32_t * square_sum3,uint32_t * square_sum5)1245 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
1246                    const ptrdiff_t width, const ptrdiff_t sum_stride,
1247                    const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5,
1248                    uint32_t* square_sum3, uint32_t* square_sum5) {
1249   int y = 2;
1250   do {
1251     const __m128i s0 =
1252         LoadUnaligned16Msan(src, kOverreadInBytesPass1_128 - width);
1253     __m128i sq_128[2], s3, s5, sq3[2], sq5[2];
1254     __m256i sq[3];
1255     sq_128[0] = SquareLo8(s0);
1256     sq_128[1] = SquareHi8(s0);
1257     SumHorizontalLo(s0, &s3, &s5);
1258     StoreAligned16(sum3, s3);
1259     StoreAligned16(sum5, s5);
1260     SumHorizontal(sq_128, &sq3[0], &sq3[1], &sq5[0], &sq5[1]);
1261     StoreAligned32U32(square_sum3, sq3);
1262     StoreAligned32U32(square_sum5, sq5);
1263     src += 8;
1264     sum3 += 8;
1265     sum5 += 8;
1266     square_sum3 += 8;
1267     square_sum5 += 8;
1268     sq[0] = SetrM128i(sq_128[1], sq_128[1]);
1269     ptrdiff_t x = sum_width;
1270     do {
1271       __m256i row3[2], row5[2], row_sq3[2], row_sq5[2];
1272       const __m256i s = LoadUnaligned32Msan(
1273           src + 8, sum_width - x + 16 + kOverreadInBytesPass1_256 - width);
1274       sq[1] = SquareLo8(s);
1275       sq[2] = SquareHi8(s);
1276       sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1277       SumHorizontal(src, sum_width - x + 8 + kOverreadInBytesPass1_256 - width,
1278                     &row3[0], &row3[1], &row5[0], &row5[1]);
1279       StoreAligned64(sum3, row3);
1280       StoreAligned64(sum5, row5);
1281       SumHorizontal(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0], &row_sq5[1]);
1282       StoreAligned64(square_sum3 + 0, row_sq3);
1283       StoreAligned64(square_sum5 + 0, row_sq5);
1284       SumHorizontal(sq + 1, &row_sq3[0], &row_sq3[1], &row_sq5[0], &row_sq5[1]);
1285       StoreAligned64(square_sum3 + 16, row_sq3);
1286       StoreAligned64(square_sum5 + 16, row_sq5);
1287       sq[0] = sq[2];
1288       src += 32;
1289       sum3 += 32;
1290       sum5 += 32;
1291       square_sum3 += 32;
1292       square_sum5 += 32;
1293       x -= 32;
1294     } while (x != 0);
1295     src += src_stride - sum_width - 8;
1296     sum3 += sum_stride - sum_width - 8;
1297     sum5 += sum_stride - sum_width - 8;
1298     square_sum3 += sum_stride - sum_width - 8;
1299     square_sum5 += sum_stride - sum_width - 8;
1300   } while (--y != 0);
1301 }
1302 
1303 template <int size>
BoxSum(const uint8_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sums,uint32_t * square_sums)1304 inline void BoxSum(const uint8_t* src, const ptrdiff_t src_stride,
1305                    const ptrdiff_t width, const ptrdiff_t sum_stride,
1306                    const ptrdiff_t sum_width, uint16_t* sums,
1307                    uint32_t* square_sums) {
1308   static_assert(size == 3 || size == 5, "");
1309   int kOverreadInBytes_128, kOverreadInBytes_256;
1310   if (size == 3) {
1311     kOverreadInBytes_128 = kOverreadInBytesPass2_128;
1312     kOverreadInBytes_256 = kOverreadInBytesPass2_256;
1313   } else {
1314     kOverreadInBytes_128 = kOverreadInBytesPass1_128;
1315     kOverreadInBytes_256 = kOverreadInBytesPass1_256;
1316   }
1317   int y = 2;
1318   do {
1319     const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytes_128 - width);
1320     __m128i ss, sq_128[2], sqs[2];
1321     __m256i sq[3];
1322     sq_128[0] = SquareLo8(s);
1323     sq_128[1] = SquareHi8(s);
1324     if (size == 3) {
1325       ss = Sum3Horizontal(s);
1326       Sum3WHorizontal(sq_128, sqs);
1327     } else {
1328       ss = Sum5Horizontal(s);
1329       Sum5WHorizontal(sq_128, sqs);
1330     }
1331     StoreAligned16(sums, ss);
1332     StoreAligned32U32(square_sums, sqs);
1333     src += 8;
1334     sums += 8;
1335     square_sums += 8;
1336     sq[0] = SetrM128i(sq_128[1], sq_128[1]);
1337     ptrdiff_t x = sum_width;
1338     do {
1339       __m256i row[2], row_sq[4];
1340       const __m256i s = LoadUnaligned32Msan(
1341           src + 8, sum_width - x + 16 + kOverreadInBytes_256 - width);
1342       sq[1] = SquareLo8(s);
1343       sq[2] = SquareHi8(s);
1344       sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1345       if (size == 3) {
1346         Sum3Horizontal(src, sum_width - x + 8 + kOverreadInBytes_256 - width,
1347                        row);
1348         Sum3WHorizontal(sq + 0, row_sq + 0);
1349         Sum3WHorizontal(sq + 1, row_sq + 2);
1350       } else {
1351         Sum5Horizontal(src, sum_width - x + 8 + kOverreadInBytes_256 - width,
1352                        &row[0], &row[1]);
1353         Sum5WHorizontal(sq + 0, row_sq + 0);
1354         Sum5WHorizontal(sq + 1, row_sq + 2);
1355       }
1356       StoreAligned64(sums, row);
1357       StoreAligned64(square_sums + 0, row_sq + 0);
1358       StoreAligned64(square_sums + 16, row_sq + 2);
1359       sq[0] = sq[2];
1360       src += 32;
1361       sums += 32;
1362       square_sums += 32;
1363       x -= 32;
1364     } while (x != 0);
1365     src += src_stride - sum_width - 8;
1366     sums += sum_stride - sum_width - 8;
1367     square_sums += sum_stride - sum_width - 8;
1368   } while (--y != 0);
1369 }
1370 
1371 template <int n>
CalculateMa(const __m128i sum,const __m128i sum_sq,const uint32_t scale)1372 inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq,
1373                            const uint32_t scale) {
1374   static_assert(n == 9 || n == 25, "");
1375   // a = |sum_sq|
1376   // d = |sum|
1377   // p = (a * n < d * d) ? 0 : a * n - d * d;
1378   const __m128i dxd = _mm_madd_epi16(sum, sum);
1379   // _mm_mullo_epi32() has high latency. Using shifts and additions instead.
1380   // Some compilers could do this for us but we make this explicit.
1381   // return _mm_mullo_epi32(sum_sq, _mm_set1_epi32(n));
1382   __m128i axn = _mm_add_epi32(sum_sq, _mm_slli_epi32(sum_sq, 3));
1383   if (n == 25) axn = _mm_add_epi32(axn, _mm_slli_epi32(sum_sq, 4));
1384   const __m128i sub = _mm_sub_epi32(axn, dxd);
1385   const __m128i p = _mm_max_epi32(sub, _mm_setzero_si128());
1386   const __m128i pxs = _mm_mullo_epi32(p, _mm_set1_epi32(scale));
1387   return VrshrU32(pxs, kSgrProjScaleBits);
1388 }
1389 
1390 template <int n>
CalculateMa(const __m128i sum,const __m128i sum_sq[2],const uint32_t scale)1391 inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq[2],
1392                            const uint32_t scale) {
1393   static_assert(n == 9 || n == 25, "");
1394   const __m128i sum_lo = _mm_unpacklo_epi16(sum, _mm_setzero_si128());
1395   const __m128i sum_hi = _mm_unpackhi_epi16(sum, _mm_setzero_si128());
1396   const __m128i z0 = CalculateMa<n>(sum_lo, sum_sq[0], scale);
1397   const __m128i z1 = CalculateMa<n>(sum_hi, sum_sq[1], scale);
1398   return _mm_packus_epi32(z0, z1);
1399 }
1400 
1401 template <int n>
CalculateMa(const __m256i sum,const __m256i sum_sq,const uint32_t scale)1402 inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq,
1403                            const uint32_t scale) {
1404   static_assert(n == 9 || n == 25, "");
1405   // a = |sum_sq|
1406   // d = |sum|
1407   // p = (a * n < d * d) ? 0 : a * n - d * d;
1408   const __m256i dxd = _mm256_madd_epi16(sum, sum);
1409   // _mm256_mullo_epi32() has high latency. Using shifts and additions instead.
1410   // Some compilers could do this for us but we make this explicit.
1411   // return _mm256_mullo_epi32(sum_sq, _mm256_set1_epi32(n));
1412   __m256i axn = _mm256_add_epi32(sum_sq, _mm256_slli_epi32(sum_sq, 3));
1413   if (n == 25) axn = _mm256_add_epi32(axn, _mm256_slli_epi32(sum_sq, 4));
1414   const __m256i sub = _mm256_sub_epi32(axn, dxd);
1415   const __m256i p = _mm256_max_epi32(sub, _mm256_setzero_si256());
1416   const __m256i pxs = _mm256_mullo_epi32(p, _mm256_set1_epi32(scale));
1417   return VrshrU32(pxs, kSgrProjScaleBits);
1418 }
1419 
1420 template <int n>
CalculateMa(const __m256i sum,const __m256i sum_sq[2],const uint32_t scale)1421 inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq[2],
1422                            const uint32_t scale) {
1423   static_assert(n == 9 || n == 25, "");
1424   const __m256i sum_lo = _mm256_unpacklo_epi16(sum, _mm256_setzero_si256());
1425   const __m256i sum_hi = _mm256_unpackhi_epi16(sum, _mm256_setzero_si256());
1426   const __m256i z0 = CalculateMa<n>(sum_lo, sum_sq[0], scale);
1427   const __m256i z1 = CalculateMa<n>(sum_hi, sum_sq[1], scale);
1428   return _mm256_packus_epi32(z0, z1);
1429 }
1430 
CalculateB5(const __m128i sum,const __m128i ma)1431 inline __m128i CalculateB5(const __m128i sum, const __m128i ma) {
1432   // one_over_n == 164.
1433   constexpr uint32_t one_over_n =
1434       ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25;
1435   // one_over_n_quarter == 41.
1436   constexpr uint32_t one_over_n_quarter = one_over_n >> 2;
1437   static_assert(one_over_n == one_over_n_quarter << 2, "");
1438   // |ma| is in range [0, 255].
1439   const __m128i m = _mm_maddubs_epi16(ma, _mm_set1_epi16(one_over_n_quarter));
1440   const __m128i m0 = VmullLo16(m, sum);
1441   const __m128i m1 = VmullHi16(m, sum);
1442   const __m128i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2);
1443   const __m128i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2);
1444   return _mm_packus_epi32(b_lo, b_hi);
1445 }
1446 
CalculateB5(const __m256i sum,const __m256i ma)1447 inline __m256i CalculateB5(const __m256i sum, const __m256i ma) {
1448   // one_over_n == 164.
1449   constexpr uint32_t one_over_n =
1450       ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25;
1451   // one_over_n_quarter == 41.
1452   constexpr uint32_t one_over_n_quarter = one_over_n >> 2;
1453   static_assert(one_over_n == one_over_n_quarter << 2, "");
1454   // |ma| is in range [0, 255].
1455   const __m256i m =
1456       _mm256_maddubs_epi16(ma, _mm256_set1_epi16(one_over_n_quarter));
1457   const __m256i m0 = VmullLo16(m, sum);
1458   const __m256i m1 = VmullHi16(m, sum);
1459   const __m256i b_lo = VrshrU32(m0, kSgrProjReciprocalBits - 2);
1460   const __m256i b_hi = VrshrU32(m1, kSgrProjReciprocalBits - 2);
1461   return _mm256_packus_epi32(b_lo, b_hi);
1462 }
1463 
CalculateB3(const __m128i sum,const __m128i ma)1464 inline __m128i CalculateB3(const __m128i sum, const __m128i ma) {
1465   // one_over_n == 455.
1466   constexpr uint32_t one_over_n =
1467       ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9;
1468   const __m128i m0 = VmullLo16(ma, sum);
1469   const __m128i m1 = VmullHi16(ma, sum);
1470   const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n));
1471   const __m128i m3 = _mm_mullo_epi32(m1, _mm_set1_epi32(one_over_n));
1472   const __m128i b_lo = VrshrU32(m2, kSgrProjReciprocalBits);
1473   const __m128i b_hi = VrshrU32(m3, kSgrProjReciprocalBits);
1474   return _mm_packus_epi32(b_lo, b_hi);
1475 }
1476 
CalculateB3(const __m256i sum,const __m256i ma)1477 inline __m256i CalculateB3(const __m256i sum, const __m256i ma) {
1478   // one_over_n == 455.
1479   constexpr uint32_t one_over_n =
1480       ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9;
1481   const __m256i m0 = VmullLo16(ma, sum);
1482   const __m256i m1 = VmullHi16(ma, sum);
1483   const __m256i m2 = _mm256_mullo_epi32(m0, _mm256_set1_epi32(one_over_n));
1484   const __m256i m3 = _mm256_mullo_epi32(m1, _mm256_set1_epi32(one_over_n));
1485   const __m256i b_lo = VrshrU32(m2, kSgrProjReciprocalBits);
1486   const __m256i b_hi = VrshrU32(m3, kSgrProjReciprocalBits);
1487   return _mm256_packus_epi32(b_lo, b_hi);
1488 }
1489 
CalculateSumAndIndex5(const __m128i s5[5],const __m128i sq5[5][2],const uint32_t scale,__m128i * const sum,__m128i * const index)1490 inline void CalculateSumAndIndex5(const __m128i s5[5], const __m128i sq5[5][2],
1491                                   const uint32_t scale, __m128i* const sum,
1492                                   __m128i* const index) {
1493   __m128i sum_sq[2];
1494   *sum = Sum5_16(s5);
1495   Sum5_32(sq5, sum_sq);
1496   *index = CalculateMa<25>(*sum, sum_sq, scale);
1497 }
1498 
CalculateSumAndIndex5(const __m256i s5[5],const __m256i sq5[5][2],const uint32_t scale,__m256i * const sum,__m256i * const index)1499 inline void CalculateSumAndIndex5(const __m256i s5[5], const __m256i sq5[5][2],
1500                                   const uint32_t scale, __m256i* const sum,
1501                                   __m256i* const index) {
1502   __m256i sum_sq[2];
1503   *sum = Sum5_16(s5);
1504   Sum5_32(sq5, sum_sq);
1505   *index = CalculateMa<25>(*sum, sum_sq, scale);
1506 }
1507 
CalculateSumAndIndex3(const __m128i s3[3],const __m128i sq3[3][2],const uint32_t scale,__m128i * const sum,__m128i * const index)1508 inline void CalculateSumAndIndex3(const __m128i s3[3], const __m128i sq3[3][2],
1509                                   const uint32_t scale, __m128i* const sum,
1510                                   __m128i* const index) {
1511   __m128i sum_sq[2];
1512   *sum = Sum3_16(s3);
1513   Sum3_32(sq3, sum_sq);
1514   *index = CalculateMa<9>(*sum, sum_sq, scale);
1515 }
1516 
CalculateSumAndIndex3(const __m256i s3[3],const __m256i sq3[3][2],const uint32_t scale,__m256i * const sum,__m256i * const index)1517 inline void CalculateSumAndIndex3(const __m256i s3[3], const __m256i sq3[3][2],
1518                                   const uint32_t scale, __m256i* const sum,
1519                                   __m256i* const index) {
1520   __m256i sum_sq[2];
1521   *sum = Sum3_16(s3);
1522   Sum3_32(sq3, sum_sq);
1523   *index = CalculateMa<9>(*sum, sum_sq, scale);
1524 }
1525 
1526 template <int n>
LookupIntermediate(const __m128i sum,const __m128i index,__m128i * const ma,__m128i * const b)1527 inline void LookupIntermediate(const __m128i sum, const __m128i index,
1528                                __m128i* const ma, __m128i* const b) {
1529   static_assert(n == 9 || n == 25, "");
1530   const __m128i idx = _mm_packus_epi16(index, index);
1531   // Actually it's not stored and loaded. The compiler will use a 64-bit
1532   // general-purpose register to process. Faster than using _mm_extract_epi8().
1533   uint8_t temp[8];
1534   StoreLo8(temp, idx);
1535   *ma = _mm_cvtsi32_si128(kSgrMaLookup[temp[0]]);
1536   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[1]], 1);
1537   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[2]], 2);
1538   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[3]], 3);
1539   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[4]], 4);
1540   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[5]], 5);
1541   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[6]], 6);
1542   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[7]], 7);
1543   // b = ma * b * one_over_n
1544   // |ma| = [0, 255]
1545   // |sum| is a box sum with radius 1 or 2.
1546   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1547   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1548   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1549   // When radius is 2 |n| is 25. |one_over_n| is 164.
1550   // When radius is 1 |n| is 9. |one_over_n| is 455.
1551   // |kSgrProjReciprocalBits| is 12.
1552   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1553   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1554   const __m128i maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128());
1555   *b = (n == 9) ? CalculateB3(sum, maq) : CalculateB5(sum, maq);
1556 }
1557 
1558 // Repeat the first 48 elements in kSgrMaLookup with a period of 16.
1559 alignas(32) constexpr uint8_t kSgrMaLookupAvx2[96] = {
1560     255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16,
1561     255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16,
1562     15,  14,  13, 13, 12, 12, 11, 11, 10, 10, 9,  9,  9,  9,  8,  8,
1563     15,  14,  13, 13, 12, 12, 11, 11, 10, 10, 9,  9,  9,  9,  8,  8,
1564     8,   8,   7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  5,  5,
1565     8,   8,   7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  5,  5};
1566 
1567 // Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b
1568 // to get value 0 as the shuffle result. The most significiant bit 1 comes
1569 // either from the comparison instruction, or from the sign bit of the index.
ShuffleIndex(const __m256i table,const __m256i index)1570 inline __m256i ShuffleIndex(const __m256i table, const __m256i index) {
1571   __m256i mask;
1572   mask = _mm256_cmpgt_epi8(index, _mm256_set1_epi8(15));
1573   mask = _mm256_or_si256(mask, index);
1574   return _mm256_shuffle_epi8(table, mask);
1575 }
1576 
AdjustValue(const __m256i value,const __m256i index,const int threshold)1577 inline __m256i AdjustValue(const __m256i value, const __m256i index,
1578                            const int threshold) {
1579   const __m256i thresholds = _mm256_set1_epi8(threshold - 128);
1580   const __m256i offset = _mm256_cmpgt_epi8(index, thresholds);
1581   return _mm256_add_epi8(value, offset);
1582 }
1583 
1584 template <int n>
CalculateIntermediate(const __m256i sum[2],const __m256i index[2],__m256i ma[3],__m256i b[2])1585 inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2],
1586                                   __m256i ma[3], __m256i b[2]) {
1587   static_assert(n == 9 || n == 25, "");
1588   // Use table lookup to read elements whose indices are less than 48.
1589   const __m256i c0 = LoadAligned32(kSgrMaLookupAvx2 + 0 * 32);
1590   const __m256i c1 = LoadAligned32(kSgrMaLookupAvx2 + 1 * 32);
1591   const __m256i c2 = LoadAligned32(kSgrMaLookupAvx2 + 2 * 32);
1592   const __m256i indices = _mm256_packus_epi16(index[0], index[1]);
1593   __m256i idx, mas;
1594   // Clip idx to 127 to apply signed comparison instructions.
1595   idx = _mm256_min_epu8(indices, _mm256_set1_epi8(127));
1596   // All elements whose indices are less than 48 are set to 0.
1597   // Get shuffle results for indices in range [0, 15].
1598   mas = ShuffleIndex(c0, idx);
1599   // Get shuffle results for indices in range [16, 31].
1600   // Subtract 16 to utilize the sign bit of the index.
1601   idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16));
1602   const __m256i res1 = ShuffleIndex(c1, idx);
1603   // Use OR instruction to combine shuffle results together.
1604   mas = _mm256_or_si256(mas, res1);
1605   // Get shuffle results for indices in range [32, 47].
1606   // Subtract 16 to utilize the sign bit of the index.
1607   idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16));
1608   const __m256i res2 = ShuffleIndex(c2, idx);
1609   mas = _mm256_or_si256(mas, res2);
1610 
1611   // For elements whose indices are larger than 47, since they seldom change
1612   // values with the increase of the index, we use comparison and arithmetic
1613   // operations to calculate their values.
1614   // Add -128 to apply signed comparison instructions.
1615   idx = _mm256_add_epi8(indices, _mm256_set1_epi8(-128));
1616   // Elements whose indices are larger than 47 (with value 0) are set to 5.
1617   mas = _mm256_max_epu8(mas, _mm256_set1_epi8(5));
1618   mas = AdjustValue(mas, idx, 55);   // 55 is the last index which value is 5.
1619   mas = AdjustValue(mas, idx, 72);   // 72 is the last index which value is 4.
1620   mas = AdjustValue(mas, idx, 101);  // 101 is the last index which value is 3.
1621   mas = AdjustValue(mas, idx, 169);  // 169 is the last index which value is 2.
1622   mas = AdjustValue(mas, idx, 254);  // 254 is the last index which value is 1.
1623 
1624   ma[2] = _mm256_permute4x64_epi64(mas, 0x93);     // 32-39 8-15 16-23 24-31
1625   ma[0] = _mm256_blend_epi32(ma[0], ma[2], 0xfc);  //  0-7  8-15 16-23 24-31
1626   ma[1] = _mm256_permute2x128_si256(ma[0], ma[2], 0x21);
1627 
1628   // b = ma * b * one_over_n
1629   // |ma| = [0, 255]
1630   // |sum| is a box sum with radius 1 or 2.
1631   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1632   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1633   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1634   // When radius is 2 |n| is 25. |one_over_n| is 164.
1635   // When radius is 1 |n| is 9. |one_over_n| is 455.
1636   // |kSgrProjReciprocalBits| is 12.
1637   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1638   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1639   const __m256i maq0 = _mm256_unpackhi_epi8(ma[0], _mm256_setzero_si256());
1640   const __m256i maq1 = _mm256_unpacklo_epi8(ma[1], _mm256_setzero_si256());
1641   if (n == 9) {
1642     b[0] = CalculateB3(sum[0], maq0);
1643     b[1] = CalculateB3(sum[1], maq1);
1644   } else {
1645     b[0] = CalculateB5(sum[0], maq0);
1646     b[1] = CalculateB5(sum[1], maq1);
1647   }
1648 }
1649 
CalculateIntermediate5(const __m128i s5[5],const __m128i sq5[5][2],const uint32_t scale,__m128i * const ma,__m128i * const b)1650 inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2],
1651                                    const uint32_t scale, __m128i* const ma,
1652                                    __m128i* const b) {
1653   __m128i sum, index;
1654   CalculateSumAndIndex5(s5, sq5, scale, &sum, &index);
1655   LookupIntermediate<25>(sum, index, ma, b);
1656 }
1657 
CalculateIntermediate3(const __m128i s3[3],const __m128i sq3[3][2],const uint32_t scale,__m128i * const ma,__m128i * const b)1658 inline void CalculateIntermediate3(const __m128i s3[3], const __m128i sq3[3][2],
1659                                    const uint32_t scale, __m128i* const ma,
1660                                    __m128i* const b) {
1661   __m128i sum, index;
1662   CalculateSumAndIndex3(s3, sq3, scale, &sum, &index);
1663   LookupIntermediate<9>(sum, index, ma, b);
1664 }
1665 
Store343_444(const __m256i b3[2],const ptrdiff_t x,__m256i sum_b343[2],__m256i sum_b444[2],uint32_t * const b343,uint32_t * const b444)1666 inline void Store343_444(const __m256i b3[2], const ptrdiff_t x,
1667                          __m256i sum_b343[2], __m256i sum_b444[2],
1668                          uint32_t* const b343, uint32_t* const b444) {
1669   __m256i b[3], sum_b111[2];
1670   Prepare3_16(b3, b);
1671   sum_b111[0] = Sum3WLo32(b);
1672   sum_b111[1] = Sum3WHi32(b);
1673   sum_b444[0] = _mm256_slli_epi32(sum_b111[0], 2);
1674   sum_b444[1] = _mm256_slli_epi32(sum_b111[1], 2);
1675   StoreAligned64(b444 + x, sum_b444);
1676   sum_b343[0] = _mm256_sub_epi32(sum_b444[0], sum_b111[0]);
1677   sum_b343[1] = _mm256_sub_epi32(sum_b444[1], sum_b111[1]);
1678   sum_b343[0] = VaddwLo16(sum_b343[0], b[1]);
1679   sum_b343[1] = VaddwHi16(sum_b343[1], b[1]);
1680   StoreAligned64(b343 + x, sum_b343);
1681 }
1682 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i * const sum_ma444,__m256i sum_b343[2],__m256i sum_b444[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1683 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1684                            const ptrdiff_t x, __m256i* const sum_ma343,
1685                            __m256i* const sum_ma444, __m256i sum_b343[2],
1686                            __m256i sum_b444[2], uint16_t* const ma343,
1687                            uint16_t* const ma444, uint32_t* const b343,
1688                            uint32_t* const b444) {
1689   const __m256i sum_ma111 = Sum3WLo16(ma3);
1690   *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2);
1691   StoreAligned32(ma444 + x, *sum_ma444);
1692   const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111);
1693   *sum_ma343 = VaddwLo8(sum333, ma3[1]);
1694   StoreAligned32(ma343 + x, *sum_ma343);
1695   Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
1696 }
1697 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i * const sum_ma444,__m256i sum_b343[2],__m256i sum_b444[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1698 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1699                            const ptrdiff_t x, __m256i* const sum_ma343,
1700                            __m256i* const sum_ma444, __m256i sum_b343[2],
1701                            __m256i sum_b444[2], uint16_t* const ma343,
1702                            uint16_t* const ma444, uint32_t* const b343,
1703                            uint32_t* const b444) {
1704   const __m256i sum_ma111 = Sum3WHi16(ma3);
1705   *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2);
1706   StoreAligned32(ma444 + x, *sum_ma444);
1707   const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111);
1708   *sum_ma343 = VaddwHi8(sum333, ma3[1]);
1709   StoreAligned32(ma343 + x, *sum_ma343);
1710   Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
1711 }
1712 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i sum_b343[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1713 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1714                            const ptrdiff_t x, __m256i* const sum_ma343,
1715                            __m256i sum_b343[2], uint16_t* const ma343,
1716                            uint16_t* const ma444, uint32_t* const b343,
1717                            uint32_t* const b444) {
1718   __m256i sum_ma444, sum_b444[2];
1719   Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
1720                  ma444, b343, b444);
1721 }
1722 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i sum_b343[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1723 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1724                            const ptrdiff_t x, __m256i* const sum_ma343,
1725                            __m256i sum_b343[2], uint16_t* const ma343,
1726                            uint16_t* const ma444, uint32_t* const b343,
1727                            uint32_t* const b444) {
1728   __m256i sum_ma444, sum_b444[2];
1729   Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
1730                  ma444, b343, b444);
1731 }
1732 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1733 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1734                            const ptrdiff_t x, uint16_t* const ma343,
1735                            uint16_t* const ma444, uint32_t* const b343,
1736                            uint32_t* const b444) {
1737   __m256i sum_ma343, sum_b343[2];
1738   Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
1739 }
1740 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1741 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1742                            const ptrdiff_t x, uint16_t* const ma343,
1743                            uint16_t* const ma444, uint32_t* const b343,
1744                            uint32_t* const b444) {
1745   __m256i sum_ma343, sum_b343[2];
1746   Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
1747 }
1748 
BoxFilterPreProcess5Lo(const __m128i s[2][3],const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],__m128i sq[2][2],__m128i * const ma,__m128i * const b)1749 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo(
1750     const __m128i s[2][3], const uint32_t scale, uint16_t* const sum5[5],
1751     uint32_t* const square_sum5[5], __m128i sq[2][2], __m128i* const ma,
1752     __m128i* const b) {
1753   __m128i s5[2][5], sq5[5][2];
1754   sq[0][1] = SquareHi8(s[0][0]);
1755   sq[1][1] = SquareHi8(s[1][0]);
1756   s5[0][3] = Sum5Horizontal(s[0][0]);
1757   StoreAligned16(sum5[3], s5[0][3]);
1758   s5[0][4] = Sum5Horizontal(s[1][0]);
1759   StoreAligned16(sum5[4], s5[0][4]);
1760   Sum5WHorizontal(sq[0], sq5[3]);
1761   StoreAligned32U32(square_sum5[3], sq5[3]);
1762   Sum5WHorizontal(sq[1], sq5[4]);
1763   StoreAligned32U32(square_sum5[4], sq5[4]);
1764   LoadAligned16x3U16(sum5, 0, s5[0]);
1765   LoadAligned32x3U32(square_sum5, 0, sq5);
1766   CalculateIntermediate5(s5[0], sq5, scale, ma, b);
1767 }
1768 
BoxFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],__m256i sq[2][3],__m256i ma[3],__m256i b[3])1769 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
1770     const uint8_t* const src0, const uint8_t* const src1,
1771     const ptrdiff_t over_read_in_bytes, const ptrdiff_t sum_width,
1772     const ptrdiff_t x, const uint32_t scale, uint16_t* const sum5[5],
1773     uint32_t* const square_sum5[5], __m256i sq[2][3], __m256i ma[3],
1774     __m256i b[3]) {
1775   const __m256i s0 = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 8);
1776   const __m256i s1 = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 8);
1777   __m256i s5[2][5], sq5[5][2], sum[2], index[2];
1778   sq[0][1] = SquareLo8(s0);
1779   sq[0][2] = SquareHi8(s0);
1780   sq[1][1] = SquareLo8(s1);
1781   sq[1][2] = SquareHi8(s1);
1782   sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21);
1783   sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21);
1784   Sum5Horizontal(src0, over_read_in_bytes, &s5[0][3], &s5[1][3]);
1785   Sum5Horizontal(src1, over_read_in_bytes, &s5[0][4], &s5[1][4]);
1786   StoreAligned32(sum5[3] + x + 0, s5[0][3]);
1787   StoreAligned32(sum5[3] + x + 16, s5[1][3]);
1788   StoreAligned32(sum5[4] + x + 0, s5[0][4]);
1789   StoreAligned32(sum5[4] + x + 16, s5[1][4]);
1790   Sum5WHorizontal(sq[0], sq5[3]);
1791   StoreAligned64(square_sum5[3] + x, sq5[3]);
1792   Sum5WHorizontal(sq[1], sq5[4]);
1793   StoreAligned64(square_sum5[4] + x, sq5[4]);
1794   LoadAligned32x3U16(sum5, x, s5[0]);
1795   LoadAligned64x3U32(square_sum5, x, sq5);
1796   CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]);
1797 
1798   Sum5WHorizontal(sq[0] + 1, sq5[3]);
1799   StoreAligned64(square_sum5[3] + x + 16, sq5[3]);
1800   Sum5WHorizontal(sq[1] + 1, sq5[4]);
1801   StoreAligned64(square_sum5[4] + x + 16, sq5[4]);
1802   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
1803   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
1804   CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]);
1805   CalculateIntermediate<25>(sum, index, ma, b + 1);
1806   b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21);
1807 }
1808 
BoxFilterPreProcess5LastRowLo(const __m128i s,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],__m128i sq[2],__m128i * const ma,__m128i * const b)1809 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo(
1810     const __m128i s, const uint32_t scale, const uint16_t* const sum5[5],
1811     const uint32_t* const square_sum5[5], __m128i sq[2], __m128i* const ma,
1812     __m128i* const b) {
1813   __m128i s5[5], sq5[5][2];
1814   sq[1] = SquareHi8(s);
1815   s5[3] = s5[4] = Sum5Horizontal(s);
1816   Sum5WHorizontal(sq, sq5[3]);
1817   sq5[4][0] = sq5[3][0];
1818   sq5[4][1] = sq5[3][1];
1819   LoadAligned16x3U16(sum5, 0, s5);
1820   LoadAligned32x3U32(square_sum5, 0, sq5);
1821   CalculateIntermediate5(s5, sq5, scale, ma, b);
1822 }
1823 
BoxFilterPreProcess5LastRow(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],__m256i sq[3],__m256i ma[3],__m256i b[3])1824 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
1825     const uint8_t* const src, const ptrdiff_t over_read_in_bytes,
1826     const ptrdiff_t sum_width, const ptrdiff_t x, const uint32_t scale,
1827     const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
1828     __m256i sq[3], __m256i ma[3], __m256i b[3]) {
1829   const __m256i s = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8);
1830   __m256i s5[2][5], sq5[5][2], sum[2], index[2];
1831   sq[1] = SquareLo8(s);
1832   sq[2] = SquareHi8(s);
1833   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1834   Sum5Horizontal(src, over_read_in_bytes, &s5[0][3], &s5[1][3]);
1835   s5[0][4] = s5[0][3];
1836   s5[1][4] = s5[1][3];
1837   Sum5WHorizontal(sq, sq5[3]);
1838   sq5[4][0] = sq5[3][0];
1839   sq5[4][1] = sq5[3][1];
1840   LoadAligned32x3U16(sum5, x, s5[0]);
1841   LoadAligned64x3U32(square_sum5, x, sq5);
1842   CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]);
1843 
1844   Sum5WHorizontal(sq + 1, sq5[3]);
1845   sq5[4][0] = sq5[3][0];
1846   sq5[4][1] = sq5[3][1];
1847   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
1848   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
1849   CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]);
1850   CalculateIntermediate<25>(sum, index, ma, b + 1);
1851   b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21);
1852 }
1853 
BoxFilterPreProcess3Lo(const __m128i s,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],__m128i sq[2],__m128i * const ma,__m128i * const b)1854 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo(
1855     const __m128i s, const uint32_t scale, uint16_t* const sum3[3],
1856     uint32_t* const square_sum3[3], __m128i sq[2], __m128i* const ma,
1857     __m128i* const b) {
1858   __m128i s3[3], sq3[3][2];
1859   sq[1] = SquareHi8(s);
1860   s3[2] = Sum3Horizontal(s);
1861   StoreAligned16(sum3[2], s3[2]);
1862   Sum3WHorizontal(sq, sq3[2]);
1863   StoreAligned32U32(square_sum3[2], sq3[2]);
1864   LoadAligned16x2U16(sum3, 0, s3);
1865   LoadAligned32x2U32(square_sum3, 0, sq3);
1866   CalculateIntermediate3(s3, sq3, scale, ma, b);
1867 }
1868 
BoxFilterPreProcess3(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t x,const ptrdiff_t sum_width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],__m256i sq[3],__m256i ma[3],__m256i b[3])1869 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
1870     const uint8_t* const src, const ptrdiff_t over_read_in_bytes,
1871     const ptrdiff_t x, const ptrdiff_t sum_width, const uint32_t scale,
1872     uint16_t* const sum3[3], uint32_t* const square_sum3[3], __m256i sq[3],
1873     __m256i ma[3], __m256i b[3]) {
1874   const __m256i s = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8);
1875   __m256i s3[4], sq3[3][2], sum[2], index[2];
1876   sq[1] = SquareLo8(s);
1877   sq[2] = SquareHi8(s);
1878   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1879   Sum3Horizontal(src, over_read_in_bytes, s3 + 2);
1880   StoreAligned64(sum3[2] + x, s3 + 2);
1881   Sum3WHorizontal(sq + 0, sq3[2]);
1882   StoreAligned64(square_sum3[2] + x, sq3[2]);
1883   LoadAligned32x2U16(sum3, x, s3);
1884   LoadAligned64x2U32(square_sum3, x, sq3);
1885   CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]);
1886 
1887   Sum3WHorizontal(sq + 1, sq3[2]);
1888   StoreAligned64(square_sum3[2] + x + 16, sq3[2]);
1889   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3 + 1);
1890   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
1891   CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]);
1892   CalculateIntermediate<9>(sum, index, ma, b + 1);
1893   b[0] = _mm256_permute2x128_si256(b[0], b[2], 0x21);
1894 }
1895 
BoxFilterPreProcessLo(const __m128i s[2],const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],__m128i sq[2][2],__m128i ma3[2],__m128i b3[2],__m128i * const ma5,__m128i * const b5)1896 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo(
1897     const __m128i s[2], const uint16_t scales[2], uint16_t* const sum3[4],
1898     uint16_t* const sum5[5], uint32_t* const square_sum3[4],
1899     uint32_t* const square_sum5[5], __m128i sq[2][2], __m128i ma3[2],
1900     __m128i b3[2], __m128i* const ma5, __m128i* const b5) {
1901   __m128i s3[4], s5[5], sq3[4][2], sq5[5][2];
1902   sq[0][1] = SquareHi8(s[0]);
1903   sq[1][1] = SquareHi8(s[1]);
1904   SumHorizontalLo(s[0], &s3[2], &s5[3]);
1905   SumHorizontalLo(s[1], &s3[3], &s5[4]);
1906   StoreAligned16(sum3[2], s3[2]);
1907   StoreAligned16(sum3[3], s3[3]);
1908   StoreAligned16(sum5[3], s5[3]);
1909   StoreAligned16(sum5[4], s5[4]);
1910   SumHorizontal(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1911   StoreAligned32U32(square_sum3[2], sq3[2]);
1912   StoreAligned32U32(square_sum5[3], sq5[3]);
1913   SumHorizontal(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1914   StoreAligned32U32(square_sum3[3], sq3[3]);
1915   StoreAligned32U32(square_sum5[4], sq5[4]);
1916   LoadAligned16x2U16(sum3, 0, s3);
1917   LoadAligned32x2U32(square_sum3, 0, sq3);
1918   LoadAligned16x3U16(sum5, 0, s5);
1919   LoadAligned32x3U32(square_sum5, 0, sq5);
1920   // Note: in the SSE4_1 version, CalculateIntermediate() is called
1921   // to replace the slow LookupIntermediate() when calculating 16 intermediate
1922   // data points. However, the AVX2 compiler generates even slower code. So we
1923   // keep using CalculateIntermediate3().
1924   CalculateIntermediate3(s3 + 0, sq3 + 0, scales[1], &ma3[0], &b3[0]);
1925   CalculateIntermediate3(s3 + 1, sq3 + 1, scales[1], &ma3[1], &b3[1]);
1926   CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
1927 }
1928 
BoxFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t over_read_in_bytes,const ptrdiff_t x,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,__m256i sq[2][3],__m256i ma3[2][3],__m256i b3[2][5],__m256i ma5[3],__m256i b5[5])1929 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
1930     const uint8_t* const src0, const uint8_t* const src1,
1931     const ptrdiff_t over_read_in_bytes, const ptrdiff_t x,
1932     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
1933     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
1934     const ptrdiff_t sum_width, __m256i sq[2][3], __m256i ma3[2][3],
1935     __m256i b3[2][5], __m256i ma5[3], __m256i b5[5]) {
1936   const __m256i s0 = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 8);
1937   const __m256i s1 = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 8);
1938   __m256i s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2][2], index_3[2][2],
1939       sum_5[2], index_5[2];
1940   sq[0][1] = SquareLo8(s0);
1941   sq[0][2] = SquareHi8(s0);
1942   sq[1][1] = SquareLo8(s1);
1943   sq[1][2] = SquareHi8(s1);
1944   sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21);
1945   sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21);
1946   SumHorizontal(src0, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3],
1947                 &s5[1][3]);
1948   SumHorizontal(src1, over_read_in_bytes, &s3[0][3], &s3[1][3], &s5[0][4],
1949                 &s5[1][4]);
1950   StoreAligned32(sum3[2] + x + 0, s3[0][2]);
1951   StoreAligned32(sum3[2] + x + 16, s3[1][2]);
1952   StoreAligned32(sum3[3] + x + 0, s3[0][3]);
1953   StoreAligned32(sum3[3] + x + 16, s3[1][3]);
1954   StoreAligned32(sum5[3] + x + 0, s5[0][3]);
1955   StoreAligned32(sum5[3] + x + 16, s5[1][3]);
1956   StoreAligned32(sum5[4] + x + 0, s5[0][4]);
1957   StoreAligned32(sum5[4] + x + 16, s5[1][4]);
1958   SumHorizontal(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1959   SumHorizontal(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1960   StoreAligned64(square_sum3[2] + x, sq3[2]);
1961   StoreAligned64(square_sum5[3] + x, sq5[3]);
1962   StoreAligned64(square_sum3[3] + x, sq3[3]);
1963   StoreAligned64(square_sum5[4] + x, sq5[4]);
1964   LoadAligned32x2U16(sum3, x, s3[0]);
1965   LoadAligned64x2U32(square_sum3, x, sq3);
1966   CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0][0], &index_3[0][0]);
1967   CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum_3[1][0],
1968                         &index_3[1][0]);
1969   LoadAligned32x3U16(sum5, x, s5[0]);
1970   LoadAligned64x3U32(square_sum5, x, sq5);
1971   CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]);
1972 
1973   SumHorizontal(sq[0] + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
1974   SumHorizontal(sq[1] + 1, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
1975   StoreAligned64(square_sum3[2] + x + 16, sq3[2]);
1976   StoreAligned64(square_sum5[3] + x + 16, sq5[3]);
1977   StoreAligned64(square_sum3[3] + x + 16, sq3[3]);
1978   StoreAligned64(square_sum5[4] + x + 16, sq5[4]);
1979   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]);
1980   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
1981   CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[0][1], &index_3[0][1]);
1982   CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum_3[1][1],
1983                         &index_3[1][1]);
1984   CalculateIntermediate<9>(sum_3[0], index_3[0], ma3[0], b3[0] + 1);
1985   CalculateIntermediate<9>(sum_3[1], index_3[1], ma3[1], b3[1] + 1);
1986   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
1987   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
1988   CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]);
1989   CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1);
1990   b3[0][0] = _mm256_permute2x128_si256(b3[0][0], b3[0][2], 0x21);
1991   b3[1][0] = _mm256_permute2x128_si256(b3[1][0], b3[1][2], 0x21);
1992   b5[0] = _mm256_permute2x128_si256(b5[0], b5[2], 0x21);
1993 }
1994 
BoxFilterPreProcessLastRowLo(const __m128i s,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],__m128i sq[2],__m128i * const ma3,__m128i * const ma5,__m128i * const b3,__m128i * const b5)1995 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo(
1996     const __m128i s, const uint16_t scales[2], const uint16_t* const sum3[4],
1997     const uint16_t* const sum5[5], const uint32_t* const square_sum3[4],
1998     const uint32_t* const square_sum5[5], __m128i sq[2], __m128i* const ma3,
1999     __m128i* const ma5, __m128i* const b3, __m128i* const b5) {
2000   __m128i s3[3], s5[5], sq3[3][2], sq5[5][2];
2001   sq[1] = SquareHi8(s);
2002   SumHorizontalLo(s, &s3[2], &s5[3]);
2003   SumHorizontal(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2004   LoadAligned16x3U16(sum5, 0, s5);
2005   s5[4] = s5[3];
2006   LoadAligned32x3U32(square_sum5, 0, sq5);
2007   sq5[4][0] = sq5[3][0];
2008   sq5[4][1] = sq5[3][1];
2009   CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
2010   LoadAligned16x2U16(sum3, 0, s3);
2011   LoadAligned32x2U32(square_sum3, 0, sq3);
2012   CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
2013 }
2014 
BoxFilterPreProcessLastRow(const uint8_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],__m256i sq[6],__m256i ma3[2],__m256i ma5[2],__m256i b3[5],__m256i b5[5])2015 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
2016     const uint8_t* const src, const ptrdiff_t over_read_in_bytes,
2017     const ptrdiff_t sum_width, const ptrdiff_t x, const uint16_t scales[2],
2018     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
2019     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
2020     __m256i sq[6], __m256i ma3[2], __m256i ma5[2], __m256i b3[5],
2021     __m256i b5[5]) {
2022   const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 8);
2023   __m256i s3[2][3], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2], index_3[2],
2024       sum_5[2], index_5[2];
2025   sq[1] = SquareLo8(s0);
2026   sq[2] = SquareHi8(s0);
2027   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
2028   SumHorizontal(src, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3],
2029                 &s5[1][3]);
2030   SumHorizontal(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2031   LoadAligned32x2U16(sum3, x, s3[0]);
2032   LoadAligned64x2U32(square_sum3, x, sq3);
2033   CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0], &index_3[0]);
2034   LoadAligned32x3U16(sum5, x, s5[0]);
2035   s5[0][4] = s5[0][3];
2036   LoadAligned64x3U32(square_sum5, x, sq5);
2037   sq5[4][0] = sq5[3][0];
2038   sq5[4][1] = sq5[3][1];
2039   CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]);
2040 
2041   SumHorizontal(sq + 1, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2042   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]);
2043   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
2044   CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[1], &index_3[1]);
2045   CalculateIntermediate<9>(sum_3, index_3, ma3, b3 + 1);
2046   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
2047   s5[1][4] = s5[1][3];
2048   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
2049   sq5[4][0] = sq5[3][0];
2050   sq5[4][1] = sq5[3][1];
2051   CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]);
2052   CalculateIntermediate<25>(sum_5, index_5, ma5, b5 + 1);
2053   b3[0] = _mm256_permute2x128_si256(b3[0], b3[2], 0x21);
2054   b5[0] = _mm256_permute2x128_si256(b5[0], b5[2], 0x21);
2055 }
2056 
BoxSumFilterPreProcess5(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * ma565,uint32_t * b565)2057 inline void BoxSumFilterPreProcess5(const uint8_t* const src0,
2058                                     const uint8_t* const src1, const int width,
2059                                     const uint32_t scale,
2060                                     uint16_t* const sum5[5],
2061                                     uint32_t* const square_sum5[5],
2062                                     const ptrdiff_t sum_width, uint16_t* ma565,
2063                                     uint32_t* b565) {
2064   __m128i ma0, b0, s[2][3], sq_128[2][2];
2065   __m256i mas[3], sq[2][3], bs[3];
2066   s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2067   s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width);
2068   sq_128[0][0] = SquareLo8(s[0][0]);
2069   sq_128[1][0] = SquareLo8(s[1][0]);
2070   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, &b0);
2071   sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]);
2072   sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]);
2073   mas[0] = SetrM128i(ma0, ma0);
2074   bs[0] = SetrM128i(b0, b0);
2075 
2076   int x = 0;
2077   do {
2078     __m256i ma5[3], ma[2], b[4];
2079     BoxFilterPreProcess5(src0 + x + 8, src1 + x + 8,
2080                          x + 8 + kOverreadInBytesPass1_256 - width, sum_width,
2081                          x + 8, scale, sum5, square_sum5, sq, mas, bs);
2082     Prepare3_8(mas, ma5);
2083     ma[0] = Sum565Lo(ma5);
2084     ma[1] = Sum565Hi(ma5);
2085     StoreAligned64(ma565, ma);
2086     Sum565W(bs + 0, b + 0);
2087     Sum565W(bs + 1, b + 2);
2088     StoreAligned64(b565, b + 0);
2089     StoreAligned64(b565 + 16, b + 2);
2090     sq[0][0] = sq[0][2];
2091     sq[1][0] = sq[1][2];
2092     mas[0] = mas[2];
2093     bs[0] = bs[2];
2094     ma565 += 32;
2095     b565 += 32;
2096     x += 32;
2097   } while (x < width);
2098 }
2099 
2100 template <bool calculate444>
BoxSumFilterPreProcess3(const uint8_t * const src,const int width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],const ptrdiff_t sum_width,uint16_t * ma343,uint16_t * ma444,uint32_t * b343,uint32_t * b444)2101 LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
2102     const uint8_t* const src, const int width, const uint32_t scale,
2103     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
2104     const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343,
2105     uint32_t* b444) {
2106   const __m128i s = LoadUnaligned16Msan(src, kOverreadInBytesPass2_128 - width);
2107   __m128i ma0, sq_128[2], b0;
2108   __m256i mas[3], sq[3], bs[3];
2109   sq_128[0] = SquareLo8(s);
2110   BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq_128, &ma0, &b0);
2111   sq[0] = SetrM128i(sq_128[0], sq_128[1]);
2112   mas[0] = SetrM128i(ma0, ma0);
2113   bs[0] = SetrM128i(b0, b0);
2114 
2115   int x = 0;
2116   do {
2117     __m256i ma3[3];
2118     BoxFilterPreProcess3(src + x + 8, x + 8 + kOverreadInBytesPass2_256 - width,
2119                          x + 8, sum_width, scale, sum3, square_sum3, sq, mas,
2120                          bs);
2121     Prepare3_8(mas, ma3);
2122     if (calculate444) {  // NOLINT(readability-simplify-boolean-expr)
2123       Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444);
2124       Store343_444Hi(ma3, bs + 1, 16, ma343, ma444, b343, b444);
2125       ma444 += 32;
2126       b444 += 32;
2127     } else {
2128       __m256i ma[2], b[4];
2129       ma[0] = Sum343Lo(ma3);
2130       ma[1] = Sum343Hi(ma3);
2131       StoreAligned64(ma343, ma);
2132       Sum343W(bs + 0, b + 0);
2133       Sum343W(bs + 1, b + 2);
2134       StoreAligned64(b343 + 0, b + 0);
2135       StoreAligned64(b343 + 16, b + 2);
2136     }
2137     sq[0] = sq[2];
2138     mas[0] = mas[2];
2139     bs[0] = bs[2];
2140     ma343 += 32;
2141     b343 += 32;
2142     x += 32;
2143   } while (x < width);
2144 }
2145 
BoxSumFilterPreProcess(const uint8_t * const src0,const uint8_t * const src1,const int width,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * const ma343[4],uint16_t * const ma444,uint16_t * ma565,uint32_t * const b343[4],uint32_t * const b444,uint32_t * b565)2146 inline void BoxSumFilterPreProcess(
2147     const uint8_t* const src0, const uint8_t* const src1, const int width,
2148     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
2149     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2150     const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444,
2151     uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444,
2152     uint32_t* b565) {
2153   __m128i s[2], ma3_128[2], ma5_0, sq_128[2][2], b3_128[2], b5_0;
2154   __m256i ma3[2][3], ma5[3], sq[2][3], b3[2][5], b5[5];
2155   s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2156   s[1] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width);
2157   sq_128[0][0] = SquareLo8(s[0]);
2158   sq_128[1][0] = SquareLo8(s[1]);
2159   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128,
2160                         ma3_128, b3_128, &ma5_0, &b5_0);
2161   sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]);
2162   sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]);
2163   ma3[0][0] = SetrM128i(ma3_128[0], ma3_128[0]);
2164   ma3[1][0] = SetrM128i(ma3_128[1], ma3_128[1]);
2165   ma5[0] = SetrM128i(ma5_0, ma5_0);
2166   b3[0][0] = SetrM128i(b3_128[0], b3_128[0]);
2167   b3[1][0] = SetrM128i(b3_128[1], b3_128[1]);
2168   b5[0] = SetrM128i(b5_0, b5_0);
2169 
2170   int x = 0;
2171   do {
2172     __m256i ma[2], b[4], ma3x[3], ma5x[3];
2173     BoxFilterPreProcess(src0 + x + 8, src1 + x + 8,
2174                         x + 8 + kOverreadInBytesPass1_256 - width, x + 8,
2175                         scales, sum3, sum5, square_sum3, square_sum5, sum_width,
2176                         sq, ma3, b3, ma5, b5);
2177     Prepare3_8(ma3[0], ma3x);
2178     ma[0] = Sum343Lo(ma3x);
2179     ma[1] = Sum343Hi(ma3x);
2180     StoreAligned64(ma343[0] + x, ma);
2181     Sum343W(b3[0], b);
2182     StoreAligned64(b343[0] + x, b);
2183     Sum565W(b5, b);
2184     StoreAligned64(b565, b);
2185     Prepare3_8(ma3[1], ma3x);
2186     Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444);
2187     Store343_444Hi(ma3x, b3[1] + 1, x + 16, ma343[1], ma444, b343[1], b444);
2188     Prepare3_8(ma5, ma5x);
2189     ma[0] = Sum565Lo(ma5x);
2190     ma[1] = Sum565Hi(ma5x);
2191     StoreAligned64(ma565, ma);
2192     Sum343W(b3[0] + 1, b);
2193     StoreAligned64(b343[0] + x + 16, b);
2194     Sum565W(b5 + 1, b);
2195     StoreAligned64(b565 + 16, b);
2196     sq[0][0] = sq[0][2];
2197     sq[1][0] = sq[1][2];
2198     ma3[0][0] = ma3[0][2];
2199     ma3[1][0] = ma3[1][2];
2200     ma5[0] = ma5[2];
2201     b3[0][0] = b3[0][2];
2202     b3[1][0] = b3[1][2];
2203     b5[0] = b5[2];
2204     ma565 += 32;
2205     b565 += 32;
2206     x += 32;
2207   } while (x < width);
2208 }
2209 
2210 template <int shift>
FilterOutput(const __m256i ma_x_src,const __m256i b)2211 inline __m256i FilterOutput(const __m256i ma_x_src, const __m256i b) {
2212   // ma: 255 * 32 = 8160 (13 bits)
2213   // b: 65088 * 32 = 2082816 (21 bits)
2214   // v: b - ma * 255 (22 bits)
2215   const __m256i v = _mm256_sub_epi32(b, ma_x_src);
2216   // kSgrProjSgrBits = 8
2217   // kSgrProjRestoreBits = 4
2218   // shift = 4 or 5
2219   // v >> 8 or 9 (13 bits)
2220   return VrshrS32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
2221 }
2222 
2223 template <int shift>
CalculateFilteredOutput(const __m256i src,const __m256i ma,const __m256i b[2])2224 inline __m256i CalculateFilteredOutput(const __m256i src, const __m256i ma,
2225                                        const __m256i b[2]) {
2226   const __m256i ma_x_src_lo = VmullLo16(ma, src);
2227   const __m256i ma_x_src_hi = VmullHi16(ma, src);
2228   const __m256i dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]);
2229   const __m256i dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]);
2230   return _mm256_packs_epi32(dst_lo, dst_hi);  // 13 bits
2231 }
2232 
CalculateFilteredOutputPass1(const __m256i src,const __m256i ma[2],const __m256i b[2][2])2233 inline __m256i CalculateFilteredOutputPass1(const __m256i src,
2234                                             const __m256i ma[2],
2235                                             const __m256i b[2][2]) {
2236   const __m256i ma_sum = _mm256_add_epi16(ma[0], ma[1]);
2237   __m256i b_sum[2];
2238   b_sum[0] = _mm256_add_epi32(b[0][0], b[1][0]);
2239   b_sum[1] = _mm256_add_epi32(b[0][1], b[1][1]);
2240   return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
2241 }
2242 
CalculateFilteredOutputPass2(const __m256i src,const __m256i ma[3],const __m256i b[3][2])2243 inline __m256i CalculateFilteredOutputPass2(const __m256i src,
2244                                             const __m256i ma[3],
2245                                             const __m256i b[3][2]) {
2246   const __m256i ma_sum = Sum3_16(ma);
2247   __m256i b_sum[2];
2248   Sum3_32(b, b_sum);
2249   return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
2250 }
2251 
SelfGuidedFinal(const __m256i src,const __m256i v[2])2252 inline __m256i SelfGuidedFinal(const __m256i src, const __m256i v[2]) {
2253   const __m256i v_lo =
2254       VrshrS32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
2255   const __m256i v_hi =
2256       VrshrS32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
2257   const __m256i vv = _mm256_packs_epi32(v_lo, v_hi);
2258   return _mm256_add_epi16(src, vv);
2259 }
2260 
SelfGuidedDoubleMultiplier(const __m256i src,const __m256i filter[2],const int w0,const int w2)2261 inline __m256i SelfGuidedDoubleMultiplier(const __m256i src,
2262                                           const __m256i filter[2], const int w0,
2263                                           const int w2) {
2264   __m256i v[2];
2265   const __m256i w0_w2 =
2266       _mm256_set1_epi32((w2 << 16) | static_cast<uint16_t>(w0));
2267   const __m256i f_lo = _mm256_unpacklo_epi16(filter[0], filter[1]);
2268   const __m256i f_hi = _mm256_unpackhi_epi16(filter[0], filter[1]);
2269   v[0] = _mm256_madd_epi16(w0_w2, f_lo);
2270   v[1] = _mm256_madd_epi16(w0_w2, f_hi);
2271   return SelfGuidedFinal(src, v);
2272 }
2273 
SelfGuidedSingleMultiplier(const __m256i src,const __m256i filter,const int w0)2274 inline __m256i SelfGuidedSingleMultiplier(const __m256i src,
2275                                           const __m256i filter, const int w0) {
2276   // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
2277   __m256i v[2];
2278   v[0] = VmullNLo8(filter, w0);
2279   v[1] = VmullNHi8(filter, w0);
2280   return SelfGuidedFinal(src, v);
2281 }
2282 
BoxFilterPass1(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const ma565[2],uint32_t * const b565[2],uint8_t * const dst)2283 LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
2284     const uint8_t* const src, const uint8_t* const src0,
2285     const uint8_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
2286     uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width,
2287     const uint32_t scale, const int16_t w0, uint16_t* const ma565[2],
2288     uint32_t* const b565[2], uint8_t* const dst) {
2289   __m128i ma0, b0, s[2][3], sq_128[2][2];
2290   __m256i mas[3], sq[2][3], bs[3];
2291   s[0][0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2292   s[1][0] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width);
2293   sq_128[0][0] = SquareLo8(s[0][0]);
2294   sq_128[1][0] = SquareLo8(s[1][0]);
2295   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, &b0);
2296   sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]);
2297   sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]);
2298   mas[0] = SetrM128i(ma0, ma0);
2299   bs[0] = SetrM128i(b0, b0);
2300 
2301   int x = 0;
2302   do {
2303     __m256i ma[3], ma5[3], b[2][2][2];
2304     BoxFilterPreProcess5(src0 + x + 8, src1 + x + 8,
2305                          x + 8 + kOverreadInBytesPass1_256 - width, sum_width,
2306                          x + 8, scale, sum5, square_sum5, sq, mas, bs);
2307     Prepare3_8(mas, ma5);
2308     ma[1] = Sum565Lo(ma5);
2309     ma[2] = Sum565Hi(ma5);
2310     StoreAligned64(ma565[1] + x, ma + 1);
2311     Sum565W(bs + 0, b[0][1]);
2312     Sum565W(bs + 1, b[1][1]);
2313     StoreAligned64(b565[1] + x + 0, b[0][1]);
2314     StoreAligned64(b565[1] + x + 16, b[1][1]);
2315     const __m256i sr0 = LoadUnaligned32(src + x);
2316     const __m256i sr1 = LoadUnaligned32(src + stride + x);
2317     const __m256i sr0_lo = _mm256_unpacklo_epi8(sr0, _mm256_setzero_si256());
2318     const __m256i sr1_lo = _mm256_unpacklo_epi8(sr1, _mm256_setzero_si256());
2319     ma[0] = LoadAligned32(ma565[0] + x);
2320     LoadAligned64(b565[0] + x, b[0][0]);
2321     const __m256i p00 = CalculateFilteredOutputPass1(sr0_lo, ma, b[0]);
2322     const __m256i p01 = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[0][1]);
2323     const __m256i d00 = SelfGuidedSingleMultiplier(sr0_lo, p00, w0);
2324     const __m256i d10 = SelfGuidedSingleMultiplier(sr1_lo, p01, w0);
2325     const __m256i sr0_hi = _mm256_unpackhi_epi8(sr0, _mm256_setzero_si256());
2326     const __m256i sr1_hi = _mm256_unpackhi_epi8(sr1, _mm256_setzero_si256());
2327     ma[1] = LoadAligned32(ma565[0] + x + 16);
2328     LoadAligned64(b565[0] + x + 16, b[1][0]);
2329     const __m256i p10 = CalculateFilteredOutputPass1(sr0_hi, ma + 1, b[1]);
2330     const __m256i p11 = CalculateFilteredOutput<4>(sr1_hi, ma[2], b[1][1]);
2331     const __m256i d01 = SelfGuidedSingleMultiplier(sr0_hi, p10, w0);
2332     const __m256i d11 = SelfGuidedSingleMultiplier(sr1_hi, p11, w0);
2333     StoreUnaligned32(dst + x, _mm256_packus_epi16(d00, d01));
2334     StoreUnaligned32(dst + stride + x, _mm256_packus_epi16(d10, d11));
2335     sq[0][0] = sq[0][2];
2336     sq[1][0] = sq[1][2];
2337     mas[0] = mas[2];
2338     bs[0] = bs[2];
2339     x += 32;
2340   } while (x < width);
2341 }
2342 
BoxFilterPass1LastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565,uint8_t * const dst)2343 inline void BoxFilterPass1LastRow(
2344     const uint8_t* const src, const uint8_t* const src0, const int width,
2345     const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
2346     uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565,
2347     uint32_t* b565, uint8_t* const dst) {
2348   const __m128i s0 =
2349       LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2350   __m128i ma0, b0, sq_128[2];
2351   __m256i mas[3], sq[3], bs[3];
2352   sq_128[0] = SquareLo8(s0);
2353   BoxFilterPreProcess5LastRowLo(s0, scale, sum5, square_sum5, sq_128, &ma0,
2354                                 &b0);
2355   sq[0] = SetrM128i(sq_128[0], sq_128[1]);
2356   mas[0] = SetrM128i(ma0, ma0);
2357   bs[0] = SetrM128i(b0, b0);
2358 
2359   int x = 0;
2360   do {
2361     __m256i ma[3], ma5[3], b[2][2];
2362     BoxFilterPreProcess5LastRow(
2363         src0 + x + 8, x + 8 + kOverreadInBytesPass1_256 - width, sum_width,
2364         x + 8, scale, sum5, square_sum5, sq, mas, bs);
2365     Prepare3_8(mas, ma5);
2366     ma[1] = Sum565Lo(ma5);
2367     ma[2] = Sum565Hi(ma5);
2368     Sum565W(bs + 0, b[1]);
2369     const __m256i sr = LoadUnaligned32(src + x);
2370     const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256());
2371     const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256());
2372     ma[0] = LoadAligned32(ma565);
2373     LoadAligned64(b565 + 0, b[0]);
2374     const __m256i p0 = CalculateFilteredOutputPass1(sr_lo, ma, b);
2375     ma[1] = LoadAligned32(ma565 + 16);
2376     LoadAligned64(b565 + 16, b[0]);
2377     Sum565W(bs + 1, b[1]);
2378     const __m256i p1 = CalculateFilteredOutputPass1(sr_hi, ma + 1, b);
2379     const __m256i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0);
2380     const __m256i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0);
2381     StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
2382     sq[0] = sq[2];
2383     mas[0] = mas[2];
2384     bs[0] = bs[2];
2385     ma565 += 32;
2386     b565 += 32;
2387     x += 32;
2388   } while (x < width);
2389 }
2390 
BoxFilterPass2(const uint8_t * const src,const uint8_t * const src0,const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * const ma343[3],uint16_t * const ma444[2],uint32_t * const b343[3],uint32_t * const b444[2],uint8_t * const dst)2391 LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
2392     const uint8_t* const src, const uint8_t* const src0, const int width,
2393     const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
2394     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
2395     uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3],
2396     uint32_t* const b444[2], uint8_t* const dst) {
2397   const __m128i s0 =
2398       LoadUnaligned16Msan(src0, kOverreadInBytesPass2_128 - width);
2399   __m128i ma0, b0, sq_128[2];
2400   __m256i mas[3], sq[3], bs[3];
2401   sq_128[0] = SquareLo8(s0);
2402   BoxFilterPreProcess3Lo(s0, scale, sum3, square_sum3, sq_128, &ma0, &b0);
2403   sq[0] = SetrM128i(sq_128[0], sq_128[1]);
2404   mas[0] = SetrM128i(ma0, ma0);
2405   bs[0] = SetrM128i(b0, b0);
2406 
2407   int x = 0;
2408   do {
2409     __m256i ma[4], b[4][2], ma3[3];
2410     BoxFilterPreProcess3(src0 + x + 8,
2411                          x + 8 + kOverreadInBytesPass2_256 - width, x + 8,
2412                          sum_width, scale, sum3, square_sum3, sq, mas, bs);
2413     Prepare3_8(mas, ma3);
2414     Store343_444Lo(ma3, bs + 0, x + 0, &ma[2], b[2], ma343[2], ma444[1],
2415                    b343[2], b444[1]);
2416     Store343_444Hi(ma3, bs + 1, x + 16, &ma[3], b[3], ma343[2], ma444[1],
2417                    b343[2], b444[1]);
2418     const __m256i sr = LoadUnaligned32(src + x);
2419     const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256());
2420     const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256());
2421     ma[0] = LoadAligned32(ma343[0] + x);
2422     ma[1] = LoadAligned32(ma444[0] + x);
2423     LoadAligned64(b343[0] + x, b[0]);
2424     LoadAligned64(b444[0] + x, b[1]);
2425     const __m256i p0 = CalculateFilteredOutputPass2(sr_lo, ma, b);
2426     ma[1] = LoadAligned32(ma343[0] + x + 16);
2427     ma[2] = LoadAligned32(ma444[0] + x + 16);
2428     LoadAligned64(b343[0] + x + 16, b[1]);
2429     LoadAligned64(b444[0] + x + 16, b[2]);
2430     const __m256i p1 = CalculateFilteredOutputPass2(sr_hi, ma + 1, b + 1);
2431     const __m256i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0);
2432     const __m256i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0);
2433     StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
2434     sq[0] = sq[2];
2435     mas[0] = mas[2];
2436     bs[0] = bs[2];
2437     x += 32;
2438   } while (x < width);
2439 }
2440 
BoxFilter(const uint8_t * const src,const uint8_t * const src0,const uint8_t * const src1,const ptrdiff_t stride,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint8_t * const dst)2441 LIBGAV1_ALWAYS_INLINE void BoxFilter(
2442     const uint8_t* const src, const uint8_t* const src0,
2443     const uint8_t* const src1, const ptrdiff_t stride, const int width,
2444     const uint16_t scales[2], const int16_t w0, const int16_t w2,
2445     uint16_t* const sum3[4], uint16_t* const sum5[5],
2446     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2447     const ptrdiff_t sum_width, uint16_t* const ma343[4],
2448     uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4],
2449     uint32_t* const b444[3], uint32_t* const b565[2], uint8_t* const dst) {
2450   __m128i s[2], ma3_128[2], ma5_0, sq_128[2][2], b3_128[2], b5_0;
2451   __m256i ma3[2][3], ma5[3], sq[2][3], b3[2][5], b5[5];
2452   s[0] = LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2453   s[1] = LoadUnaligned16Msan(src1, kOverreadInBytesPass1_128 - width);
2454   sq_128[0][0] = SquareLo8(s[0]);
2455   sq_128[1][0] = SquareLo8(s[1]);
2456   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128,
2457                         ma3_128, b3_128, &ma5_0, &b5_0);
2458   sq[0][0] = SetrM128i(sq_128[0][0], sq_128[0][1]);
2459   sq[1][0] = SetrM128i(sq_128[1][0], sq_128[1][1]);
2460   ma3[0][0] = SetrM128i(ma3_128[0], ma3_128[0]);
2461   ma3[1][0] = SetrM128i(ma3_128[1], ma3_128[1]);
2462   ma5[0] = SetrM128i(ma5_0, ma5_0);
2463   b3[0][0] = SetrM128i(b3_128[0], b3_128[0]);
2464   b3[1][0] = SetrM128i(b3_128[1], b3_128[1]);
2465   b5[0] = SetrM128i(b5_0, b5_0);
2466 
2467   int x = 0;
2468   do {
2469     __m256i ma[3][3], mat[3][3], b[3][3][2], p[2][2], ma3x[2][3], ma5x[3];
2470     BoxFilterPreProcess(src0 + x + 8, src1 + x + 8,
2471                         x + 8 + kOverreadInBytesPass1_256 - width, x + 8,
2472                         scales, sum3, sum5, square_sum3, square_sum5, sum_width,
2473                         sq, ma3, b3, ma5, b5);
2474     Prepare3_8(ma3[0], ma3x[0]);
2475     Prepare3_8(ma3[1], ma3x[1]);
2476     Prepare3_8(ma5, ma5x);
2477     Store343_444Lo(ma3x[0], b3[0], x, &ma[1][2], &ma[2][1], b[1][2], b[2][1],
2478                    ma343[2], ma444[1], b343[2], b444[1]);
2479     Store343_444Lo(ma3x[1], b3[1], x, &ma[2][2], b[2][2], ma343[3], ma444[2],
2480                    b343[3], b444[2]);
2481     ma[0][1] = Sum565Lo(ma5x);
2482     ma[0][2] = Sum565Hi(ma5x);
2483     mat[0][1] = ma[0][2];
2484     StoreAligned64(ma565[1] + x, ma[0] + 1);
2485     Sum565W(b5, b[0][1]);
2486     StoreAligned64(b565[1] + x, b[0][1]);
2487     const __m256i sr0 = LoadUnaligned32(src + x);
2488     const __m256i sr1 = LoadUnaligned32(src + stride + x);
2489     const __m256i sr0_lo = _mm256_unpacklo_epi8(sr0, _mm256_setzero_si256());
2490     const __m256i sr1_lo = _mm256_unpacklo_epi8(sr1, _mm256_setzero_si256());
2491     ma[0][0] = LoadAligned32(ma565[0] + x);
2492     LoadAligned64(b565[0] + x, b[0][0]);
2493     p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]);
2494     p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]);
2495     ma[1][0] = LoadAligned32(ma343[0] + x);
2496     ma[1][1] = LoadAligned32(ma444[0] + x);
2497     LoadAligned64(b343[0] + x, b[1][0]);
2498     LoadAligned64(b444[0] + x, b[1][1]);
2499     p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]);
2500     const __m256i d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2);
2501     ma[2][0] = LoadAligned32(ma343[1] + x);
2502     LoadAligned64(b343[1] + x, b[2][0]);
2503     p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]);
2504     const __m256i d10 = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2);
2505 
2506     Sum565W(b5 + 1, b[0][1]);
2507     StoreAligned64(b565[1] + x + 16, b[0][1]);
2508     Store343_444Hi(ma3x[0], b3[0] + 1, x + 16, &mat[1][2], &mat[2][1], b[1][2],
2509                    b[2][1], ma343[2], ma444[1], b343[2], b444[1]);
2510     Store343_444Hi(ma3x[1], b3[1] + 1, x + 16, &mat[2][2], b[2][2], ma343[3],
2511                    ma444[2], b343[3], b444[2]);
2512     const __m256i sr0_hi = _mm256_unpackhi_epi8(sr0, _mm256_setzero_si256());
2513     const __m256i sr1_hi = _mm256_unpackhi_epi8(sr1, _mm256_setzero_si256());
2514     mat[0][0] = LoadAligned32(ma565[0] + x + 16);
2515     LoadAligned64(b565[0] + x + 16, b[0][0]);
2516     p[0][0] = CalculateFilteredOutputPass1(sr0_hi, mat[0], b[0]);
2517     p[1][0] = CalculateFilteredOutput<4>(sr1_hi, mat[0][1], b[0][1]);
2518     mat[1][0] = LoadAligned32(ma343[0] + x + 16);
2519     mat[1][1] = LoadAligned32(ma444[0] + x + 16);
2520     LoadAligned64(b343[0] + x + 16, b[1][0]);
2521     LoadAligned64(b444[0] + x + 16, b[1][1]);
2522     p[0][1] = CalculateFilteredOutputPass2(sr0_hi, mat[1], b[1]);
2523     const __m256i d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2);
2524     mat[2][0] = LoadAligned32(ma343[1] + x + 16);
2525     LoadAligned64(b343[1] + x + 16, b[2][0]);
2526     p[1][1] = CalculateFilteredOutputPass2(sr1_hi, mat[2], b[2]);
2527     const __m256i d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2);
2528     StoreUnaligned32(dst + x, _mm256_packus_epi16(d00, d01));
2529     StoreUnaligned32(dst + stride + x, _mm256_packus_epi16(d10, d11));
2530     sq[0][0] = sq[0][2];
2531     sq[1][0] = sq[1][2];
2532     ma3[0][0] = ma3[0][2];
2533     ma3[1][0] = ma3[1][2];
2534     ma5[0] = ma5[2];
2535     b3[0][0] = b3[0][2];
2536     b3[1][0] = b3[1][2];
2537     b5[0] = b5[2];
2538     x += 32;
2539   } while (x < width);
2540 }
2541 
BoxFilterLastRow(const uint8_t * const src,const uint8_t * const src0,const int width,const ptrdiff_t sum_width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343,uint16_t * const ma444,uint16_t * const ma565,uint32_t * const b343,uint32_t * const b444,uint32_t * const b565,uint8_t * const dst)2542 inline void BoxFilterLastRow(
2543     const uint8_t* const src, const uint8_t* const src0, const int width,
2544     const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0,
2545     const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5],
2546     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2547     uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565,
2548     uint32_t* const b343, uint32_t* const b444, uint32_t* const b565,
2549     uint8_t* const dst) {
2550   const __m128i s0 =
2551       LoadUnaligned16Msan(src0, kOverreadInBytesPass1_128 - width);
2552   __m128i ma3_0, ma5_0, b3_0, b5_0, sq_128[2];
2553   __m256i ma3[3], ma5[3], sq[3], b3[3], b5[3];
2554   sq_128[0] = SquareLo8(s0);
2555   BoxFilterPreProcessLastRowLo(s0, scales, sum3, sum5, square_sum3, square_sum5,
2556                                sq_128, &ma3_0, &ma5_0, &b3_0, &b5_0);
2557   sq[0] = SetrM128i(sq_128[0], sq_128[1]);
2558   ma3[0] = SetrM128i(ma3_0, ma3_0);
2559   ma5[0] = SetrM128i(ma5_0, ma5_0);
2560   b3[0] = SetrM128i(b3_0, b3_0);
2561   b5[0] = SetrM128i(b5_0, b5_0);
2562 
2563   int x = 0;
2564   do {
2565     __m256i ma[3], mat[3], b[3][2], p[2], ma3x[3], ma5x[3];
2566     BoxFilterPreProcessLastRow(src0 + x + 8,
2567                                x + 8 + kOverreadInBytesPass1_256 - width,
2568                                sum_width, x + 8, scales, sum3, sum5,
2569                                square_sum3, square_sum5, sq, ma3, ma5, b3, b5);
2570     Prepare3_8(ma3, ma3x);
2571     Prepare3_8(ma5, ma5x);
2572     ma[1] = Sum565Lo(ma5x);
2573     Sum565W(b5, b[1]);
2574     ma[2] = Sum343Lo(ma3x);
2575     Sum343W(b3, b[2]);
2576     const __m256i sr = LoadUnaligned32(src + x);
2577     const __m256i sr_lo = _mm256_unpacklo_epi8(sr, _mm256_setzero_si256());
2578     ma[0] = LoadAligned32(ma565 + x);
2579     LoadAligned64(b565 + x, b[0]);
2580     p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b);
2581     ma[0] = LoadAligned32(ma343 + x);
2582     ma[1] = LoadAligned32(ma444 + x);
2583     LoadAligned64(b343 + x, b[0]);
2584     LoadAligned64(b444 + x, b[1]);
2585     p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b);
2586     const __m256i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2);
2587 
2588     mat[1] = Sum565Hi(ma5x);
2589     Sum565W(b5 + 1, b[1]);
2590     mat[2] = Sum343Hi(ma3x);
2591     Sum343W(b3 + 1, b[2]);
2592     const __m256i sr_hi = _mm256_unpackhi_epi8(sr, _mm256_setzero_si256());
2593     mat[0] = LoadAligned32(ma565 + x + 16);
2594     LoadAligned64(b565 + x + 16, b[0]);
2595     p[0] = CalculateFilteredOutputPass1(sr_hi, mat, b);
2596     mat[0] = LoadAligned32(ma343 + x + 16);
2597     mat[1] = LoadAligned32(ma444 + x + 16);
2598     LoadAligned64(b343 + x + 16, b[0]);
2599     LoadAligned64(b444 + x + 16, b[1]);
2600     p[1] = CalculateFilteredOutputPass2(sr_hi, mat, b);
2601     const __m256i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2);
2602     StoreUnaligned32(dst + x, _mm256_packus_epi16(d0, d1));
2603     sq[0] = sq[2];
2604     ma3[0] = ma3[2];
2605     ma5[0] = ma5[2];
2606     b3[0] = b3[2];
2607     b5[0] = b5[2];
2608     x += 32;
2609   } while (x < width);
2610 }
2611 
BoxFilterProcess(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2612 LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
2613     const RestorationUnitInfo& restoration_info, const uint8_t* src,
2614     const ptrdiff_t stride, const uint8_t* const top_border,
2615     const ptrdiff_t top_border_stride, const uint8_t* bottom_border,
2616     const ptrdiff_t bottom_border_stride, const int width, const int height,
2617     SgrBuffer* const sgr_buffer, uint8_t* dst) {
2618   const auto temp_stride = Align<ptrdiff_t>(width, 32);
2619   const auto sum_width = temp_stride + 8;
2620   const auto sum_stride = temp_stride + 32;
2621   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2622   const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
2623   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2624   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2625   const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
2626   uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
2627   uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
2628   sum3[0] = sgr_buffer->sum3 + kSumOffset;
2629   square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset;
2630   ma343[0] = sgr_buffer->ma343;
2631   b343[0] = sgr_buffer->b343;
2632   for (int i = 1; i <= 3; ++i) {
2633     sum3[i] = sum3[i - 1] + sum_stride;
2634     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2635     ma343[i] = ma343[i - 1] + temp_stride;
2636     b343[i] = b343[i - 1] + temp_stride;
2637   }
2638   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2639   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2640   for (int i = 1; i <= 4; ++i) {
2641     sum5[i] = sum5[i - 1] + sum_stride;
2642     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2643   }
2644   ma444[0] = sgr_buffer->ma444;
2645   b444[0] = sgr_buffer->b444;
2646   for (int i = 1; i <= 2; ++i) {
2647     ma444[i] = ma444[i - 1] + temp_stride;
2648     b444[i] = b444[i - 1] + temp_stride;
2649   }
2650   ma565[0] = sgr_buffer->ma565;
2651   ma565[1] = ma565[0] + temp_stride;
2652   b565[0] = sgr_buffer->b565;
2653   b565[1] = b565[0] + temp_stride;
2654   assert(scales[0] != 0);
2655   assert(scales[1] != 0);
2656   BoxSum(top_border, top_border_stride, width, sum_stride, temp_stride, sum3[0],
2657          sum5[1], square_sum3[0], square_sum5[1]);
2658   sum5[0] = sum5[1];
2659   square_sum5[0] = square_sum5[1];
2660   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
2661   BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
2662                          square_sum5, sum_width, ma343, ma444[0], ma565[0],
2663                          b343, b444[0], b565[0]);
2664   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2665   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2666 
2667   for (int y = (height >> 1) - 1; y > 0; --y) {
2668     Circulate4PointersBy2<uint16_t>(sum3);
2669     Circulate4PointersBy2<uint32_t>(square_sum3);
2670     Circulate5PointersBy2<uint16_t>(sum5);
2671     Circulate5PointersBy2<uint32_t>(square_sum5);
2672     BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
2673               scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width,
2674               ma343, ma444, ma565, b343, b444, b565, dst);
2675     src += 2 * stride;
2676     dst += 2 * stride;
2677     Circulate4PointersBy2<uint16_t>(ma343);
2678     Circulate4PointersBy2<uint32_t>(b343);
2679     std::swap(ma444[0], ma444[2]);
2680     std::swap(b444[0], b444[2]);
2681     std::swap(ma565[0], ma565[1]);
2682     std::swap(b565[0], b565[1]);
2683   }
2684 
2685   Circulate4PointersBy2<uint16_t>(sum3);
2686   Circulate4PointersBy2<uint32_t>(square_sum3);
2687   Circulate5PointersBy2<uint16_t>(sum5);
2688   Circulate5PointersBy2<uint32_t>(square_sum5);
2689   if ((height & 1) == 0 || height > 1) {
2690     const uint8_t* sr[2];
2691     if ((height & 1) == 0) {
2692       sr[0] = bottom_border;
2693       sr[1] = bottom_border + bottom_border_stride;
2694     } else {
2695       sr[0] = src + 2 * stride;
2696       sr[1] = bottom_border;
2697     }
2698     BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
2699               square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343,
2700               b444, b565, dst);
2701   }
2702   if ((height & 1) != 0) {
2703     if (height > 1) {
2704       src += 2 * stride;
2705       dst += 2 * stride;
2706       Circulate4PointersBy2<uint16_t>(sum3);
2707       Circulate4PointersBy2<uint32_t>(square_sum3);
2708       Circulate5PointersBy2<uint16_t>(sum5);
2709       Circulate5PointersBy2<uint32_t>(square_sum5);
2710       Circulate4PointersBy2<uint16_t>(ma343);
2711       Circulate4PointersBy2<uint32_t>(b343);
2712       std::swap(ma444[0], ma444[2]);
2713       std::swap(b444[0], b444[2]);
2714       std::swap(ma565[0], ma565[1]);
2715       std::swap(b565[0], b565[1]);
2716     }
2717     BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width,
2718                      sum_width, scales, w0, w2, sum3, sum5, square_sum3,
2719                      square_sum5, ma343[0], ma444[0], ma565[0], b343[0],
2720                      b444[0], b565[0], dst);
2721   }
2722 }
2723 
BoxFilterProcessPass1(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2724 inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
2725                                   const uint8_t* src, const ptrdiff_t stride,
2726                                   const uint8_t* const top_border,
2727                                   const ptrdiff_t top_border_stride,
2728                                   const uint8_t* bottom_border,
2729                                   const ptrdiff_t bottom_border_stride,
2730                                   const int width, const int height,
2731                                   SgrBuffer* const sgr_buffer, uint8_t* dst) {
2732   const auto temp_stride = Align<ptrdiff_t>(width, 32);
2733   const auto sum_width = temp_stride + 8;
2734   const auto sum_stride = temp_stride + 32;
2735   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2736   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
2737   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2738   uint16_t *sum5[5], *ma565[2];
2739   uint32_t *square_sum5[5], *b565[2];
2740   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2741   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2742   for (int i = 1; i <= 4; ++i) {
2743     sum5[i] = sum5[i - 1] + sum_stride;
2744     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2745   }
2746   ma565[0] = sgr_buffer->ma565;
2747   ma565[1] = ma565[0] + temp_stride;
2748   b565[0] = sgr_buffer->b565;
2749   b565[1] = b565[0] + temp_stride;
2750   assert(scale != 0);
2751   BoxSum<5>(top_border, top_border_stride, width, sum_stride, temp_stride,
2752             sum5[1], square_sum5[1]);
2753   sum5[0] = sum5[1];
2754   square_sum5[0] = square_sum5[1];
2755   const uint8_t* const s = (height > 1) ? src + stride : bottom_border;
2756   BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width,
2757                           ma565[0], b565[0]);
2758   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2759   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2760 
2761   for (int y = (height >> 1) - 1; y > 0; --y) {
2762     Circulate5PointersBy2<uint16_t>(sum5);
2763     Circulate5PointersBy2<uint32_t>(square_sum5);
2764     BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
2765                    square_sum5, width, sum_width, scale, w0, ma565, b565, dst);
2766     src += 2 * stride;
2767     dst += 2 * stride;
2768     std::swap(ma565[0], ma565[1]);
2769     std::swap(b565[0], b565[1]);
2770   }
2771 
2772   Circulate5PointersBy2<uint16_t>(sum5);
2773   Circulate5PointersBy2<uint32_t>(square_sum5);
2774   if ((height & 1) == 0 || height > 1) {
2775     const uint8_t* sr[2];
2776     if ((height & 1) == 0) {
2777       sr[0] = bottom_border;
2778       sr[1] = bottom_border + bottom_border_stride;
2779     } else {
2780       sr[0] = src + 2 * stride;
2781       sr[1] = bottom_border;
2782     }
2783     BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
2784                    sum_width, scale, w0, ma565, b565, dst);
2785   }
2786   if ((height & 1) != 0) {
2787     src += 3;
2788     if (height > 1) {
2789       src += 2 * stride;
2790       dst += 2 * stride;
2791       std::swap(ma565[0], ma565[1]);
2792       std::swap(b565[0], b565[1]);
2793       Circulate5PointersBy2<uint16_t>(sum5);
2794       Circulate5PointersBy2<uint32_t>(square_sum5);
2795     }
2796     BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width,
2797                           sum_width, scale, w0, sum5, square_sum5, ma565[0],
2798                           b565[0], dst);
2799   }
2800 }
2801 
BoxFilterProcessPass2(const RestorationUnitInfo & restoration_info,const uint8_t * src,const ptrdiff_t stride,const uint8_t * const top_border,const ptrdiff_t top_border_stride,const uint8_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint8_t * dst)2802 inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
2803                                   const uint8_t* src, const ptrdiff_t stride,
2804                                   const uint8_t* const top_border,
2805                                   const ptrdiff_t top_border_stride,
2806                                   const uint8_t* bottom_border,
2807                                   const ptrdiff_t bottom_border_stride,
2808                                   const int width, const int height,
2809                                   SgrBuffer* const sgr_buffer, uint8_t* dst) {
2810   assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
2811   const auto temp_stride = Align<ptrdiff_t>(width, 32);
2812   const auto sum_width = temp_stride + 8;
2813   const auto sum_stride = temp_stride + 32;
2814   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2815   const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
2816   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2817   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1];  // < 2^12.
2818   uint16_t *sum3[3], *ma343[3], *ma444[2];
2819   uint32_t *square_sum3[3], *b343[3], *b444[2];
2820   sum3[0] = sgr_buffer->sum3 + kSumOffset;
2821   square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset;
2822   ma343[0] = sgr_buffer->ma343;
2823   b343[0] = sgr_buffer->b343;
2824   for (int i = 1; i <= 2; ++i) {
2825     sum3[i] = sum3[i - 1] + sum_stride;
2826     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2827     ma343[i] = ma343[i - 1] + temp_stride;
2828     b343[i] = b343[i - 1] + temp_stride;
2829   }
2830   ma444[0] = sgr_buffer->ma444;
2831   ma444[1] = ma444[0] + temp_stride;
2832   b444[0] = sgr_buffer->b444;
2833   b444[1] = b444[0] + temp_stride;
2834   assert(scale != 0);
2835   BoxSum<3>(top_border, top_border_stride, width, sum_stride, temp_stride,
2836             sum3[0], square_sum3[0]);
2837   BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3,
2838                                  sum_width, ma343[0], nullptr, b343[0],
2839                                  nullptr);
2840   Circulate3PointersBy1<uint16_t>(sum3);
2841   Circulate3PointersBy1<uint32_t>(square_sum3);
2842   const uint8_t* s;
2843   if (height > 1) {
2844     s = src + stride;
2845   } else {
2846     s = bottom_border;
2847     bottom_border += bottom_border_stride;
2848   }
2849   BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width,
2850                                 ma343[1], ma444[0], b343[1], b444[0]);
2851 
2852   for (int y = height - 2; y > 0; --y) {
2853     Circulate3PointersBy1<uint16_t>(sum3);
2854     Circulate3PointersBy1<uint32_t>(square_sum3);
2855     BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3,
2856                    square_sum3, ma343, ma444, b343, b444, dst);
2857     src += stride;
2858     dst += stride;
2859     Circulate3PointersBy1<uint16_t>(ma343);
2860     Circulate3PointersBy1<uint32_t>(b343);
2861     std::swap(ma444[0], ma444[1]);
2862     std::swap(b444[0], b444[1]);
2863   }
2864 
2865   int y = std::min(height, 2);
2866   src += 2;
2867   do {
2868     Circulate3PointersBy1<uint16_t>(sum3);
2869     Circulate3PointersBy1<uint32_t>(square_sum3);
2870     BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3,
2871                    square_sum3, ma343, ma444, b343, b444, dst);
2872     src += stride;
2873     dst += stride;
2874     bottom_border += bottom_border_stride;
2875     Circulate3PointersBy1<uint16_t>(ma343);
2876     Circulate3PointersBy1<uint32_t>(b343);
2877     std::swap(ma444[0], ma444[1]);
2878     std::swap(b444[0], b444[1]);
2879   } while (--y != 0);
2880 }
2881 
2882 // If |width| is non-multiple of 32, up to 31 more pixels are written to |dest|
2883 // in the end of each row. It is safe to overwrite the output as it will not be
2884 // part of the visible frame.
SelfGuidedFilter_AVX2(const RestorationUnitInfo & LIBGAV1_RESTRICT restoration_info,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_border,const ptrdiff_t top_border_stride,const void * LIBGAV1_RESTRICT const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * LIBGAV1_RESTRICT const restoration_buffer,void * LIBGAV1_RESTRICT const dest)2885 void SelfGuidedFilter_AVX2(
2886     const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
2887     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
2888     const void* LIBGAV1_RESTRICT const top_border,
2889     const ptrdiff_t top_border_stride,
2890     const void* LIBGAV1_RESTRICT const bottom_border,
2891     const ptrdiff_t bottom_border_stride, const int width, const int height,
2892     RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
2893     void* LIBGAV1_RESTRICT const dest) {
2894   const int index = restoration_info.sgr_proj_info.index;
2895   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
2896   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
2897   const auto* const src = static_cast<const uint8_t*>(source);
2898   const auto* top = static_cast<const uint8_t*>(top_border);
2899   const auto* bottom = static_cast<const uint8_t*>(bottom_border);
2900   auto* const dst = static_cast<uint8_t*>(dest);
2901   SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
2902   if (radius_pass_1 == 0) {
2903     // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
2904     // following assertion.
2905     assert(radius_pass_0 != 0);
2906     BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3,
2907                           top_border_stride, bottom - 3, bottom_border_stride,
2908                           width, height, sgr_buffer, dst);
2909   } else if (radius_pass_0 == 0) {
2910     BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2,
2911                           top_border_stride, bottom - 2, bottom_border_stride,
2912                           width, height, sgr_buffer, dst);
2913   } else {
2914     BoxFilterProcess(restoration_info, src - 3, stride, top - 3,
2915                      top_border_stride, bottom - 3, bottom_border_stride, width,
2916                      height, sgr_buffer, dst);
2917   }
2918 }
2919 
Init8bpp()2920 void Init8bpp() {
2921   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth8);
2922   assert(dsp != nullptr);
2923 #if DSP_ENABLED_8BPP_AVX2(WienerFilter)
2924   dsp->loop_restorations[0] = WienerFilter_AVX2;
2925 #endif
2926 #if DSP_ENABLED_8BPP_AVX2(SelfGuidedFilter)
2927   dsp->loop_restorations[1] = SelfGuidedFilter_AVX2;
2928 #endif
2929 }
2930 
2931 }  // namespace
2932 }  // namespace low_bitdepth
2933 
LoopRestorationInit_AVX2()2934 void LoopRestorationInit_AVX2() { low_bitdepth::Init8bpp(); }
2935 
2936 }  // namespace dsp
2937 }  // namespace libgav1
2938 
2939 #else   // !LIBGAV1_TARGETING_AVX2
2940 namespace libgav1 {
2941 namespace dsp {
2942 
LoopRestorationInit_AVX2()2943 void LoopRestorationInit_AVX2() {}
2944 
2945 }  // namespace dsp
2946 }  // namespace libgav1
2947 #endif  // LIBGAV1_TARGETING_AVX2
2948