xref: /aosp_15_r20/external/libgav1/src/dsp/x86/loop_restoration_10bit_avx2.cc (revision 095378508e87ed692bf8dfeb34008b65b3735891)
1 // Copyright 2020 The libgav1 Authors
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //      http://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 #include "src/dsp/loop_restoration.h"
16 #include "src/utils/cpu.h"
17 
18 #if LIBGAV1_TARGETING_AVX2 && LIBGAV1_MAX_BITDEPTH >= 10
19 #include <immintrin.h>
20 
21 #include <algorithm>
22 #include <cassert>
23 #include <cstddef>
24 #include <cstdint>
25 #include <cstring>
26 
27 #include "src/dsp/common.h"
28 #include "src/dsp/constants.h"
29 #include "src/dsp/dsp.h"
30 #include "src/dsp/x86/common_avx2.h"
31 #include "src/utils/common.h"
32 #include "src/utils/constants.h"
33 
34 namespace libgav1 {
35 namespace dsp {
36 namespace {
37 
WienerHorizontalClip(const __m256i s[2],int16_t * const wiener_buffer)38 inline void WienerHorizontalClip(const __m256i s[2],
39                                  int16_t* const wiener_buffer) {
40   constexpr int offset =
41       1 << (10 + kWienerFilterBits - kInterRoundBitsHorizontal - 1);
42   constexpr int limit = (offset << 2) - 1;
43   const __m256i offsets = _mm256_set1_epi16(-offset);
44   const __m256i limits = _mm256_set1_epi16(limit - offset);
45   const __m256i round = _mm256_set1_epi32(1 << (kInterRoundBitsHorizontal - 1));
46   const __m256i sum0 = _mm256_add_epi32(s[0], round);
47   const __m256i sum1 = _mm256_add_epi32(s[1], round);
48   const __m256i rounded_sum0 =
49       _mm256_srai_epi32(sum0, kInterRoundBitsHorizontal);
50   const __m256i rounded_sum1 =
51       _mm256_srai_epi32(sum1, kInterRoundBitsHorizontal);
52   const __m256i rounded_sum = _mm256_packs_epi32(rounded_sum0, rounded_sum1);
53   const __m256i d0 = _mm256_max_epi16(rounded_sum, offsets);
54   const __m256i d1 = _mm256_min_epi16(d0, limits);
55   StoreAligned32(wiener_buffer, d1);
56 }
57 
WienerHorizontalTap7Kernel(const __m256i s[7],const __m256i filter[2],int16_t * const wiener_buffer)58 inline void WienerHorizontalTap7Kernel(const __m256i s[7],
59                                        const __m256i filter[2],
60                                        int16_t* const wiener_buffer) {
61   const __m256i s06 = _mm256_add_epi16(s[0], s[6]);
62   const __m256i s15 = _mm256_add_epi16(s[1], s[5]);
63   const __m256i s24 = _mm256_add_epi16(s[2], s[4]);
64   const __m256i ss0 = _mm256_unpacklo_epi16(s06, s15);
65   const __m256i ss1 = _mm256_unpackhi_epi16(s06, s15);
66   const __m256i ss2 = _mm256_unpacklo_epi16(s24, s[3]);
67   const __m256i ss3 = _mm256_unpackhi_epi16(s24, s[3]);
68   __m256i madds[4];
69   madds[0] = _mm256_madd_epi16(ss0, filter[0]);
70   madds[1] = _mm256_madd_epi16(ss1, filter[0]);
71   madds[2] = _mm256_madd_epi16(ss2, filter[1]);
72   madds[3] = _mm256_madd_epi16(ss3, filter[1]);
73   madds[0] = _mm256_add_epi32(madds[0], madds[2]);
74   madds[1] = _mm256_add_epi32(madds[1], madds[3]);
75   WienerHorizontalClip(madds, wiener_buffer);
76 }
77 
WienerHorizontalTap5Kernel(const __m256i s[5],const __m256i filter,int16_t * const wiener_buffer)78 inline void WienerHorizontalTap5Kernel(const __m256i s[5], const __m256i filter,
79                                        int16_t* const wiener_buffer) {
80   const __m256i s04 = _mm256_add_epi16(s[0], s[4]);
81   const __m256i s13 = _mm256_add_epi16(s[1], s[3]);
82   const __m256i s2d = _mm256_add_epi16(s[2], s[2]);
83   const __m256i s0m = _mm256_sub_epi16(s04, s2d);
84   const __m256i s1m = _mm256_sub_epi16(s13, s2d);
85   const __m256i ss0 = _mm256_unpacklo_epi16(s0m, s1m);
86   const __m256i ss1 = _mm256_unpackhi_epi16(s0m, s1m);
87   __m256i madds[2];
88   madds[0] = _mm256_madd_epi16(ss0, filter);
89   madds[1] = _mm256_madd_epi16(ss1, filter);
90   const __m256i s2_lo = _mm256_unpacklo_epi16(s[2], _mm256_setzero_si256());
91   const __m256i s2_hi = _mm256_unpackhi_epi16(s[2], _mm256_setzero_si256());
92   const __m256i s2x128_lo = _mm256_slli_epi32(s2_lo, 7);
93   const __m256i s2x128_hi = _mm256_slli_epi32(s2_hi, 7);
94   madds[0] = _mm256_add_epi32(madds[0], s2x128_lo);
95   madds[1] = _mm256_add_epi32(madds[1], s2x128_hi);
96   WienerHorizontalClip(madds, wiener_buffer);
97 }
98 
WienerHorizontalTap3Kernel(const __m256i s[3],const __m256i filter,int16_t * const wiener_buffer)99 inline void WienerHorizontalTap3Kernel(const __m256i s[3], const __m256i filter,
100                                        int16_t* const wiener_buffer) {
101   const __m256i s02 = _mm256_add_epi16(s[0], s[2]);
102   const __m256i ss0 = _mm256_unpacklo_epi16(s02, s[1]);
103   const __m256i ss1 = _mm256_unpackhi_epi16(s02, s[1]);
104   __m256i madds[2];
105   madds[0] = _mm256_madd_epi16(ss0, filter);
106   madds[1] = _mm256_madd_epi16(ss1, filter);
107   WienerHorizontalClip(madds, wiener_buffer);
108 }
109 
WienerHorizontalTap7(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i * const coefficients,int16_t ** const wiener_buffer)110 inline void WienerHorizontalTap7(const uint16_t* src,
111                                  const ptrdiff_t src_stride,
112                                  const ptrdiff_t width, const int height,
113                                  const __m256i* const coefficients,
114                                  int16_t** const wiener_buffer) {
115   __m256i filter[2];
116   filter[0] = _mm256_shuffle_epi32(*coefficients, 0x0);
117   filter[1] = _mm256_shuffle_epi32(*coefficients, 0x55);
118   for (int y = height; y != 0; --y) {
119     ptrdiff_t x = 0;
120     do {
121       __m256i s[7];
122       s[0] = LoadUnaligned32(src + x + 0);
123       s[1] = LoadUnaligned32(src + x + 1);
124       s[2] = LoadUnaligned32(src + x + 2);
125       s[3] = LoadUnaligned32(src + x + 3);
126       s[4] = LoadUnaligned32(src + x + 4);
127       s[5] = LoadUnaligned32(src + x + 5);
128       s[6] = LoadUnaligned32(src + x + 6);
129       WienerHorizontalTap7Kernel(s, filter, *wiener_buffer + x);
130       x += 16;
131     } while (x < width);
132     src += src_stride;
133     *wiener_buffer += width;
134   }
135 }
136 
WienerHorizontalTap5(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i * const coefficients,int16_t ** const wiener_buffer)137 inline void WienerHorizontalTap5(const uint16_t* src,
138                                  const ptrdiff_t src_stride,
139                                  const ptrdiff_t width, const int height,
140                                  const __m256i* const coefficients,
141                                  int16_t** const wiener_buffer) {
142   const __m256i filter =
143       _mm256_shuffle_epi8(*coefficients, _mm256_set1_epi32(0x05040302));
144   for (int y = height; y != 0; --y) {
145     ptrdiff_t x = 0;
146     do {
147       __m256i s[5];
148       s[0] = LoadUnaligned32(src + x + 0);
149       s[1] = LoadUnaligned32(src + x + 1);
150       s[2] = LoadUnaligned32(src + x + 2);
151       s[3] = LoadUnaligned32(src + x + 3);
152       s[4] = LoadUnaligned32(src + x + 4);
153       WienerHorizontalTap5Kernel(s, filter, *wiener_buffer + x);
154       x += 16;
155     } while (x < width);
156     src += src_stride;
157     *wiener_buffer += width;
158   }
159 }
160 
WienerHorizontalTap3(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,const __m256i * const coefficients,int16_t ** const wiener_buffer)161 inline void WienerHorizontalTap3(const uint16_t* src,
162                                  const ptrdiff_t src_stride,
163                                  const ptrdiff_t width, const int height,
164                                  const __m256i* const coefficients,
165                                  int16_t** const wiener_buffer) {
166   const auto filter = _mm256_shuffle_epi32(*coefficients, 0x55);
167   for (int y = height; y != 0; --y) {
168     ptrdiff_t x = 0;
169     do {
170       __m256i s[3];
171       s[0] = LoadUnaligned32(src + x + 0);
172       s[1] = LoadUnaligned32(src + x + 1);
173       s[2] = LoadUnaligned32(src + x + 2);
174       WienerHorizontalTap3Kernel(s, filter, *wiener_buffer + x);
175       x += 16;
176     } while (x < width);
177     src += src_stride;
178     *wiener_buffer += width;
179   }
180 }
181 
WienerHorizontalTap1(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const int height,int16_t ** const wiener_buffer)182 inline void WienerHorizontalTap1(const uint16_t* src,
183                                  const ptrdiff_t src_stride,
184                                  const ptrdiff_t width, const int height,
185                                  int16_t** const wiener_buffer) {
186   for (int y = height; y != 0; --y) {
187     ptrdiff_t x = 0;
188     do {
189       const __m256i s0 = LoadUnaligned32(src + x);
190       const __m256i d0 = _mm256_slli_epi16(s0, 4);
191       StoreAligned32(*wiener_buffer + x, d0);
192       x += 16;
193     } while (x < width);
194     src += src_stride;
195     *wiener_buffer += width;
196   }
197 }
198 
WienerVertical7(const __m256i a[4],const __m256i filter[4])199 inline __m256i WienerVertical7(const __m256i a[4], const __m256i filter[4]) {
200   const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]);
201   const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]);
202   const __m256i madd2 = _mm256_madd_epi16(a[2], filter[2]);
203   const __m256i madd3 = _mm256_madd_epi16(a[3], filter[3]);
204   const __m256i madd01 = _mm256_add_epi32(madd0, madd1);
205   const __m256i madd23 = _mm256_add_epi32(madd2, madd3);
206   const __m256i sum = _mm256_add_epi32(madd01, madd23);
207   return _mm256_srai_epi32(sum, kInterRoundBitsVertical);
208 }
209 
WienerVertical5(const __m256i a[3],const __m256i filter[3])210 inline __m256i WienerVertical5(const __m256i a[3], const __m256i filter[3]) {
211   const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]);
212   const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]);
213   const __m256i madd2 = _mm256_madd_epi16(a[2], filter[2]);
214   const __m256i madd01 = _mm256_add_epi32(madd0, madd1);
215   const __m256i sum = _mm256_add_epi32(madd01, madd2);
216   return _mm256_srai_epi32(sum, kInterRoundBitsVertical);
217 }
218 
WienerVertical3(const __m256i a[2],const __m256i filter[2])219 inline __m256i WienerVertical3(const __m256i a[2], const __m256i filter[2]) {
220   const __m256i madd0 = _mm256_madd_epi16(a[0], filter[0]);
221   const __m256i madd1 = _mm256_madd_epi16(a[1], filter[1]);
222   const __m256i sum = _mm256_add_epi32(madd0, madd1);
223   return _mm256_srai_epi32(sum, kInterRoundBitsVertical);
224 }
225 
WienerVerticalClip(const __m256i s[2])226 inline __m256i WienerVerticalClip(const __m256i s[2]) {
227   const __m256i d = _mm256_packus_epi32(s[0], s[1]);
228   return _mm256_min_epu16(d, _mm256_set1_epi16(1023));
229 }
230 
WienerVerticalFilter7(const __m256i a[7],const __m256i filter[2])231 inline __m256i WienerVerticalFilter7(const __m256i a[7],
232                                      const __m256i filter[2]) {
233   const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1));
234   __m256i b[4], c[2];
235   b[0] = _mm256_unpacklo_epi16(a[0], a[1]);
236   b[1] = _mm256_unpacklo_epi16(a[2], a[3]);
237   b[2] = _mm256_unpacklo_epi16(a[4], a[5]);
238   b[3] = _mm256_unpacklo_epi16(a[6], round);
239   c[0] = WienerVertical7(b, filter);
240   b[0] = _mm256_unpackhi_epi16(a[0], a[1]);
241   b[1] = _mm256_unpackhi_epi16(a[2], a[3]);
242   b[2] = _mm256_unpackhi_epi16(a[4], a[5]);
243   b[3] = _mm256_unpackhi_epi16(a[6], round);
244   c[1] = WienerVertical7(b, filter);
245   return WienerVerticalClip(c);
246 }
247 
WienerVerticalFilter5(const __m256i a[5],const __m256i filter[3])248 inline __m256i WienerVerticalFilter5(const __m256i a[5],
249                                      const __m256i filter[3]) {
250   const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1));
251   __m256i b[3], c[2];
252   b[0] = _mm256_unpacklo_epi16(a[0], a[1]);
253   b[1] = _mm256_unpacklo_epi16(a[2], a[3]);
254   b[2] = _mm256_unpacklo_epi16(a[4], round);
255   c[0] = WienerVertical5(b, filter);
256   b[0] = _mm256_unpackhi_epi16(a[0], a[1]);
257   b[1] = _mm256_unpackhi_epi16(a[2], a[3]);
258   b[2] = _mm256_unpackhi_epi16(a[4], round);
259   c[1] = WienerVertical5(b, filter);
260   return WienerVerticalClip(c);
261 }
262 
WienerVerticalFilter3(const __m256i a[3],const __m256i filter[2])263 inline __m256i WienerVerticalFilter3(const __m256i a[3],
264                                      const __m256i filter[2]) {
265   const __m256i round = _mm256_set1_epi16(1 << (kInterRoundBitsVertical - 1));
266   __m256i b[2], c[2];
267   b[0] = _mm256_unpacklo_epi16(a[0], a[1]);
268   b[1] = _mm256_unpacklo_epi16(a[2], round);
269   c[0] = WienerVertical3(b, filter);
270   b[0] = _mm256_unpackhi_epi16(a[0], a[1]);
271   b[1] = _mm256_unpackhi_epi16(a[2], round);
272   c[1] = WienerVertical3(b, filter);
273   return WienerVerticalClip(c);
274 }
275 
WienerVerticalTap7Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i a[7])276 inline __m256i WienerVerticalTap7Kernel(const int16_t* wiener_buffer,
277                                         const ptrdiff_t wiener_stride,
278                                         const __m256i filter[2], __m256i a[7]) {
279   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
280   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
281   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
282   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
283   a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride);
284   a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride);
285   a[6] = LoadAligned32(wiener_buffer + 6 * wiener_stride);
286   return WienerVerticalFilter7(a, filter);
287 }
288 
WienerVerticalTap5Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[3],__m256i a[5])289 inline __m256i WienerVerticalTap5Kernel(const int16_t* wiener_buffer,
290                                         const ptrdiff_t wiener_stride,
291                                         const __m256i filter[3], __m256i a[5]) {
292   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
293   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
294   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
295   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
296   a[4] = LoadAligned32(wiener_buffer + 4 * wiener_stride);
297   return WienerVerticalFilter5(a, filter);
298 }
299 
WienerVerticalTap3Kernel(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i a[3])300 inline __m256i WienerVerticalTap3Kernel(const int16_t* wiener_buffer,
301                                         const ptrdiff_t wiener_stride,
302                                         const __m256i filter[2], __m256i a[3]) {
303   a[0] = LoadAligned32(wiener_buffer + 0 * wiener_stride);
304   a[1] = LoadAligned32(wiener_buffer + 1 * wiener_stride);
305   a[2] = LoadAligned32(wiener_buffer + 2 * wiener_stride);
306   return WienerVerticalFilter3(a, filter);
307 }
308 
WienerVerticalTap7Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i d[2])309 inline void WienerVerticalTap7Kernel2(const int16_t* wiener_buffer,
310                                       const ptrdiff_t wiener_stride,
311                                       const __m256i filter[2], __m256i d[2]) {
312   __m256i a[8];
313   d[0] = WienerVerticalTap7Kernel(wiener_buffer, wiener_stride, filter, a);
314   a[7] = LoadAligned32(wiener_buffer + 7 * wiener_stride);
315   d[1] = WienerVerticalFilter7(a + 1, filter);
316 }
317 
WienerVerticalTap5Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[3],__m256i d[2])318 inline void WienerVerticalTap5Kernel2(const int16_t* wiener_buffer,
319                                       const ptrdiff_t wiener_stride,
320                                       const __m256i filter[3], __m256i d[2]) {
321   __m256i a[6];
322   d[0] = WienerVerticalTap5Kernel(wiener_buffer, wiener_stride, filter, a);
323   a[5] = LoadAligned32(wiener_buffer + 5 * wiener_stride);
324   d[1] = WienerVerticalFilter5(a + 1, filter);
325 }
326 
WienerVerticalTap3Kernel2(const int16_t * wiener_buffer,const ptrdiff_t wiener_stride,const __m256i filter[2],__m256i d[2])327 inline void WienerVerticalTap3Kernel2(const int16_t* wiener_buffer,
328                                       const ptrdiff_t wiener_stride,
329                                       const __m256i filter[2], __m256i d[2]) {
330   __m256i a[4];
331   d[0] = WienerVerticalTap3Kernel(wiener_buffer, wiener_stride, filter, a);
332   a[3] = LoadAligned32(wiener_buffer + 3 * wiener_stride);
333   d[1] = WienerVerticalFilter3(a + 1, filter);
334 }
335 
WienerVerticalTap7(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[4],uint16_t * dst,const ptrdiff_t dst_stride)336 inline void WienerVerticalTap7(const int16_t* wiener_buffer,
337                                const ptrdiff_t width, const int height,
338                                const int16_t coefficients[4], uint16_t* dst,
339                                const ptrdiff_t dst_stride) {
340   const __m256i c = _mm256_broadcastq_epi64(LoadLo8(coefficients));
341   __m256i filter[4];
342   filter[0] = _mm256_shuffle_epi32(c, 0x0);
343   filter[1] = _mm256_shuffle_epi32(c, 0x55);
344   filter[2] = _mm256_shuffle_epi8(c, _mm256_set1_epi32(0x03020504));
345   filter[3] =
346       _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0]));
347   for (int y = height >> 1; y > 0; --y) {
348     ptrdiff_t x = 0;
349     do {
350       __m256i d[2];
351       WienerVerticalTap7Kernel2(wiener_buffer + x, width, filter, d);
352       StoreUnaligned32(dst + x, d[0]);
353       StoreUnaligned32(dst + dst_stride + x, d[1]);
354       x += 16;
355     } while (x < width);
356     dst += 2 * dst_stride;
357     wiener_buffer += 2 * width;
358   }
359 
360   if ((height & 1) != 0) {
361     ptrdiff_t x = 0;
362     do {
363       __m256i a[7];
364       const __m256i d =
365           WienerVerticalTap7Kernel(wiener_buffer + x, width, filter, a);
366       StoreUnaligned32(dst + x, d);
367       x += 16;
368     } while (x < width);
369   }
370 }
371 
WienerVerticalTap5(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[3],uint16_t * dst,const ptrdiff_t dst_stride)372 inline void WienerVerticalTap5(const int16_t* wiener_buffer,
373                                const ptrdiff_t width, const int height,
374                                const int16_t coefficients[3], uint16_t* dst,
375                                const ptrdiff_t dst_stride) {
376   const __m256i c = _mm256_broadcastq_epi64(LoadLo8(coefficients));
377   __m256i filter[3];
378   filter[0] = _mm256_shuffle_epi32(c, 0x0);
379   filter[1] = _mm256_shuffle_epi8(c, _mm256_set1_epi32(0x03020504));
380   filter[2] =
381       _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0]));
382   for (int y = height >> 1; y > 0; --y) {
383     ptrdiff_t x = 0;
384     do {
385       __m256i d[2];
386       WienerVerticalTap5Kernel2(wiener_buffer + x, width, filter, d);
387       StoreUnaligned32(dst + x, d[0]);
388       StoreUnaligned32(dst + dst_stride + x, d[1]);
389       x += 16;
390     } while (x < width);
391     dst += 2 * dst_stride;
392     wiener_buffer += 2 * width;
393   }
394 
395   if ((height & 1) != 0) {
396     ptrdiff_t x = 0;
397     do {
398       __m256i a[5];
399       const __m256i d =
400           WienerVerticalTap5Kernel(wiener_buffer + x, width, filter, a);
401       StoreUnaligned32(dst + x, d);
402       x += 16;
403     } while (x < width);
404   }
405 }
406 
WienerVerticalTap3(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,const int16_t coefficients[2],uint16_t * dst,const ptrdiff_t dst_stride)407 inline void WienerVerticalTap3(const int16_t* wiener_buffer,
408                                const ptrdiff_t width, const int height,
409                                const int16_t coefficients[2], uint16_t* dst,
410                                const ptrdiff_t dst_stride) {
411   __m256i filter[2];
412   filter[0] =
413       _mm256_set1_epi32(*reinterpret_cast<const int32_t*>(coefficients));
414   filter[1] =
415       _mm256_set1_epi32((1 << 16) | static_cast<uint16_t>(coefficients[0]));
416   for (int y = height >> 1; y > 0; --y) {
417     ptrdiff_t x = 0;
418     do {
419       __m256i d[2][2];
420       WienerVerticalTap3Kernel2(wiener_buffer + x, width, filter, d[0]);
421       StoreUnaligned32(dst + x, d[0][0]);
422       StoreUnaligned32(dst + dst_stride + x, d[0][1]);
423       x += 16;
424     } while (x < width);
425     dst += 2 * dst_stride;
426     wiener_buffer += 2 * width;
427   }
428 
429   if ((height & 1) != 0) {
430     ptrdiff_t x = 0;
431     do {
432       __m256i a[3];
433       const __m256i d =
434           WienerVerticalTap3Kernel(wiener_buffer + x, width, filter, a);
435       StoreUnaligned32(dst + x, d);
436       x += 16;
437     } while (x < width);
438   }
439 }
440 
WienerVerticalTap1Kernel(const int16_t * const wiener_buffer,uint16_t * const dst)441 inline void WienerVerticalTap1Kernel(const int16_t* const wiener_buffer,
442                                      uint16_t* const dst) {
443   const __m256i a = LoadAligned32(wiener_buffer);
444   const __m256i b = _mm256_add_epi16(a, _mm256_set1_epi16(8));
445   const __m256i c = _mm256_srai_epi16(b, 4);
446   const __m256i d = _mm256_max_epi16(c, _mm256_setzero_si256());
447   const __m256i e = _mm256_min_epi16(d, _mm256_set1_epi16(1023));
448   StoreUnaligned32(dst, e);
449 }
450 
WienerVerticalTap1(const int16_t * wiener_buffer,const ptrdiff_t width,const int height,uint16_t * dst,const ptrdiff_t dst_stride)451 inline void WienerVerticalTap1(const int16_t* wiener_buffer,
452                                const ptrdiff_t width, const int height,
453                                uint16_t* dst, const ptrdiff_t dst_stride) {
454   for (int y = height >> 1; y > 0; --y) {
455     ptrdiff_t x = 0;
456     do {
457       WienerVerticalTap1Kernel(wiener_buffer + x, dst + x);
458       WienerVerticalTap1Kernel(wiener_buffer + width + x, dst + dst_stride + x);
459       x += 16;
460     } while (x < width);
461     dst += 2 * dst_stride;
462     wiener_buffer += 2 * width;
463   }
464 
465   if ((height & 1) != 0) {
466     ptrdiff_t x = 0;
467     do {
468       WienerVerticalTap1Kernel(wiener_buffer + x, dst + x);
469       x += 16;
470     } while (x < width);
471   }
472 }
473 
WienerFilter_AVX2(const RestorationUnitInfo & LIBGAV1_RESTRICT restoration_info,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_border,const ptrdiff_t top_border_stride,const void * LIBGAV1_RESTRICT const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * LIBGAV1_RESTRICT const restoration_buffer,void * LIBGAV1_RESTRICT const dest)474 void WienerFilter_AVX2(
475     const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
476     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
477     const void* LIBGAV1_RESTRICT const top_border,
478     const ptrdiff_t top_border_stride,
479     const void* LIBGAV1_RESTRICT const bottom_border,
480     const ptrdiff_t bottom_border_stride, const int width, const int height,
481     RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
482     void* LIBGAV1_RESTRICT const dest) {
483   const int16_t* const number_leading_zero_coefficients =
484       restoration_info.wiener_info.number_leading_zero_coefficients;
485   const int number_rows_to_skip = std::max(
486       static_cast<int>(number_leading_zero_coefficients[WienerInfo::kVertical]),
487       1);
488   const ptrdiff_t wiener_stride = Align(width, 16);
489   int16_t* const wiener_buffer_vertical = restoration_buffer->wiener_buffer;
490   // The values are saturated to 13 bits before storing.
491   int16_t* wiener_buffer_horizontal =
492       wiener_buffer_vertical + number_rows_to_skip * wiener_stride;
493 
494   // horizontal filtering.
495   // Over-reads up to 15 - |kRestorationHorizontalBorder| values.
496   const int height_horizontal =
497       height + kWienerFilterTaps - 1 - 2 * number_rows_to_skip;
498   const int height_extra = (height_horizontal - height) >> 1;
499   assert(height_extra <= 2);
500   const auto* const src = static_cast<const uint16_t*>(source);
501   const auto* const top = static_cast<const uint16_t*>(top_border);
502   const auto* const bottom = static_cast<const uint16_t*>(bottom_border);
503   const __m128i c =
504       LoadLo8(restoration_info.wiener_info.filter[WienerInfo::kHorizontal]);
505   const __m256i coefficients_horizontal = _mm256_broadcastq_epi64(c);
506   if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 0) {
507     WienerHorizontalTap7(top + (2 - height_extra) * top_border_stride - 3,
508                          top_border_stride, wiener_stride, height_extra,
509                          &coefficients_horizontal, &wiener_buffer_horizontal);
510     WienerHorizontalTap7(src - 3, stride, wiener_stride, height,
511                          &coefficients_horizontal, &wiener_buffer_horizontal);
512     WienerHorizontalTap7(bottom - 3, bottom_border_stride, wiener_stride,
513                          height_extra, &coefficients_horizontal,
514                          &wiener_buffer_horizontal);
515   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 1) {
516     WienerHorizontalTap5(top + (2 - height_extra) * top_border_stride - 2,
517                          top_border_stride, wiener_stride, height_extra,
518                          &coefficients_horizontal, &wiener_buffer_horizontal);
519     WienerHorizontalTap5(src - 2, stride, wiener_stride, height,
520                          &coefficients_horizontal, &wiener_buffer_horizontal);
521     WienerHorizontalTap5(bottom - 2, bottom_border_stride, wiener_stride,
522                          height_extra, &coefficients_horizontal,
523                          &wiener_buffer_horizontal);
524   } else if (number_leading_zero_coefficients[WienerInfo::kHorizontal] == 2) {
525     // The maximum over-reads happen here.
526     WienerHorizontalTap3(top + (2 - height_extra) * top_border_stride - 1,
527                          top_border_stride, wiener_stride, height_extra,
528                          &coefficients_horizontal, &wiener_buffer_horizontal);
529     WienerHorizontalTap3(src - 1, stride, wiener_stride, height,
530                          &coefficients_horizontal, &wiener_buffer_horizontal);
531     WienerHorizontalTap3(bottom - 1, bottom_border_stride, wiener_stride,
532                          height_extra, &coefficients_horizontal,
533                          &wiener_buffer_horizontal);
534   } else {
535     assert(number_leading_zero_coefficients[WienerInfo::kHorizontal] == 3);
536     WienerHorizontalTap1(top + (2 - height_extra) * top_border_stride,
537                          top_border_stride, wiener_stride, height_extra,
538                          &wiener_buffer_horizontal);
539     WienerHorizontalTap1(src, stride, wiener_stride, height,
540                          &wiener_buffer_horizontal);
541     WienerHorizontalTap1(bottom, bottom_border_stride, wiener_stride,
542                          height_extra, &wiener_buffer_horizontal);
543   }
544 
545   // vertical filtering.
546   // Over-writes up to 15 values.
547   const int16_t* const filter_vertical =
548       restoration_info.wiener_info.filter[WienerInfo::kVertical];
549   auto* dst = static_cast<uint16_t*>(dest);
550   if (number_leading_zero_coefficients[WienerInfo::kVertical] == 0) {
551     // Because the top row of |source| is a duplicate of the second row, and the
552     // bottom row of |source| is a duplicate of its above row, we can duplicate
553     // the top and bottom row of |wiener_buffer| accordingly.
554     memcpy(wiener_buffer_horizontal, wiener_buffer_horizontal - wiener_stride,
555            sizeof(*wiener_buffer_horizontal) * wiener_stride);
556     memcpy(restoration_buffer->wiener_buffer,
557            restoration_buffer->wiener_buffer + wiener_stride,
558            sizeof(*restoration_buffer->wiener_buffer) * wiener_stride);
559     WienerVerticalTap7(wiener_buffer_vertical, wiener_stride, height,
560                        filter_vertical, dst, stride);
561   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 1) {
562     WienerVerticalTap5(wiener_buffer_vertical + wiener_stride, wiener_stride,
563                        height, filter_vertical + 1, dst, stride);
564   } else if (number_leading_zero_coefficients[WienerInfo::kVertical] == 2) {
565     WienerVerticalTap3(wiener_buffer_vertical + 2 * wiener_stride,
566                        wiener_stride, height, filter_vertical + 2, dst, stride);
567   } else {
568     assert(number_leading_zero_coefficients[WienerInfo::kVertical] == 3);
569     WienerVerticalTap1(wiener_buffer_vertical + 3 * wiener_stride,
570                        wiener_stride, height, dst, stride);
571   }
572 }
573 
574 //------------------------------------------------------------------------------
575 // SGR
576 
577 constexpr int kSumOffset = 24;
578 
579 // SIMD overreads the number of pixels in SIMD registers - (width % 8) - 2 *
580 // padding pixels, where padding is 3 for Pass 1 and 2 for Pass 2. The number of
581 // bytes in SIMD registers is 16 for SSE4.1 and 32 for AVX2.
582 constexpr int kOverreadInBytesPass1_128 = 4;
583 constexpr int kOverreadInBytesPass2_128 = 8;
584 constexpr int kOverreadInBytesPass1_256 = kOverreadInBytesPass1_128 + 16;
585 constexpr int kOverreadInBytesPass2_256 = kOverreadInBytesPass2_128 + 16;
586 
LoadAligned16x2U16(const uint16_t * const src[2],const ptrdiff_t x,__m128i dst[2])587 inline void LoadAligned16x2U16(const uint16_t* const src[2], const ptrdiff_t x,
588                                __m128i dst[2]) {
589   dst[0] = LoadAligned16(src[0] + x);
590   dst[1] = LoadAligned16(src[1] + x);
591 }
592 
LoadAligned32x2U16(const uint16_t * const src[2],const ptrdiff_t x,__m256i dst[2])593 inline void LoadAligned32x2U16(const uint16_t* const src[2], const ptrdiff_t x,
594                                __m256i dst[2]) {
595   dst[0] = LoadAligned32(src[0] + x);
596   dst[1] = LoadAligned32(src[1] + x);
597 }
598 
LoadAligned32x2U16Msan(const uint16_t * const src[2],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[2])599 inline void LoadAligned32x2U16Msan(const uint16_t* const src[2],
600                                    const ptrdiff_t x, const ptrdiff_t border,
601                                    __m256i dst[2]) {
602   dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border));
603   dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border));
604 }
605 
LoadAligned16x3U16(const uint16_t * const src[3],const ptrdiff_t x,__m128i dst[3])606 inline void LoadAligned16x3U16(const uint16_t* const src[3], const ptrdiff_t x,
607                                __m128i dst[3]) {
608   dst[0] = LoadAligned16(src[0] + x);
609   dst[1] = LoadAligned16(src[1] + x);
610   dst[2] = LoadAligned16(src[2] + x);
611 }
612 
LoadAligned32x3U16(const uint16_t * const src[3],const ptrdiff_t x,__m256i dst[3])613 inline void LoadAligned32x3U16(const uint16_t* const src[3], const ptrdiff_t x,
614                                __m256i dst[3]) {
615   dst[0] = LoadAligned32(src[0] + x);
616   dst[1] = LoadAligned32(src[1] + x);
617   dst[2] = LoadAligned32(src[2] + x);
618 }
619 
LoadAligned32x3U16Msan(const uint16_t * const src[3],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[3])620 inline void LoadAligned32x3U16Msan(const uint16_t* const src[3],
621                                    const ptrdiff_t x, const ptrdiff_t border,
622                                    __m256i dst[3]) {
623   dst[0] = LoadAligned32Msan(src[0] + x, sizeof(**src) * (x + 16 - border));
624   dst[1] = LoadAligned32Msan(src[1] + x, sizeof(**src) * (x + 16 - border));
625   dst[2] = LoadAligned32Msan(src[2] + x, sizeof(**src) * (x + 16 - border));
626 }
627 
LoadAligned32U32(const uint32_t * const src,__m128i dst[2])628 inline void LoadAligned32U32(const uint32_t* const src, __m128i dst[2]) {
629   dst[0] = LoadAligned16(src + 0);
630   dst[1] = LoadAligned16(src + 4);
631 }
632 
LoadAligned32x2U32(const uint32_t * const src[2],const ptrdiff_t x,__m128i dst[2][2])633 inline void LoadAligned32x2U32(const uint32_t* const src[2], const ptrdiff_t x,
634                                __m128i dst[2][2]) {
635   LoadAligned32U32(src[0] + x, dst[0]);
636   LoadAligned32U32(src[1] + x, dst[1]);
637 }
638 
LoadAligned64x2U32(const uint32_t * const src[2],const ptrdiff_t x,__m256i dst[2][2])639 inline void LoadAligned64x2U32(const uint32_t* const src[2], const ptrdiff_t x,
640                                __m256i dst[2][2]) {
641   LoadAligned64(src[0] + x, dst[0]);
642   LoadAligned64(src[1] + x, dst[1]);
643 }
644 
LoadAligned64x2U32Msan(const uint32_t * const src[2],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[2][2])645 inline void LoadAligned64x2U32Msan(const uint32_t* const src[2],
646                                    const ptrdiff_t x, const ptrdiff_t border,
647                                    __m256i dst[2][2]) {
648   LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]);
649   LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]);
650 }
651 
LoadAligned32x3U32(const uint32_t * const src[3],const ptrdiff_t x,__m128i dst[3][2])652 inline void LoadAligned32x3U32(const uint32_t* const src[3], const ptrdiff_t x,
653                                __m128i dst[3][2]) {
654   LoadAligned32U32(src[0] + x, dst[0]);
655   LoadAligned32U32(src[1] + x, dst[1]);
656   LoadAligned32U32(src[2] + x, dst[2]);
657 }
658 
LoadAligned64x3U32(const uint32_t * const src[3],const ptrdiff_t x,__m256i dst[3][2])659 inline void LoadAligned64x3U32(const uint32_t* const src[3], const ptrdiff_t x,
660                                __m256i dst[3][2]) {
661   LoadAligned64(src[0] + x, dst[0]);
662   LoadAligned64(src[1] + x, dst[1]);
663   LoadAligned64(src[2] + x, dst[2]);
664 }
665 
LoadAligned64x3U32Msan(const uint32_t * const src[3],const ptrdiff_t x,const ptrdiff_t border,__m256i dst[3][2])666 inline void LoadAligned64x3U32Msan(const uint32_t* const src[3],
667                                    const ptrdiff_t x, const ptrdiff_t border,
668                                    __m256i dst[3][2]) {
669   LoadAligned64Msan(src[0] + x, sizeof(**src) * (x + 16 - border), dst[0]);
670   LoadAligned64Msan(src[1] + x, sizeof(**src) * (x + 16 - border), dst[1]);
671   LoadAligned64Msan(src[2] + x, sizeof(**src) * (x + 16 - border), dst[2]);
672 }
673 
StoreAligned32U32(uint32_t * const dst,const __m128i src[2])674 inline void StoreAligned32U32(uint32_t* const dst, const __m128i src[2]) {
675   StoreAligned16(dst + 0, src[0]);
676   StoreAligned16(dst + 4, src[1]);
677 }
678 
679 // The AVX2 ymm register holds ma[0], ma[1], ..., ma[7], and ma[16], ma[17],
680 // ..., ma[23].
681 // There is an 8 pixel gap between the first half and the second half.
682 constexpr int kMaStoreOffset = 8;
683 
StoreAligned32_ma(uint16_t * src,const __m256i v)684 inline void StoreAligned32_ma(uint16_t* src, const __m256i v) {
685   StoreAligned16(src + 0 * 8, _mm256_extracti128_si256(v, 0));
686   StoreAligned16(src + 2 * 8, _mm256_extracti128_si256(v, 1));
687 }
688 
StoreAligned64_ma(uint16_t * src,const __m256i v[2])689 inline void StoreAligned64_ma(uint16_t* src, const __m256i v[2]) {
690   // The next 4 lines are much faster than:
691   // StoreAligned32(src + 0, _mm256_permute2x128_si256(v[0], v[1], 0x20));
692   // StoreAligned32(src + 16, _mm256_permute2x128_si256(v[0], v[1], 0x31));
693   StoreAligned16(src + 0 * 8, _mm256_extracti128_si256(v[0], 0));
694   StoreAligned16(src + 1 * 8, _mm256_extracti128_si256(v[1], 0));
695   StoreAligned16(src + 2 * 8, _mm256_extracti128_si256(v[0], 1));
696   StoreAligned16(src + 3 * 8, _mm256_extracti128_si256(v[1], 1));
697 }
698 
699 // Don't use _mm_cvtepu8_epi16() or _mm_cvtepu16_epi32() in the following
700 // functions. Some compilers may generate super inefficient code and the whole
701 // decoder could be 15% slower.
702 
VaddlLo8(const __m256i src0,const __m256i src1)703 inline __m256i VaddlLo8(const __m256i src0, const __m256i src1) {
704   const __m256i s0 = _mm256_unpacklo_epi8(src0, _mm256_setzero_si256());
705   const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256());
706   return _mm256_add_epi16(s0, s1);
707 }
708 
VaddlHi8(const __m256i src0,const __m256i src1)709 inline __m256i VaddlHi8(const __m256i src0, const __m256i src1) {
710   const __m256i s0 = _mm256_unpackhi_epi8(src0, _mm256_setzero_si256());
711   const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256());
712   return _mm256_add_epi16(s0, s1);
713 }
714 
VaddwLo8(const __m256i src0,const __m256i src1)715 inline __m256i VaddwLo8(const __m256i src0, const __m256i src1) {
716   const __m256i s1 = _mm256_unpacklo_epi8(src1, _mm256_setzero_si256());
717   return _mm256_add_epi16(src0, s1);
718 }
719 
VaddwHi8(const __m256i src0,const __m256i src1)720 inline __m256i VaddwHi8(const __m256i src0, const __m256i src1) {
721   const __m256i s1 = _mm256_unpackhi_epi8(src1, _mm256_setzero_si256());
722   return _mm256_add_epi16(src0, s1);
723 }
724 
VmullNLo8(const __m256i src0,const int src1)725 inline __m256i VmullNLo8(const __m256i src0, const int src1) {
726   const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256());
727   return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1));
728 }
729 
VmullNHi8(const __m256i src0,const int src1)730 inline __m256i VmullNHi8(const __m256i src0, const int src1) {
731   const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256());
732   return _mm256_madd_epi16(s0, _mm256_set1_epi32(src1));
733 }
734 
VmullLo16(const __m128i src0,const __m128i src1)735 inline __m128i VmullLo16(const __m128i src0, const __m128i src1) {
736   const __m128i s0 = _mm_unpacklo_epi16(src0, _mm_setzero_si128());
737   const __m128i s1 = _mm_unpacklo_epi16(src1, _mm_setzero_si128());
738   return _mm_madd_epi16(s0, s1);
739 }
740 
VmullLo16(const __m256i src0,const __m256i src1)741 inline __m256i VmullLo16(const __m256i src0, const __m256i src1) {
742   const __m256i s0 = _mm256_unpacklo_epi16(src0, _mm256_setzero_si256());
743   const __m256i s1 = _mm256_unpacklo_epi16(src1, _mm256_setzero_si256());
744   return _mm256_madd_epi16(s0, s1);
745 }
746 
VmullHi16(const __m128i src0,const __m128i src1)747 inline __m128i VmullHi16(const __m128i src0, const __m128i src1) {
748   const __m128i s0 = _mm_unpackhi_epi16(src0, _mm_setzero_si128());
749   const __m128i s1 = _mm_unpackhi_epi16(src1, _mm_setzero_si128());
750   return _mm_madd_epi16(s0, s1);
751 }
752 
VmullHi16(const __m256i src0,const __m256i src1)753 inline __m256i VmullHi16(const __m256i src0, const __m256i src1) {
754   const __m256i s0 = _mm256_unpackhi_epi16(src0, _mm256_setzero_si256());
755   const __m256i s1 = _mm256_unpackhi_epi16(src1, _mm256_setzero_si256());
756   return _mm256_madd_epi16(s0, s1);
757 }
758 
VrshrU16(const __m128i src0,const int src1)759 inline __m128i VrshrU16(const __m128i src0, const int src1) {
760   const __m128i sum = _mm_add_epi16(src0, _mm_set1_epi16(1 << (src1 - 1)));
761   return _mm_srli_epi16(sum, src1);
762 }
763 
VrshrU16(const __m256i src0,const int src1)764 inline __m256i VrshrU16(const __m256i src0, const int src1) {
765   const __m256i sum =
766       _mm256_add_epi16(src0, _mm256_set1_epi16(1 << (src1 - 1)));
767   return _mm256_srli_epi16(sum, src1);
768 }
769 
VrshrS32(const __m256i src0,const int src1)770 inline __m256i VrshrS32(const __m256i src0, const int src1) {
771   const __m256i sum =
772       _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1)));
773   return _mm256_srai_epi32(sum, src1);
774 }
775 
VrshrU32(const __m128i src0,const int src1)776 inline __m128i VrshrU32(const __m128i src0, const int src1) {
777   const __m128i sum = _mm_add_epi32(src0, _mm_set1_epi32(1 << (src1 - 1)));
778   return _mm_srli_epi32(sum, src1);
779 }
780 
VrshrU32(const __m256i src0,const int src1)781 inline __m256i VrshrU32(const __m256i src0, const int src1) {
782   const __m256i sum =
783       _mm256_add_epi32(src0, _mm256_set1_epi32(1 << (src1 - 1)));
784   return _mm256_srli_epi32(sum, src1);
785 }
786 
Square(const __m128i src,__m128i dst[2])787 inline void Square(const __m128i src, __m128i dst[2]) {
788   const __m128i s0 = _mm_unpacklo_epi16(src, _mm_setzero_si128());
789   const __m128i s1 = _mm_unpackhi_epi16(src, _mm_setzero_si128());
790   dst[0] = _mm_madd_epi16(s0, s0);
791   dst[1] = _mm_madd_epi16(s1, s1);
792 }
793 
Square(const __m256i src,__m256i dst[2])794 inline void Square(const __m256i src, __m256i dst[2]) {
795   const __m256i s0 = _mm256_unpacklo_epi16(src, _mm256_setzero_si256());
796   const __m256i s1 = _mm256_unpackhi_epi16(src, _mm256_setzero_si256());
797   dst[0] = _mm256_madd_epi16(s0, s0);
798   dst[1] = _mm256_madd_epi16(s1, s1);
799 }
800 
Prepare3_8(const __m256i src[2],__m256i dst[3])801 inline void Prepare3_8(const __m256i src[2], __m256i dst[3]) {
802   dst[0] = _mm256_alignr_epi8(src[1], src[0], 0);
803   dst[1] = _mm256_alignr_epi8(src[1], src[0], 1);
804   dst[2] = _mm256_alignr_epi8(src[1], src[0], 2);
805 }
806 
Prepare3_16(const __m128i src[2],__m128i dst[3])807 inline void Prepare3_16(const __m128i src[2], __m128i dst[3]) {
808   dst[0] = src[0];
809   dst[1] = _mm_alignr_epi8(src[1], src[0], 2);
810   dst[2] = _mm_alignr_epi8(src[1], src[0], 4);
811 }
812 
Prepare3_32(const __m128i src[2],__m128i dst[3])813 inline void Prepare3_32(const __m128i src[2], __m128i dst[3]) {
814   dst[0] = src[0];
815   dst[1] = _mm_alignr_epi8(src[1], src[0], 4);
816   dst[2] = _mm_alignr_epi8(src[1], src[0], 8);
817 }
818 
Prepare3_32(const __m256i src[2],__m256i dst[3])819 inline void Prepare3_32(const __m256i src[2], __m256i dst[3]) {
820   dst[0] = src[0];
821   dst[1] = _mm256_alignr_epi8(src[1], src[0], 4);
822   dst[2] = _mm256_alignr_epi8(src[1], src[0], 8);
823 }
824 
Prepare5_16(const __m128i src[2],__m128i dst[5])825 inline void Prepare5_16(const __m128i src[2], __m128i dst[5]) {
826   Prepare3_16(src, dst);
827   dst[3] = _mm_alignr_epi8(src[1], src[0], 6);
828   dst[4] = _mm_alignr_epi8(src[1], src[0], 8);
829 }
830 
Prepare5_32(const __m128i src[2],__m128i dst[5])831 inline void Prepare5_32(const __m128i src[2], __m128i dst[5]) {
832   Prepare3_32(src, dst);
833   dst[3] = _mm_alignr_epi8(src[1], src[0], 12);
834   dst[4] = src[1];
835 }
836 
Prepare5_32(const __m256i src[2],__m256i dst[5])837 inline void Prepare5_32(const __m256i src[2], __m256i dst[5]) {
838   Prepare3_32(src, dst);
839   dst[3] = _mm256_alignr_epi8(src[1], src[0], 12);
840   dst[4] = src[1];
841 }
842 
Sum3_16(const __m128i src0,const __m128i src1,const __m128i src2)843 inline __m128i Sum3_16(const __m128i src0, const __m128i src1,
844                        const __m128i src2) {
845   const __m128i sum = _mm_add_epi16(src0, src1);
846   return _mm_add_epi16(sum, src2);
847 }
848 
Sum3_16(const __m256i src0,const __m256i src1,const __m256i src2)849 inline __m256i Sum3_16(const __m256i src0, const __m256i src1,
850                        const __m256i src2) {
851   const __m256i sum = _mm256_add_epi16(src0, src1);
852   return _mm256_add_epi16(sum, src2);
853 }
854 
Sum3_16(const __m128i src[3])855 inline __m128i Sum3_16(const __m128i src[3]) {
856   return Sum3_16(src[0], src[1], src[2]);
857 }
858 
Sum3_16(const __m256i src[3])859 inline __m256i Sum3_16(const __m256i src[3]) {
860   return Sum3_16(src[0], src[1], src[2]);
861 }
862 
Sum3_32(const __m128i src0,const __m128i src1,const __m128i src2)863 inline __m128i Sum3_32(const __m128i src0, const __m128i src1,
864                        const __m128i src2) {
865   const __m128i sum = _mm_add_epi32(src0, src1);
866   return _mm_add_epi32(sum, src2);
867 }
868 
Sum3_32(const __m256i src0,const __m256i src1,const __m256i src2)869 inline __m256i Sum3_32(const __m256i src0, const __m256i src1,
870                        const __m256i src2) {
871   const __m256i sum = _mm256_add_epi32(src0, src1);
872   return _mm256_add_epi32(sum, src2);
873 }
874 
Sum3_32(const __m128i src[3])875 inline __m128i Sum3_32(const __m128i src[3]) {
876   return Sum3_32(src[0], src[1], src[2]);
877 }
878 
Sum3_32(const __m256i src[3])879 inline __m256i Sum3_32(const __m256i src[3]) {
880   return Sum3_32(src[0], src[1], src[2]);
881 }
882 
Sum3_32(const __m128i src[3][2],__m128i dst[2])883 inline void Sum3_32(const __m128i src[3][2], __m128i dst[2]) {
884   dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]);
885   dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]);
886 }
887 
Sum3_32(const __m256i src[3][2],__m256i dst[2])888 inline void Sum3_32(const __m256i src[3][2], __m256i dst[2]) {
889   dst[0] = Sum3_32(src[0][0], src[1][0], src[2][0]);
890   dst[1] = Sum3_32(src[0][1], src[1][1], src[2][1]);
891 }
892 
Sum3WLo16(const __m256i src[3])893 inline __m256i Sum3WLo16(const __m256i src[3]) {
894   const __m256i sum = VaddlLo8(src[0], src[1]);
895   return VaddwLo8(sum, src[2]);
896 }
897 
Sum3WHi16(const __m256i src[3])898 inline __m256i Sum3WHi16(const __m256i src[3]) {
899   const __m256i sum = VaddlHi8(src[0], src[1]);
900   return VaddwHi8(sum, src[2]);
901 }
902 
Sum5_16(const __m128i src[5])903 inline __m128i Sum5_16(const __m128i src[5]) {
904   const __m128i sum01 = _mm_add_epi16(src[0], src[1]);
905   const __m128i sum23 = _mm_add_epi16(src[2], src[3]);
906   const __m128i sum = _mm_add_epi16(sum01, sum23);
907   return _mm_add_epi16(sum, src[4]);
908 }
909 
Sum5_16(const __m256i src[5])910 inline __m256i Sum5_16(const __m256i src[5]) {
911   const __m256i sum01 = _mm256_add_epi16(src[0], src[1]);
912   const __m256i sum23 = _mm256_add_epi16(src[2], src[3]);
913   const __m256i sum = _mm256_add_epi16(sum01, sum23);
914   return _mm256_add_epi16(sum, src[4]);
915 }
916 
Sum5_32(const __m128i * const src0,const __m128i * const src1,const __m128i * const src2,const __m128i * const src3,const __m128i * const src4)917 inline __m128i Sum5_32(const __m128i* const src0, const __m128i* const src1,
918                        const __m128i* const src2, const __m128i* const src3,
919                        const __m128i* const src4) {
920   const __m128i sum01 = _mm_add_epi32(*src0, *src1);
921   const __m128i sum23 = _mm_add_epi32(*src2, *src3);
922   const __m128i sum = _mm_add_epi32(sum01, sum23);
923   return _mm_add_epi32(sum, *src4);
924 }
925 
Sum5_32(const __m256i * const src0,const __m256i * const src1,const __m256i * const src2,const __m256i * const src3,const __m256i * const src4)926 inline __m256i Sum5_32(const __m256i* const src0, const __m256i* const src1,
927                        const __m256i* const src2, const __m256i* const src3,
928                        const __m256i* const src4) {
929   const __m256i sum01 = _mm256_add_epi32(*src0, *src1);
930   const __m256i sum23 = _mm256_add_epi32(*src2, *src3);
931   const __m256i sum = _mm256_add_epi32(sum01, sum23);
932   return _mm256_add_epi32(sum, *src4);
933 }
934 
Sum5_32(const __m128i src[5])935 inline __m128i Sum5_32(const __m128i src[5]) {
936   return Sum5_32(&src[0], &src[1], &src[2], &src[3], &src[4]);
937 }
938 
Sum5_32(const __m256i src[5])939 inline __m256i Sum5_32(const __m256i src[5]) {
940   return Sum5_32(&src[0], &src[1], &src[2], &src[3], &src[4]);
941 }
942 
Sum5_32(const __m128i src[5][2],__m128i dst[2])943 inline void Sum5_32(const __m128i src[5][2], __m128i dst[2]) {
944   dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]);
945   dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]);
946 }
947 
Sum5_32(const __m256i src[5][2],__m256i dst[2])948 inline void Sum5_32(const __m256i src[5][2], __m256i dst[2]) {
949   dst[0] = Sum5_32(&src[0][0], &src[1][0], &src[2][0], &src[3][0], &src[4][0]);
950   dst[1] = Sum5_32(&src[0][1], &src[1][1], &src[2][1], &src[3][1], &src[4][1]);
951 }
952 
Sum3Horizontal16(const __m128i src[2])953 inline __m128i Sum3Horizontal16(const __m128i src[2]) {
954   __m128i s[3];
955   Prepare3_16(src, s);
956   return Sum3_16(s);
957 }
958 
Sum3Horizontal16(const uint16_t * const src,const ptrdiff_t over_read_in_bytes)959 inline __m256i Sum3Horizontal16(const uint16_t* const src,
960                                 const ptrdiff_t over_read_in_bytes) {
961   __m256i s[3];
962   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
963   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 2);
964   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 4);
965   return Sum3_16(s);
966 }
967 
Sum5Horizontal16(const __m128i src[2])968 inline __m128i Sum5Horizontal16(const __m128i src[2]) {
969   __m128i s[5];
970   Prepare5_16(src, s);
971   return Sum5_16(s);
972 }
973 
Sum5Horizontal16(const uint16_t * const src,const ptrdiff_t over_read_in_bytes)974 inline __m256i Sum5Horizontal16(const uint16_t* const src,
975                                 const ptrdiff_t over_read_in_bytes) {
976   __m256i s[5];
977   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
978   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 2);
979   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 4);
980   s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 6);
981   s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 8);
982   return Sum5_16(s);
983 }
984 
SumHorizontal16(const uint16_t * const src,const ptrdiff_t over_read_in_bytes,__m256i * const row3,__m256i * const row5)985 inline void SumHorizontal16(const uint16_t* const src,
986                             const ptrdiff_t over_read_in_bytes,
987                             __m256i* const row3, __m256i* const row5) {
988   __m256i s[5];
989   s[0] = LoadUnaligned32Msan(src + 0, over_read_in_bytes + 0);
990   s[1] = LoadUnaligned32Msan(src + 1, over_read_in_bytes + 2);
991   s[2] = LoadUnaligned32Msan(src + 2, over_read_in_bytes + 4);
992   s[3] = LoadUnaligned32Msan(src + 3, over_read_in_bytes + 6);
993   s[4] = LoadUnaligned32Msan(src + 4, over_read_in_bytes + 8);
994   const __m256i sum04 = _mm256_add_epi16(s[0], s[4]);
995   *row3 = Sum3_16(s + 1);
996   *row5 = _mm256_add_epi16(sum04, *row3);
997 }
998 
SumHorizontal16(const uint16_t * const src,const ptrdiff_t over_read_in_bytes,__m256i * const row3_0,__m256i * const row3_1,__m256i * const row5_0,__m256i * const row5_1)999 inline void SumHorizontal16(const uint16_t* const src,
1000                             const ptrdiff_t over_read_in_bytes,
1001                             __m256i* const row3_0, __m256i* const row3_1,
1002                             __m256i* const row5_0, __m256i* const row5_1) {
1003   SumHorizontal16(src + 0, over_read_in_bytes + 0, row3_0, row5_0);
1004   SumHorizontal16(src + 16, over_read_in_bytes + 32, row3_1, row5_1);
1005 }
1006 
SumHorizontal32(const __m128i src[5],__m128i * const row_sq3,__m128i * const row_sq5)1007 inline void SumHorizontal32(const __m128i src[5], __m128i* const row_sq3,
1008                             __m128i* const row_sq5) {
1009   const __m128i sum04 = _mm_add_epi32(src[0], src[4]);
1010   *row_sq3 = Sum3_32(src + 1);
1011   *row_sq5 = _mm_add_epi32(sum04, *row_sq3);
1012 }
1013 
SumHorizontal32(const __m256i src[5],__m256i * const row_sq3,__m256i * const row_sq5)1014 inline void SumHorizontal32(const __m256i src[5], __m256i* const row_sq3,
1015                             __m256i* const row_sq5) {
1016   const __m256i sum04 = _mm256_add_epi32(src[0], src[4]);
1017   *row_sq3 = Sum3_32(src + 1);
1018   *row_sq5 = _mm256_add_epi32(sum04, *row_sq3);
1019 }
1020 
SumHorizontal32(const __m128i src[3],__m128i * const row_sq3_0,__m128i * const row_sq3_1,__m128i * const row_sq5_0,__m128i * const row_sq5_1)1021 inline void SumHorizontal32(const __m128i src[3], __m128i* const row_sq3_0,
1022                             __m128i* const row_sq3_1, __m128i* const row_sq5_0,
1023                             __m128i* const row_sq5_1) {
1024   __m128i s[5];
1025   Prepare5_32(src + 0, s);
1026   SumHorizontal32(s, row_sq3_0, row_sq5_0);
1027   Prepare5_32(src + 1, s);
1028   SumHorizontal32(s, row_sq3_1, row_sq5_1);
1029 }
1030 
SumHorizontal32(const __m256i src[3],__m256i * const row_sq3_0,__m256i * const row_sq3_1,__m256i * const row_sq5_0,__m256i * const row_sq5_1)1031 inline void SumHorizontal32(const __m256i src[3], __m256i* const row_sq3_0,
1032                             __m256i* const row_sq3_1, __m256i* const row_sq5_0,
1033                             __m256i* const row_sq5_1) {
1034   __m256i s[5];
1035   Prepare5_32(src + 0, s);
1036   SumHorizontal32(s, row_sq3_0, row_sq5_0);
1037   Prepare5_32(src + 1, s);
1038   SumHorizontal32(s, row_sq3_1, row_sq5_1);
1039 }
1040 
Sum3Horizontal32(const __m128i src[3],__m128i dst[2])1041 inline void Sum3Horizontal32(const __m128i src[3], __m128i dst[2]) {
1042   __m128i s[3];
1043   Prepare3_32(src + 0, s);
1044   dst[0] = Sum3_32(s);
1045   Prepare3_32(src + 1, s);
1046   dst[1] = Sum3_32(s);
1047 }
1048 
Sum3Horizontal32(const __m256i src[3],__m256i dst[2])1049 inline void Sum3Horizontal32(const __m256i src[3], __m256i dst[2]) {
1050   __m256i s[3];
1051   Prepare3_32(src + 0, s);
1052   dst[0] = Sum3_32(s);
1053   Prepare3_32(src + 1, s);
1054   dst[1] = Sum3_32(s);
1055 }
1056 
Sum5Horizontal32(const __m128i src[3],__m128i dst[2])1057 inline void Sum5Horizontal32(const __m128i src[3], __m128i dst[2]) {
1058   __m128i s[5];
1059   Prepare5_32(src + 0, s);
1060   dst[0] = Sum5_32(s);
1061   Prepare5_32(src + 1, s);
1062   dst[1] = Sum5_32(s);
1063 }
1064 
Sum5Horizontal32(const __m256i src[3],__m256i dst[2])1065 inline void Sum5Horizontal32(const __m256i src[3], __m256i dst[2]) {
1066   __m256i s[5];
1067   Prepare5_32(src + 0, s);
1068   dst[0] = Sum5_32(s);
1069   Prepare5_32(src + 1, s);
1070   dst[1] = Sum5_32(s);
1071 }
1072 
SumHorizontal16(const __m128i src[2],__m128i * const row3,__m128i * const row5)1073 void SumHorizontal16(const __m128i src[2], __m128i* const row3,
1074                      __m128i* const row5) {
1075   __m128i s[5];
1076   Prepare5_16(src, s);
1077   const __m128i sum04 = _mm_add_epi16(s[0], s[4]);
1078   *row3 = Sum3_16(s + 1);
1079   *row5 = _mm_add_epi16(sum04, *row3);
1080 }
1081 
Sum343Lo(const __m256i ma3[3])1082 inline __m256i Sum343Lo(const __m256i ma3[3]) {
1083   const __m256i sum = Sum3WLo16(ma3);
1084   const __m256i sum3 = Sum3_16(sum, sum, sum);
1085   return VaddwLo8(sum3, ma3[1]);
1086 }
1087 
Sum343Hi(const __m256i ma3[3])1088 inline __m256i Sum343Hi(const __m256i ma3[3]) {
1089   const __m256i sum = Sum3WHi16(ma3);
1090   const __m256i sum3 = Sum3_16(sum, sum, sum);
1091   return VaddwHi8(sum3, ma3[1]);
1092 }
1093 
Sum343(const __m256i src[3])1094 inline __m256i Sum343(const __m256i src[3]) {
1095   const __m256i sum = Sum3_32(src);
1096   const __m256i sum3 = Sum3_32(sum, sum, sum);
1097   return _mm256_add_epi32(sum3, src[1]);
1098 }
1099 
Sum343(const __m256i src[3],__m256i dst[2])1100 inline void Sum343(const __m256i src[3], __m256i dst[2]) {
1101   __m256i s[3];
1102   Prepare3_32(src + 0, s);
1103   dst[0] = Sum343(s);
1104   Prepare3_32(src + 1, s);
1105   dst[1] = Sum343(s);
1106 }
1107 
Sum565Lo(const __m256i src[3])1108 inline __m256i Sum565Lo(const __m256i src[3]) {
1109   const __m256i sum = Sum3WLo16(src);
1110   const __m256i sum4 = _mm256_slli_epi16(sum, 2);
1111   const __m256i sum5 = _mm256_add_epi16(sum4, sum);
1112   return VaddwLo8(sum5, src[1]);
1113 }
1114 
Sum565Hi(const __m256i src[3])1115 inline __m256i Sum565Hi(const __m256i src[3]) {
1116   const __m256i sum = Sum3WHi16(src);
1117   const __m256i sum4 = _mm256_slli_epi16(sum, 2);
1118   const __m256i sum5 = _mm256_add_epi16(sum4, sum);
1119   return VaddwHi8(sum5, src[1]);
1120 }
1121 
Sum565(const __m256i src[3])1122 inline __m256i Sum565(const __m256i src[3]) {
1123   const __m256i sum = Sum3_32(src);
1124   const __m256i sum4 = _mm256_slli_epi32(sum, 2);
1125   const __m256i sum5 = _mm256_add_epi32(sum4, sum);
1126   return _mm256_add_epi32(sum5, src[1]);
1127 }
1128 
Sum565(const __m256i src[3],__m256i dst[2])1129 inline void Sum565(const __m256i src[3], __m256i dst[2]) {
1130   __m256i s[3];
1131   Prepare3_32(src + 0, s);
1132   dst[0] = Sum565(s);
1133   Prepare3_32(src + 1, s);
1134   dst[1] = Sum565(s);
1135 }
1136 
BoxSum(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sum3,uint16_t * sum5,uint32_t * square_sum3,uint32_t * square_sum5)1137 inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride,
1138                    const ptrdiff_t width, const ptrdiff_t sum_stride,
1139                    const ptrdiff_t sum_width, uint16_t* sum3, uint16_t* sum5,
1140                    uint32_t* square_sum3, uint32_t* square_sum5) {
1141   const ptrdiff_t overread_in_bytes_128 =
1142       kOverreadInBytesPass1_128 - sizeof(*src) * width;
1143   const ptrdiff_t overread_in_bytes_256 =
1144       kOverreadInBytesPass1_256 - sizeof(*src) * width;
1145   int y = 2;
1146   do {
1147     __m128i s0[2], sq_128[4], s3, s5, sq3[2], sq5[2];
1148     __m256i sq[8];
1149     s0[0] = LoadUnaligned16Msan(src + 0, overread_in_bytes_128 + 0);
1150     s0[1] = LoadUnaligned16Msan(src + 8, overread_in_bytes_128 + 16);
1151     Square(s0[0], sq_128 + 0);
1152     Square(s0[1], sq_128 + 2);
1153     SumHorizontal16(s0, &s3, &s5);
1154     StoreAligned16(sum3, s3);
1155     StoreAligned16(sum5, s5);
1156     SumHorizontal32(sq_128, &sq3[0], &sq3[1], &sq5[0], &sq5[1]);
1157     StoreAligned32U32(square_sum3, sq3);
1158     StoreAligned32U32(square_sum5, sq5);
1159     src += 8;
1160     sum3 += 8;
1161     sum5 += 8;
1162     square_sum3 += 8;
1163     square_sum5 += 8;
1164     sq[0] = SetrM128i(sq_128[2], sq_128[2]);
1165     sq[1] = SetrM128i(sq_128[3], sq_128[3]);
1166     ptrdiff_t x = sum_width;
1167     do {
1168       __m256i s[2], row3[2], row5[2], row_sq3[2], row_sq5[2];
1169       s[0] = LoadUnaligned32Msan(
1170           src + 8, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8));
1171       s[1] = LoadUnaligned32Msan(
1172           src + 24,
1173           overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 24));
1174       Square(s[0], sq + 2);
1175       Square(s[1], sq + 6);
1176       sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1177       sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21);
1178       sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21);
1179       sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21);
1180       SumHorizontal16(
1181           src, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8),
1182           &row3[0], &row3[1], &row5[0], &row5[1]);
1183       StoreAligned64(sum3, row3);
1184       StoreAligned64(sum5, row5);
1185       SumHorizontal32(sq + 0, &row_sq3[0], &row_sq3[1], &row_sq5[0],
1186                       &row_sq5[1]);
1187       StoreAligned64(square_sum3 + 0, row_sq3);
1188       StoreAligned64(square_sum5 + 0, row_sq5);
1189       SumHorizontal32(sq + 4, &row_sq3[0], &row_sq3[1], &row_sq5[0],
1190                       &row_sq5[1]);
1191       StoreAligned64(square_sum3 + 16, row_sq3);
1192       StoreAligned64(square_sum5 + 16, row_sq5);
1193       sq[0] = sq[6];
1194       sq[1] = sq[7];
1195       src += 32;
1196       sum3 += 32;
1197       sum5 += 32;
1198       square_sum3 += 32;
1199       square_sum5 += 32;
1200       x -= 32;
1201     } while (x != 0);
1202     src += src_stride - sum_width - 8;
1203     sum3 += sum_stride - sum_width - 8;
1204     sum5 += sum_stride - sum_width - 8;
1205     square_sum3 += sum_stride - sum_width - 8;
1206     square_sum5 += sum_stride - sum_width - 8;
1207   } while (--y != 0);
1208 }
1209 
1210 template <int size>
BoxSum(const uint16_t * src,const ptrdiff_t src_stride,const ptrdiff_t width,const ptrdiff_t sum_stride,const ptrdiff_t sum_width,uint16_t * sums,uint32_t * square_sums)1211 inline void BoxSum(const uint16_t* src, const ptrdiff_t src_stride,
1212                    const ptrdiff_t width, const ptrdiff_t sum_stride,
1213                    const ptrdiff_t sum_width, uint16_t* sums,
1214                    uint32_t* square_sums) {
1215   static_assert(size == 3 || size == 5, "");
1216   int overread_in_bytes_128, overread_in_bytes_256;
1217   if (size == 3) {
1218     overread_in_bytes_128 = kOverreadInBytesPass2_128;
1219     overread_in_bytes_256 = kOverreadInBytesPass2_256;
1220   } else {
1221     overread_in_bytes_128 = kOverreadInBytesPass1_128;
1222     overread_in_bytes_256 = kOverreadInBytesPass1_256;
1223   }
1224   overread_in_bytes_128 -= sizeof(*src) * width;
1225   overread_in_bytes_256 -= sizeof(*src) * width;
1226   int y = 2;
1227   do {
1228     __m128i s_128[2], ss, sq_128[4], sqs[2];
1229     __m256i sq[8];
1230     s_128[0] = LoadUnaligned16Msan(src + 0, overread_in_bytes_128);
1231     s_128[1] = LoadUnaligned16Msan(src + 8, overread_in_bytes_128 + 16);
1232     Square(s_128[0], sq_128 + 0);
1233     Square(s_128[1], sq_128 + 2);
1234     if (size == 3) {
1235       ss = Sum3Horizontal16(s_128);
1236       Sum3Horizontal32(sq_128, sqs);
1237     } else {
1238       ss = Sum5Horizontal16(s_128);
1239       Sum5Horizontal32(sq_128, sqs);
1240     }
1241     StoreAligned16(sums, ss);
1242     StoreAligned32U32(square_sums, sqs);
1243     src += 8;
1244     sums += 8;
1245     square_sums += 8;
1246     sq[0] = SetrM128i(sq_128[2], sq_128[2]);
1247     sq[1] = SetrM128i(sq_128[3], sq_128[3]);
1248     ptrdiff_t x = sum_width;
1249     do {
1250       __m256i s[2], row[2], row_sq[4];
1251       s[0] = LoadUnaligned32Msan(
1252           src + 8, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8));
1253       s[1] = LoadUnaligned32Msan(
1254           src + 24,
1255           overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 24));
1256       Square(s[0], sq + 2);
1257       Square(s[1], sq + 6);
1258       sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1259       sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21);
1260       sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21);
1261       sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21);
1262       if (size == 3) {
1263         row[0] = Sum3Horizontal16(
1264             src, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8));
1265         row[1] =
1266             Sum3Horizontal16(src + 16, overread_in_bytes_256 +
1267                                            sizeof(*src) * (sum_width - x + 24));
1268         Sum3Horizontal32(sq + 0, row_sq + 0);
1269         Sum3Horizontal32(sq + 4, row_sq + 2);
1270       } else {
1271         row[0] = Sum5Horizontal16(
1272             src, overread_in_bytes_256 + sizeof(*src) * (sum_width - x + 8));
1273         row[1] =
1274             Sum5Horizontal16(src + 16, overread_in_bytes_256 +
1275                                            sizeof(*src) * (sum_width - x + 24));
1276         Sum5Horizontal32(sq + 0, row_sq + 0);
1277         Sum5Horizontal32(sq + 4, row_sq + 2);
1278       }
1279       StoreAligned64(sums, row);
1280       StoreAligned64(square_sums + 0, row_sq + 0);
1281       StoreAligned64(square_sums + 16, row_sq + 2);
1282       sq[0] = sq[6];
1283       sq[1] = sq[7];
1284       src += 32;
1285       sums += 32;
1286       square_sums += 32;
1287       x -= 32;
1288     } while (x != 0);
1289     src += src_stride - sum_width - 8;
1290     sums += sum_stride - sum_width - 8;
1291     square_sums += sum_stride - sum_width - 8;
1292   } while (--y != 0);
1293 }
1294 
1295 template <int n>
CalculateMa(const __m128i sum,const __m128i sum_sq,const uint32_t scale)1296 inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq,
1297                            const uint32_t scale) {
1298   static_assert(n == 9 || n == 25, "");
1299   // a = |sum_sq|
1300   // d = |sum|
1301   // p = (a * n < d * d) ? 0 : a * n - d * d;
1302   const __m128i dxd = _mm_madd_epi16(sum, sum);
1303   // _mm_mullo_epi32() has high latency. Using shifts and additions instead.
1304   // Some compilers could do this for us but we make this explicit.
1305   // return _mm_mullo_epi32(sum_sq, _mm_set1_epi32(n));
1306   __m128i axn = _mm_add_epi32(sum_sq, _mm_slli_epi32(sum_sq, 3));
1307   if (n == 25) axn = _mm_add_epi32(axn, _mm_slli_epi32(sum_sq, 4));
1308   const __m128i sub = _mm_sub_epi32(axn, dxd);
1309   const __m128i p = _mm_max_epi32(sub, _mm_setzero_si128());
1310   const __m128i pxs = _mm_mullo_epi32(p, _mm_set1_epi32(scale));
1311   return VrshrU32(pxs, kSgrProjScaleBits);
1312 }
1313 
1314 template <int n>
CalculateMa(const __m128i sum,const __m128i sum_sq[2],const uint32_t scale)1315 inline __m128i CalculateMa(const __m128i sum, const __m128i sum_sq[2],
1316                            const uint32_t scale) {
1317   static_assert(n == 9 || n == 25, "");
1318   const __m128i b = VrshrU16(sum, 2);
1319   const __m128i sum_lo = _mm_unpacklo_epi16(b, _mm_setzero_si128());
1320   const __m128i sum_hi = _mm_unpackhi_epi16(b, _mm_setzero_si128());
1321   const __m128i z0 = CalculateMa<n>(sum_lo, VrshrU32(sum_sq[0], 4), scale);
1322   const __m128i z1 = CalculateMa<n>(sum_hi, VrshrU32(sum_sq[1], 4), scale);
1323   return _mm_packus_epi32(z0, z1);
1324 }
1325 
1326 template <int n>
CalculateMa(const __m256i sum,const __m256i sum_sq,const uint32_t scale)1327 inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq,
1328                            const uint32_t scale) {
1329   static_assert(n == 9 || n == 25, "");
1330   // a = |sum_sq|
1331   // d = |sum|
1332   // p = (a * n < d * d) ? 0 : a * n - d * d;
1333   const __m256i dxd = _mm256_madd_epi16(sum, sum);
1334   // _mm256_mullo_epi32() has high latency. Using shifts and additions instead.
1335   // Some compilers could do this for us but we make this explicit.
1336   // return _mm256_mullo_epi32(sum_sq, _mm256_set1_epi32(n));
1337   __m256i axn = _mm256_add_epi32(sum_sq, _mm256_slli_epi32(sum_sq, 3));
1338   if (n == 25) axn = _mm256_add_epi32(axn, _mm256_slli_epi32(sum_sq, 4));
1339   const __m256i sub = _mm256_sub_epi32(axn, dxd);
1340   const __m256i p = _mm256_max_epi32(sub, _mm256_setzero_si256());
1341   const __m256i pxs = _mm256_mullo_epi32(p, _mm256_set1_epi32(scale));
1342   return VrshrU32(pxs, kSgrProjScaleBits);
1343 }
1344 
1345 template <int n>
CalculateMa(const __m256i sum,const __m256i sum_sq[2],const uint32_t scale)1346 inline __m256i CalculateMa(const __m256i sum, const __m256i sum_sq[2],
1347                            const uint32_t scale) {
1348   static_assert(n == 9 || n == 25, "");
1349   const __m256i b = VrshrU16(sum, 2);
1350   const __m256i sum_lo = _mm256_unpacklo_epi16(b, _mm256_setzero_si256());
1351   const __m256i sum_hi = _mm256_unpackhi_epi16(b, _mm256_setzero_si256());
1352   const __m256i z0 = CalculateMa<n>(sum_lo, VrshrU32(sum_sq[0], 4), scale);
1353   const __m256i z1 = CalculateMa<n>(sum_hi, VrshrU32(sum_sq[1], 4), scale);
1354   return _mm256_packus_epi32(z0, z1);
1355 }
1356 
CalculateB5(const __m128i sum,const __m128i ma,__m128i b[2])1357 inline void CalculateB5(const __m128i sum, const __m128i ma, __m128i b[2]) {
1358   // one_over_n == 164.
1359   constexpr uint32_t one_over_n =
1360       ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25;
1361   // one_over_n_quarter == 41.
1362   constexpr uint32_t one_over_n_quarter = one_over_n >> 2;
1363   static_assert(one_over_n == one_over_n_quarter << 2, "");
1364   // |ma| is in range [0, 255].
1365   const __m128i m = _mm_maddubs_epi16(ma, _mm_set1_epi16(one_over_n_quarter));
1366   const __m128i m0 = VmullLo16(m, sum);
1367   const __m128i m1 = VmullHi16(m, sum);
1368   b[0] = VrshrU32(m0, kSgrProjReciprocalBits - 2);
1369   b[1] = VrshrU32(m1, kSgrProjReciprocalBits - 2);
1370 }
1371 
CalculateB5(const __m256i sum,const __m256i ma,__m256i b[2])1372 inline void CalculateB5(const __m256i sum, const __m256i ma, __m256i b[2]) {
1373   // one_over_n == 164.
1374   constexpr uint32_t one_over_n =
1375       ((1 << kSgrProjReciprocalBits) + (25 >> 1)) / 25;
1376   // one_over_n_quarter == 41.
1377   constexpr uint32_t one_over_n_quarter = one_over_n >> 2;
1378   static_assert(one_over_n == one_over_n_quarter << 2, "");
1379   // |ma| is in range [0, 255].
1380   const __m256i m =
1381       _mm256_maddubs_epi16(ma, _mm256_set1_epi16(one_over_n_quarter));
1382   const __m256i m0 = VmullLo16(m, sum);
1383   const __m256i m1 = VmullHi16(m, sum);
1384   b[0] = VrshrU32(m0, kSgrProjReciprocalBits - 2);
1385   b[1] = VrshrU32(m1, kSgrProjReciprocalBits - 2);
1386 }
1387 
CalculateB3(const __m128i sum,const __m128i ma,__m128i b[2])1388 inline void CalculateB3(const __m128i sum, const __m128i ma, __m128i b[2]) {
1389   // one_over_n == 455.
1390   constexpr uint32_t one_over_n =
1391       ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9;
1392   const __m128i m0 = VmullLo16(ma, sum);
1393   const __m128i m1 = VmullHi16(ma, sum);
1394   const __m128i m2 = _mm_mullo_epi32(m0, _mm_set1_epi32(one_over_n));
1395   const __m128i m3 = _mm_mullo_epi32(m1, _mm_set1_epi32(one_over_n));
1396   b[0] = VrshrU32(m2, kSgrProjReciprocalBits);
1397   b[1] = VrshrU32(m3, kSgrProjReciprocalBits);
1398 }
1399 
CalculateB3(const __m256i sum,const __m256i ma,__m256i b[2])1400 inline void CalculateB3(const __m256i sum, const __m256i ma, __m256i b[2]) {
1401   // one_over_n == 455.
1402   constexpr uint32_t one_over_n =
1403       ((1 << kSgrProjReciprocalBits) + (9 >> 1)) / 9;
1404   const __m256i m0 = VmullLo16(ma, sum);
1405   const __m256i m1 = VmullHi16(ma, sum);
1406   const __m256i m2 = _mm256_mullo_epi32(m0, _mm256_set1_epi32(one_over_n));
1407   const __m256i m3 = _mm256_mullo_epi32(m1, _mm256_set1_epi32(one_over_n));
1408   b[0] = VrshrU32(m2, kSgrProjReciprocalBits);
1409   b[1] = VrshrU32(m3, kSgrProjReciprocalBits);
1410 }
1411 
CalculateSumAndIndex5(const __m128i s5[5],const __m128i sq5[5][2],const uint32_t scale,__m128i * const sum,__m128i * const index)1412 inline void CalculateSumAndIndex5(const __m128i s5[5], const __m128i sq5[5][2],
1413                                   const uint32_t scale, __m128i* const sum,
1414                                   __m128i* const index) {
1415   __m128i sum_sq[2];
1416   *sum = Sum5_16(s5);
1417   Sum5_32(sq5, sum_sq);
1418   *index = CalculateMa<25>(*sum, sum_sq, scale);
1419 }
1420 
CalculateSumAndIndex5(const __m256i s5[5],const __m256i sq5[5][2],const uint32_t scale,__m256i * const sum,__m256i * const index)1421 inline void CalculateSumAndIndex5(const __m256i s5[5], const __m256i sq5[5][2],
1422                                   const uint32_t scale, __m256i* const sum,
1423                                   __m256i* const index) {
1424   __m256i sum_sq[2];
1425   *sum = Sum5_16(s5);
1426   Sum5_32(sq5, sum_sq);
1427   *index = CalculateMa<25>(*sum, sum_sq, scale);
1428 }
1429 
CalculateSumAndIndex3(const __m128i s3[3],const __m128i sq3[3][2],const uint32_t scale,__m128i * const sum,__m128i * const index)1430 inline void CalculateSumAndIndex3(const __m128i s3[3], const __m128i sq3[3][2],
1431                                   const uint32_t scale, __m128i* const sum,
1432                                   __m128i* const index) {
1433   __m128i sum_sq[2];
1434   *sum = Sum3_16(s3);
1435   Sum3_32(sq3, sum_sq);
1436   *index = CalculateMa<9>(*sum, sum_sq, scale);
1437 }
1438 
CalculateSumAndIndex3(const __m256i s3[3],const __m256i sq3[3][2],const uint32_t scale,__m256i * const sum,__m256i * const index)1439 inline void CalculateSumAndIndex3(const __m256i s3[3], const __m256i sq3[3][2],
1440                                   const uint32_t scale, __m256i* const sum,
1441                                   __m256i* const index) {
1442   __m256i sum_sq[2];
1443   *sum = Sum3_16(s3);
1444   Sum3_32(sq3, sum_sq);
1445   *index = CalculateMa<9>(*sum, sum_sq, scale);
1446 }
1447 
1448 template <int n>
LookupIntermediate(const __m128i sum,const __m128i index,__m128i * const ma,__m128i b[2])1449 inline void LookupIntermediate(const __m128i sum, const __m128i index,
1450                                __m128i* const ma, __m128i b[2]) {
1451   static_assert(n == 9 || n == 25, "");
1452   const __m128i idx = _mm_packus_epi16(index, index);
1453   // Actually it's not stored and loaded. The compiler will use a 64-bit
1454   // general-purpose register to process. Faster than using _mm_extract_epi8().
1455   uint8_t temp[8];
1456   StoreLo8(temp, idx);
1457   *ma = _mm_cvtsi32_si128(kSgrMaLookup[temp[0]]);
1458   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[1]], 1);
1459   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[2]], 2);
1460   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[3]], 3);
1461   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[4]], 4);
1462   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[5]], 5);
1463   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[6]], 6);
1464   *ma = _mm_insert_epi8(*ma, kSgrMaLookup[temp[7]], 7);
1465   // b = ma * b * one_over_n
1466   // |ma| = [0, 255]
1467   // |sum| is a box sum with radius 1 or 2.
1468   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1469   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1470   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1471   // When radius is 2 |n| is 25. |one_over_n| is 164.
1472   // When radius is 1 |n| is 9. |one_over_n| is 455.
1473   // |kSgrProjReciprocalBits| is 12.
1474   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1475   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1476   const __m128i maq = _mm_unpacklo_epi8(*ma, _mm_setzero_si128());
1477   if (n == 9) {
1478     CalculateB3(sum, maq, b);
1479   } else {
1480     CalculateB5(sum, maq, b);
1481   }
1482 }
1483 
1484 // Repeat the first 48 elements in kSgrMaLookup with a period of 16.
1485 alignas(32) constexpr uint8_t kSgrMaLookupAvx2[96] = {
1486     255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16,
1487     255, 128, 85, 64, 51, 43, 37, 32, 28, 26, 23, 21, 20, 18, 17, 16,
1488     15,  14,  13, 13, 12, 12, 11, 11, 10, 10, 9,  9,  9,  9,  8,  8,
1489     15,  14,  13, 13, 12, 12, 11, 11, 10, 10, 9,  9,  9,  9,  8,  8,
1490     8,   8,   7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  5,  5,
1491     8,   8,   7,  7,  7,  7,  7,  6,  6,  6,  6,  6,  6,  6,  5,  5};
1492 
1493 // Set the shuffle control mask of indices out of range [0, 15] to (1xxxxxxx)b
1494 // to get value 0 as the shuffle result. The most significiant bit 1 comes
1495 // either from the comparison instruction, or from the sign bit of the index.
ShuffleIndex(const __m128i table,const __m128i index)1496 inline __m128i ShuffleIndex(const __m128i table, const __m128i index) {
1497   __m128i mask;
1498   mask = _mm_cmpgt_epi8(index, _mm_set1_epi8(15));
1499   mask = _mm_or_si128(mask, index);
1500   return _mm_shuffle_epi8(table, mask);
1501 }
1502 
ShuffleIndex(const __m256i table,const __m256i index)1503 inline __m256i ShuffleIndex(const __m256i table, const __m256i index) {
1504   __m256i mask;
1505   mask = _mm256_cmpgt_epi8(index, _mm256_set1_epi8(15));
1506   mask = _mm256_or_si256(mask, index);
1507   return _mm256_shuffle_epi8(table, mask);
1508 }
1509 
AdjustValue(const __m128i value,const __m128i index,const int threshold)1510 inline __m128i AdjustValue(const __m128i value, const __m128i index,
1511                            const int threshold) {
1512   const __m128i thresholds = _mm_set1_epi8(threshold - 128);
1513   const __m128i offset = _mm_cmpgt_epi8(index, thresholds);
1514   return _mm_add_epi8(value, offset);
1515 }
1516 
AdjustValue(const __m256i value,const __m256i index,const int threshold)1517 inline __m256i AdjustValue(const __m256i value, const __m256i index,
1518                            const int threshold) {
1519   const __m256i thresholds = _mm256_set1_epi8(threshold - 128);
1520   const __m256i offset = _mm256_cmpgt_epi8(index, thresholds);
1521   return _mm256_add_epi8(value, offset);
1522 }
1523 
CalculateIntermediate(const __m128i sum[2],const __m128i index[2],__m128i * const ma,__m128i b0[2],__m128i b1[2])1524 inline void CalculateIntermediate(const __m128i sum[2], const __m128i index[2],
1525                                   __m128i* const ma, __m128i b0[2],
1526                                   __m128i b1[2]) {
1527   // Use table lookup to read elements whose indices are less than 48.
1528   const __m128i c0 = LoadAligned16(kSgrMaLookup + 0 * 16);
1529   const __m128i c1 = LoadAligned16(kSgrMaLookup + 1 * 16);
1530   const __m128i c2 = LoadAligned16(kSgrMaLookup + 2 * 16);
1531   const __m128i indices = _mm_packus_epi16(index[0], index[1]);
1532   __m128i idx;
1533   // Clip idx to 127 to apply signed comparison instructions.
1534   idx = _mm_min_epu8(indices, _mm_set1_epi8(127));
1535   // All elements whose indices are less than 48 are set to 0.
1536   // Get shuffle results for indices in range [0, 15].
1537   *ma = ShuffleIndex(c0, idx);
1538   // Get shuffle results for indices in range [16, 31].
1539   // Subtract 16 to utilize the sign bit of the index.
1540   idx = _mm_sub_epi8(idx, _mm_set1_epi8(16));
1541   const __m128i res1 = ShuffleIndex(c1, idx);
1542   // Use OR instruction to combine shuffle results together.
1543   *ma = _mm_or_si128(*ma, res1);
1544   // Get shuffle results for indices in range [32, 47].
1545   // Subtract 16 to utilize the sign bit of the index.
1546   idx = _mm_sub_epi8(idx, _mm_set1_epi8(16));
1547   const __m128i res2 = ShuffleIndex(c2, idx);
1548   *ma = _mm_or_si128(*ma, res2);
1549 
1550   // For elements whose indices are larger than 47, since they seldom change
1551   // values with the increase of the index, we use comparison and arithmetic
1552   // operations to calculate their values.
1553   // Add -128 to apply signed comparison instructions.
1554   idx = _mm_add_epi8(indices, _mm_set1_epi8(-128));
1555   // Elements whose indices are larger than 47 (with value 0) are set to 5.
1556   *ma = _mm_max_epu8(*ma, _mm_set1_epi8(5));
1557   *ma = AdjustValue(*ma, idx, 55);   // 55 is the last index which value is 5.
1558   *ma = AdjustValue(*ma, idx, 72);   // 72 is the last index which value is 4.
1559   *ma = AdjustValue(*ma, idx, 101);  // 101 is the last index which value is 3.
1560   *ma = AdjustValue(*ma, idx, 169);  // 169 is the last index which value is 2.
1561   *ma = AdjustValue(*ma, idx, 254);  // 254 is the last index which value is 1.
1562 
1563   // b = ma * b * one_over_n
1564   // |ma| = [0, 255]
1565   // |sum| is a box sum with radius 1 or 2.
1566   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1567   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1568   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1569   // When radius is 2 |n| is 25. |one_over_n| is 164.
1570   // When radius is 1 |n| is 9. |one_over_n| is 455.
1571   // |kSgrProjReciprocalBits| is 12.
1572   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1573   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1574   const __m128i maq0 = _mm_unpacklo_epi8(*ma, _mm_setzero_si128());
1575   CalculateB3(sum[0], maq0, b0);
1576   const __m128i maq1 = _mm_unpackhi_epi8(*ma, _mm_setzero_si128());
1577   CalculateB3(sum[1], maq1, b1);
1578 }
1579 
1580 template <int n>
CalculateIntermediate(const __m256i sum[2],const __m256i index[2],__m256i ma[3],__m256i b0[2],__m256i b1[2])1581 inline void CalculateIntermediate(const __m256i sum[2], const __m256i index[2],
1582                                   __m256i ma[3], __m256i b0[2], __m256i b1[2]) {
1583   static_assert(n == 9 || n == 25, "");
1584   // Use table lookup to read elements whose indices are less than 48.
1585   const __m256i c0 = LoadAligned32(kSgrMaLookupAvx2 + 0 * 32);
1586   const __m256i c1 = LoadAligned32(kSgrMaLookupAvx2 + 1 * 32);
1587   const __m256i c2 = LoadAligned32(kSgrMaLookupAvx2 + 2 * 32);
1588   const __m256i indices = _mm256_packus_epi16(index[0], index[1]);  // 0 2 1 3
1589   __m256i idx, mas;
1590   // Clip idx to 127 to apply signed comparison instructions.
1591   idx = _mm256_min_epu8(indices, _mm256_set1_epi8(127));
1592   // All elements whose indices are less than 48 are set to 0.
1593   // Get shuffle results for indices in range [0, 15].
1594   mas = ShuffleIndex(c0, idx);
1595   // Get shuffle results for indices in range [16, 31].
1596   // Subtract 16 to utilize the sign bit of the index.
1597   idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16));
1598   const __m256i res1 = ShuffleIndex(c1, idx);
1599   // Use OR instruction to combine shuffle results together.
1600   mas = _mm256_or_si256(mas, res1);
1601   // Get shuffle results for indices in range [32, 47].
1602   // Subtract 16 to utilize the sign bit of the index.
1603   idx = _mm256_sub_epi8(idx, _mm256_set1_epi8(16));
1604   const __m256i res2 = ShuffleIndex(c2, idx);
1605   mas = _mm256_or_si256(mas, res2);
1606 
1607   // For elements whose indices are larger than 47, since they seldom change
1608   // values with the increase of the index, we use comparison and arithmetic
1609   // operations to calculate their values.
1610   // Add -128 to apply signed comparison instructions.
1611   idx = _mm256_add_epi8(indices, _mm256_set1_epi8(-128));
1612   // Elements whose indices are larger than 47 (with value 0) are set to 5.
1613   mas = _mm256_max_epu8(mas, _mm256_set1_epi8(5));
1614   mas = AdjustValue(mas, idx, 55);   // 55 is the last index which value is 5.
1615   mas = AdjustValue(mas, idx, 72);   // 72 is the last index which value is 4.
1616   mas = AdjustValue(mas, idx, 101);  // 101 is the last index which value is 3.
1617   mas = AdjustValue(mas, idx, 169);  // 169 is the last index which value is 2.
1618   mas = AdjustValue(mas, idx, 254);  // 254 is the last index which value is 1.
1619 
1620   ma[2] = _mm256_permute4x64_epi64(mas, 0x63);     // 32-39 8-15 16-23 24-31
1621   ma[0] = _mm256_blend_epi32(ma[0], ma[2], 0xfc);  //  0-7  8-15 16-23 24-31
1622   ma[1] = _mm256_permute2x128_si256(ma[0], ma[2], 0x21);
1623 
1624   // b = ma * b * one_over_n
1625   // |ma| = [0, 255]
1626   // |sum| is a box sum with radius 1 or 2.
1627   // For the first pass radius is 2. Maximum value is 5x5x255 = 6375.
1628   // For the second pass radius is 1. Maximum value is 3x3x255 = 2295.
1629   // |one_over_n| = ((1 << kSgrProjReciprocalBits) + (n >> 1)) / n
1630   // When radius is 2 |n| is 25. |one_over_n| is 164.
1631   // When radius is 1 |n| is 9. |one_over_n| is 455.
1632   // |kSgrProjReciprocalBits| is 12.
1633   // Radius 2: 255 * 6375 * 164 >> 12 = 65088 (16 bits).
1634   // Radius 1: 255 * 2295 * 455 >> 12 = 65009 (16 bits).
1635   const __m256i maq0 = _mm256_unpackhi_epi8(ma[0], _mm256_setzero_si256());
1636   const __m256i maq1 = _mm256_unpacklo_epi8(ma[1], _mm256_setzero_si256());
1637   __m256i sums[2];
1638   sums[0] = _mm256_permute2x128_si256(sum[0], sum[1], 0x20);
1639   sums[1] = _mm256_permute2x128_si256(sum[0], sum[1], 0x31);
1640   if (n == 9) {
1641     CalculateB3(sums[0], maq0, b0);
1642     CalculateB3(sums[1], maq1, b1);
1643   } else {
1644     CalculateB5(sums[0], maq0, b0);
1645     CalculateB5(sums[1], maq1, b1);
1646   }
1647 }
1648 
CalculateIntermediate5(const __m128i s5[5],const __m128i sq5[5][2],const uint32_t scale,__m128i * const ma,__m128i b[2])1649 inline void CalculateIntermediate5(const __m128i s5[5], const __m128i sq5[5][2],
1650                                    const uint32_t scale, __m128i* const ma,
1651                                    __m128i b[2]) {
1652   __m128i sum, index;
1653   CalculateSumAndIndex5(s5, sq5, scale, &sum, &index);
1654   LookupIntermediate<25>(sum, index, ma, b);
1655 }
1656 
CalculateIntermediate3(const __m128i s3[3],const __m128i sq3[3][2],const uint32_t scale,__m128i * const ma,__m128i b[2])1657 inline void CalculateIntermediate3(const __m128i s3[3], const __m128i sq3[3][2],
1658                                    const uint32_t scale, __m128i* const ma,
1659                                    __m128i b[2]) {
1660   __m128i sum, index;
1661   CalculateSumAndIndex3(s3, sq3, scale, &sum, &index);
1662   LookupIntermediate<9>(sum, index, ma, b);
1663 }
1664 
Store343_444(const __m256i b3[3],const ptrdiff_t x,__m256i sum_b343[2],__m256i sum_b444[2],uint32_t * const b343,uint32_t * const b444)1665 inline void Store343_444(const __m256i b3[3], const ptrdiff_t x,
1666                          __m256i sum_b343[2], __m256i sum_b444[2],
1667                          uint32_t* const b343, uint32_t* const b444) {
1668   __m256i b[3], sum_b111[2];
1669   Prepare3_32(b3 + 0, b);
1670   sum_b111[0] = Sum3_32(b);
1671   sum_b444[0] = _mm256_slli_epi32(sum_b111[0], 2);
1672   sum_b343[0] = _mm256_sub_epi32(sum_b444[0], sum_b111[0]);
1673   sum_b343[0] = _mm256_add_epi32(sum_b343[0], b[1]);
1674   Prepare3_32(b3 + 1, b);
1675   sum_b111[1] = Sum3_32(b);
1676   sum_b444[1] = _mm256_slli_epi32(sum_b111[1], 2);
1677   sum_b343[1] = _mm256_sub_epi32(sum_b444[1], sum_b111[1]);
1678   sum_b343[1] = _mm256_add_epi32(sum_b343[1], b[1]);
1679   StoreAligned64(b444 + x, sum_b444);
1680   StoreAligned64(b343 + x, sum_b343);
1681 }
1682 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i * const sum_ma444,__m256i sum_b343[2],__m256i sum_b444[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1683 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1684                            const ptrdiff_t x, __m256i* const sum_ma343,
1685                            __m256i* const sum_ma444, __m256i sum_b343[2],
1686                            __m256i sum_b444[2], uint16_t* const ma343,
1687                            uint16_t* const ma444, uint32_t* const b343,
1688                            uint32_t* const b444) {
1689   const __m256i sum_ma111 = Sum3WLo16(ma3);
1690   *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2);
1691   StoreAligned32_ma(ma444 + x, *sum_ma444);
1692   const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111);
1693   *sum_ma343 = VaddwLo8(sum333, ma3[1]);
1694   StoreAligned32_ma(ma343 + x, *sum_ma343);
1695   Store343_444(b3, x, sum_b343, sum_b444, b343, b444);
1696 }
1697 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i * const sum_ma444,__m256i sum_b343[2],__m256i sum_b444[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1698 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1699                            const ptrdiff_t x, __m256i* const sum_ma343,
1700                            __m256i* const sum_ma444, __m256i sum_b343[2],
1701                            __m256i sum_b444[2], uint16_t* const ma343,
1702                            uint16_t* const ma444, uint32_t* const b343,
1703                            uint32_t* const b444) {
1704   const __m256i sum_ma111 = Sum3WHi16(ma3);
1705   *sum_ma444 = _mm256_slli_epi16(sum_ma111, 2);
1706   StoreAligned32_ma(ma444 + x, *sum_ma444);
1707   const __m256i sum333 = _mm256_sub_epi16(*sum_ma444, sum_ma111);
1708   *sum_ma343 = VaddwHi8(sum333, ma3[1]);
1709   StoreAligned32_ma(ma343 + x, *sum_ma343);
1710   Store343_444(b3, x + kMaStoreOffset, sum_b343, sum_b444, b343, b444);
1711 }
1712 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i sum_b343[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1713 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1714                            const ptrdiff_t x, __m256i* const sum_ma343,
1715                            __m256i sum_b343[2], uint16_t* const ma343,
1716                            uint16_t* const ma444, uint32_t* const b343,
1717                            uint32_t* const b444) {
1718   __m256i sum_ma444, sum_b444[2];
1719   Store343_444Lo(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
1720                  ma444, b343, b444);
1721 }
1722 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,__m256i * const sum_ma343,__m256i sum_b343[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1723 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1724                            const ptrdiff_t x, __m256i* const sum_ma343,
1725                            __m256i sum_b343[2], uint16_t* const ma343,
1726                            uint16_t* const ma444, uint32_t* const b343,
1727                            uint32_t* const b444) {
1728   __m256i sum_ma444, sum_b444[2];
1729   Store343_444Hi(ma3, b3, x, sum_ma343, &sum_ma444, sum_b343, sum_b444, ma343,
1730                  ma444, b343, b444);
1731 }
1732 
Store343_444Lo(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1733 inline void Store343_444Lo(const __m256i ma3[3], const __m256i b3[2],
1734                            const ptrdiff_t x, uint16_t* const ma343,
1735                            uint16_t* const ma444, uint32_t* const b343,
1736                            uint32_t* const b444) {
1737   __m256i sum_ma343, sum_b343[2];
1738   Store343_444Lo(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
1739 }
1740 
Store343_444Hi(const __m256i ma3[3],const __m256i b3[2],const ptrdiff_t x,uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1741 inline void Store343_444Hi(const __m256i ma3[3], const __m256i b3[2],
1742                            const ptrdiff_t x, uint16_t* const ma343,
1743                            uint16_t* const ma444, uint32_t* const b343,
1744                            uint32_t* const b444) {
1745   __m256i sum_ma343, sum_b343[2];
1746   Store343_444Hi(ma3, b3, x, &sum_ma343, sum_b343, ma343, ma444, b343, b444);
1747 }
1748 
1749 // Don't combine the following 2 functions, which would be slower.
Store343_444(const __m256i ma3[3],const __m256i b3[6],const ptrdiff_t x,__m256i * const sum_ma343_lo,__m256i * const sum_ma343_hi,__m256i * const sum_ma444_lo,__m256i * const sum_ma444_hi,__m256i sum_b343_lo[2],__m256i sum_b343_hi[2],__m256i sum_b444_lo[2],__m256i sum_b444_hi[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1750 inline void Store343_444(const __m256i ma3[3], const __m256i b3[6],
1751                          const ptrdiff_t x, __m256i* const sum_ma343_lo,
1752                          __m256i* const sum_ma343_hi,
1753                          __m256i* const sum_ma444_lo,
1754                          __m256i* const sum_ma444_hi, __m256i sum_b343_lo[2],
1755                          __m256i sum_b343_hi[2], __m256i sum_b444_lo[2],
1756                          __m256i sum_b444_hi[2], uint16_t* const ma343,
1757                          uint16_t* const ma444, uint32_t* const b343,
1758                          uint32_t* const b444) {
1759   __m256i sum_mat343[2], sum_mat444[2];
1760   const __m256i sum_ma111_lo = Sum3WLo16(ma3);
1761   sum_mat444[0] = _mm256_slli_epi16(sum_ma111_lo, 2);
1762   const __m256i sum333_lo = _mm256_sub_epi16(sum_mat444[0], sum_ma111_lo);
1763   sum_mat343[0] = VaddwLo8(sum333_lo, ma3[1]);
1764   Store343_444(b3, x, sum_b343_lo, sum_b444_lo, b343, b444);
1765   const __m256i sum_ma111_hi = Sum3WHi16(ma3);
1766   sum_mat444[1] = _mm256_slli_epi16(sum_ma111_hi, 2);
1767   *sum_ma444_lo = _mm256_permute2x128_si256(sum_mat444[0], sum_mat444[1], 0x20);
1768   *sum_ma444_hi = _mm256_permute2x128_si256(sum_mat444[0], sum_mat444[1], 0x31);
1769   StoreAligned32(ma444 + x + 0, *sum_ma444_lo);
1770   StoreAligned32(ma444 + x + 16, *sum_ma444_hi);
1771   const __m256i sum333_hi = _mm256_sub_epi16(sum_mat444[1], sum_ma111_hi);
1772   sum_mat343[1] = VaddwHi8(sum333_hi, ma3[1]);
1773   *sum_ma343_lo = _mm256_permute2x128_si256(sum_mat343[0], sum_mat343[1], 0x20);
1774   *sum_ma343_hi = _mm256_permute2x128_si256(sum_mat343[0], sum_mat343[1], 0x31);
1775   StoreAligned32(ma343 + x + 0, *sum_ma343_lo);
1776   StoreAligned32(ma343 + x + 16, *sum_ma343_hi);
1777   Store343_444(b3 + 3, x + 16, sum_b343_hi, sum_b444_hi, b343, b444);
1778 }
1779 
Store343_444(const __m256i ma3[3],const __m256i b3[6],const ptrdiff_t x,__m256i * const sum_ma343_lo,__m256i * const sum_ma343_hi,__m256i sum_b343_lo[2],__m256i sum_b343_hi[2],uint16_t * const ma343,uint16_t * const ma444,uint32_t * const b343,uint32_t * const b444)1780 inline void Store343_444(const __m256i ma3[3], const __m256i b3[6],
1781                          const ptrdiff_t x, __m256i* const sum_ma343_lo,
1782                          __m256i* const sum_ma343_hi, __m256i sum_b343_lo[2],
1783                          __m256i sum_b343_hi[2], uint16_t* const ma343,
1784                          uint16_t* const ma444, uint32_t* const b343,
1785                          uint32_t* const b444) {
1786   __m256i sum_ma444[2], sum_b444[2], sum_mat343[2];
1787   const __m256i sum_ma111_lo = Sum3WLo16(ma3);
1788   sum_ma444[0] = _mm256_slli_epi16(sum_ma111_lo, 2);
1789   const __m256i sum333_lo = _mm256_sub_epi16(sum_ma444[0], sum_ma111_lo);
1790   sum_mat343[0] = VaddwLo8(sum333_lo, ma3[1]);
1791   Store343_444(b3, x, sum_b343_lo, sum_b444, b343, b444);
1792   const __m256i sum_ma111_hi = Sum3WHi16(ma3);
1793   sum_ma444[1] = _mm256_slli_epi16(sum_ma111_hi, 2);
1794   StoreAligned64_ma(ma444 + x, sum_ma444);
1795   const __m256i sum333_hi = _mm256_sub_epi16(sum_ma444[1], sum_ma111_hi);
1796   sum_mat343[1] = VaddwHi8(sum333_hi, ma3[1]);
1797   *sum_ma343_lo = _mm256_permute2x128_si256(sum_mat343[0], sum_mat343[1], 0x20);
1798   *sum_ma343_hi = _mm256_permute2x128_si256(sum_mat343[0], sum_mat343[1], 0x31);
1799   StoreAligned32(ma343 + x + 0, *sum_ma343_lo);
1800   StoreAligned32(ma343 + x + 16, *sum_ma343_hi);
1801   Store343_444(b3 + 3, x + 16, sum_b343_hi, sum_b444, b343, b444);
1802 }
1803 
PermuteB(const __m256i t[4],__m256i b[7])1804 inline void PermuteB(const __m256i t[4], __m256i b[7]) {
1805   // Input:
1806   //                             0     1      2     3  // b[0]
1807   //                             4     5      6     7  // b[1]
1808   //  8     9     10    11      24    25     26    27  // t[0]
1809   // 12    13     14    15      28    29     30    31  // t[1]
1810   // 16    17     18    19      32    33     34    35  // t[2]
1811   // 20    21     22    23      36    37     38    39  // t[3]
1812 
1813   // Output:
1814   //  0     1      2     3       8     9     10    11  // b[0]
1815   //  4     5      6     7      12    13     14    15  // b[1]
1816   //  8     9     10    11      16    17     18    19  // b[2]
1817   // 16    17     18    19      24    25     26    27  // b[3]
1818   // 20    21     22    23      28    29     30    31  // b[4]
1819   // 24    25     26    27      32    33     34    35  // b[5]
1820   // 20    21     22    23      36    37     38    39  // b[6]
1821   b[0] = _mm256_permute2x128_si256(b[0], t[0], 0x21);
1822   b[1] = _mm256_permute2x128_si256(b[1], t[1], 0x21);
1823   b[2] = _mm256_permute2x128_si256(t[0], t[2], 0x20);
1824   b[3] = _mm256_permute2x128_si256(t[2], t[0], 0x30);
1825   b[4] = _mm256_permute2x128_si256(t[3], t[1], 0x30);
1826   b[5] = _mm256_permute2x128_si256(t[0], t[2], 0x31);
1827   b[6] = t[3];
1828 }
1829 
BoxFilterPreProcess5Lo(const __m128i s[2][2],const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],__m128i sq[2][4],__m128i * const ma,__m128i b[2])1830 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5Lo(
1831     const __m128i s[2][2], const uint32_t scale, uint16_t* const sum5[5],
1832     uint32_t* const square_sum5[5], __m128i sq[2][4], __m128i* const ma,
1833     __m128i b[2]) {
1834   __m128i s5[2][5], sq5[5][2];
1835   Square(s[0][1], sq[0] + 2);
1836   Square(s[1][1], sq[1] + 2);
1837   s5[0][3] = Sum5Horizontal16(s[0]);
1838   StoreAligned16(sum5[3], s5[0][3]);
1839   s5[0][4] = Sum5Horizontal16(s[1]);
1840   StoreAligned16(sum5[4], s5[0][4]);
1841   Sum5Horizontal32(sq[0], sq5[3]);
1842   StoreAligned32U32(square_sum5[3], sq5[3]);
1843   Sum5Horizontal32(sq[1], sq5[4]);
1844   StoreAligned32U32(square_sum5[4], sq5[4]);
1845   LoadAligned16x3U16(sum5, 0, s5[0]);
1846   LoadAligned32x3U32(square_sum5, 0, sq5);
1847   CalculateIntermediate5(s5[0], sq5, scale, ma, b);
1848 }
1849 
BoxFilterPreProcess5(const uint16_t * const src0,const uint16_t * const src1,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],__m256i sq[2][8],__m256i ma[3],__m256i b[3])1850 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5(
1851     const uint16_t* const src0, const uint16_t* const src1,
1852     const ptrdiff_t over_read_in_bytes, const ptrdiff_t sum_width,
1853     const ptrdiff_t x, const uint32_t scale, uint16_t* const sum5[5],
1854     uint32_t* const square_sum5[5], __m256i sq[2][8], __m256i ma[3],
1855     __m256i b[3]) {
1856   __m256i s[2], s5[2][5], sq5[5][2], sum[2], index[2], t[4];
1857   s[0] = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 16);
1858   s[1] = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 16);
1859   Square(s[0], sq[0] + 2);
1860   Square(s[1], sq[1] + 2);
1861   sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21);
1862   sq[0][1] = _mm256_permute2x128_si256(sq[0][1], sq[0][3], 0x21);
1863   sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21);
1864   sq[1][1] = _mm256_permute2x128_si256(sq[1][1], sq[1][3], 0x21);
1865   s5[0][3] = Sum5Horizontal16(src0 + 0, over_read_in_bytes + 0);
1866   s5[1][3] = Sum5Horizontal16(src0 + 16, over_read_in_bytes + 32);
1867   s5[0][4] = Sum5Horizontal16(src1 + 0, over_read_in_bytes + 0);
1868   s5[1][4] = Sum5Horizontal16(src1 + 16, over_read_in_bytes + 32);
1869   StoreAligned32(sum5[3] + x + 0, s5[0][3]);
1870   StoreAligned32(sum5[3] + x + 16, s5[1][3]);
1871   StoreAligned32(sum5[4] + x + 0, s5[0][4]);
1872   StoreAligned32(sum5[4] + x + 16, s5[1][4]);
1873   Sum5Horizontal32(sq[0], sq5[3]);
1874   StoreAligned64(square_sum5[3] + x, sq5[3]);
1875   Sum5Horizontal32(sq[1], sq5[4]);
1876   StoreAligned64(square_sum5[4] + x, sq5[4]);
1877   LoadAligned32x3U16(sum5, x, s5[0]);
1878   LoadAligned64x3U32(square_sum5, x, sq5);
1879   CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]);
1880 
1881   s[0] = LoadUnaligned32Msan(src0 + 24, over_read_in_bytes + 48);
1882   s[1] = LoadUnaligned32Msan(src1 + 24, over_read_in_bytes + 48);
1883   Square(s[0], sq[0] + 6);
1884   Square(s[1], sq[1] + 6);
1885   sq[0][4] = _mm256_permute2x128_si256(sq[0][2], sq[0][6], 0x21);
1886   sq[0][5] = _mm256_permute2x128_si256(sq[0][3], sq[0][7], 0x21);
1887   sq[1][4] = _mm256_permute2x128_si256(sq[1][2], sq[1][6], 0x21);
1888   sq[1][5] = _mm256_permute2x128_si256(sq[1][3], sq[1][7], 0x21);
1889   Sum5Horizontal32(sq[0] + 4, sq5[3]);
1890   StoreAligned64(square_sum5[3] + x + 16, sq5[3]);
1891   Sum5Horizontal32(sq[1] + 4, sq5[4]);
1892   StoreAligned64(square_sum5[4] + x + 16, sq5[4]);
1893   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
1894   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
1895   CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]);
1896   CalculateIntermediate<25>(sum, index, ma, t, t + 2);
1897   PermuteB(t, b);
1898 }
1899 
BoxFilterPreProcess5LastRowLo(const __m128i s[2],const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],__m128i sq[4],__m128i * const ma,__m128i b[2])1900 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRowLo(
1901     const __m128i s[2], const uint32_t scale, const uint16_t* const sum5[5],
1902     const uint32_t* const square_sum5[5], __m128i sq[4], __m128i* const ma,
1903     __m128i b[2]) {
1904   __m128i s5[5], sq5[5][2];
1905   Square(s[1], sq + 2);
1906   s5[3] = s5[4] = Sum5Horizontal16(s);
1907   Sum5Horizontal32(sq, sq5[3]);
1908   sq5[4][0] = sq5[3][0];
1909   sq5[4][1] = sq5[3][1];
1910   LoadAligned16x3U16(sum5, 0, s5);
1911   LoadAligned32x3U32(square_sum5, 0, sq5);
1912   CalculateIntermediate5(s5, sq5, scale, ma, b);
1913 }
1914 
BoxFilterPreProcess5LastRow(const uint16_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint32_t scale,const uint16_t * const sum5[5],const uint32_t * const square_sum5[5],__m256i sq[3],__m256i ma[3],__m256i b[3])1915 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess5LastRow(
1916     const uint16_t* const src, const ptrdiff_t over_read_in_bytes,
1917     const ptrdiff_t sum_width, const ptrdiff_t x, const uint32_t scale,
1918     const uint16_t* const sum5[5], const uint32_t* const square_sum5[5],
1919     __m256i sq[3], __m256i ma[3], __m256i b[3]) {
1920   const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 16);
1921   __m256i s5[2][5], sq5[5][2], sum[2], index[2], t[4];
1922   Square(s0, sq + 2);
1923   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1924   sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21);
1925   s5[0][3] = Sum5Horizontal16(src + 0, over_read_in_bytes + 0);
1926   s5[1][3] = Sum5Horizontal16(src + 16, over_read_in_bytes + 32);
1927   s5[0][4] = s5[0][3];
1928   s5[1][4] = s5[1][3];
1929   Sum5Horizontal32(sq, sq5[3]);
1930   sq5[4][0] = sq5[3][0];
1931   sq5[4][1] = sq5[3][1];
1932   LoadAligned32x3U16(sum5, x, s5[0]);
1933   LoadAligned64x3U32(square_sum5, x, sq5);
1934   CalculateSumAndIndex5(s5[0], sq5, scale, &sum[0], &index[0]);
1935 
1936   const __m256i s1 = LoadUnaligned32Msan(src + 24, over_read_in_bytes + 48);
1937   Square(s1, sq + 6);
1938   sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21);
1939   sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21);
1940   Sum5Horizontal32(sq + 4, sq5[3]);
1941   sq5[4][0] = sq5[3][0];
1942   sq5[4][1] = sq5[3][1];
1943   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
1944   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
1945   CalculateSumAndIndex5(s5[1], sq5, scale, &sum[1], &index[1]);
1946   CalculateIntermediate<25>(sum, index, ma, t, t + 2);
1947   PermuteB(t, b);
1948 }
1949 
BoxFilterPreProcess3Lo(const __m128i s[2],const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],__m128i sq[4],__m128i * const ma,__m128i b[2])1950 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3Lo(
1951     const __m128i s[2], const uint32_t scale, uint16_t* const sum3[3],
1952     uint32_t* const square_sum3[3], __m128i sq[4], __m128i* const ma,
1953     __m128i b[2]) {
1954   __m128i s3[3], sq3[3][2];
1955   Square(s[1], sq + 2);
1956   s3[2] = Sum3Horizontal16(s);
1957   StoreAligned16(sum3[2], s3[2]);
1958   Sum3Horizontal32(sq, sq3[2]);
1959   StoreAligned32U32(square_sum3[2], sq3[2]);
1960   LoadAligned16x2U16(sum3, 0, s3);
1961   LoadAligned32x2U32(square_sum3, 0, sq3);
1962   CalculateIntermediate3(s3, sq3, scale, ma, b);
1963 }
1964 
BoxFilterPreProcess3(const uint16_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t x,const ptrdiff_t sum_width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],__m256i sq[8],__m256i ma[3],__m256i b[7])1965 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess3(
1966     const uint16_t* const src, const ptrdiff_t over_read_in_bytes,
1967     const ptrdiff_t x, const ptrdiff_t sum_width, const uint32_t scale,
1968     uint16_t* const sum3[3], uint32_t* const square_sum3[3], __m256i sq[8],
1969     __m256i ma[3], __m256i b[7]) {
1970   __m256i s[2], s3[4], sq3[3][2], sum[2], index[2], t[4];
1971   s[0] = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 16);
1972   s[1] = LoadUnaligned32Msan(src + 24, over_read_in_bytes + 48);
1973   Square(s[0], sq + 2);
1974   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
1975   sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21);
1976   s3[2] = Sum3Horizontal16(src, over_read_in_bytes);
1977   s3[3] = Sum3Horizontal16(src + 16, over_read_in_bytes + 32);
1978   StoreAligned64(sum3[2] + x, s3 + 2);
1979   Sum3Horizontal32(sq + 0, sq3[2]);
1980   StoreAligned64(square_sum3[2] + x, sq3[2]);
1981   LoadAligned32x2U16(sum3, x, s3);
1982   LoadAligned64x2U32(square_sum3, x, sq3);
1983   CalculateSumAndIndex3(s3, sq3, scale, &sum[0], &index[0]);
1984 
1985   Square(s[1], sq + 6);
1986   sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21);
1987   sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21);
1988   Sum3Horizontal32(sq + 4, sq3[2]);
1989   StoreAligned64(square_sum3[2] + x + 16, sq3[2]);
1990   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3 + 1);
1991   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
1992   CalculateSumAndIndex3(s3 + 1, sq3, scale, &sum[1], &index[1]);
1993   CalculateIntermediate<9>(sum, index, ma, t, t + 2);
1994   PermuteB(t, b);
1995 }
1996 
BoxFilterPreProcessLo(const __m128i s[2][4],const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],__m128i sq[2][8],__m128i ma3[2][3],__m128i b3[2][10],__m128i * const ma5,__m128i b5[2])1997 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLo(
1998     const __m128i s[2][4], const uint16_t scales[2], uint16_t* const sum3[4],
1999     uint16_t* const sum5[5], uint32_t* const square_sum3[4],
2000     uint32_t* const square_sum5[5], __m128i sq[2][8], __m128i ma3[2][3],
2001     __m128i b3[2][10], __m128i* const ma5, __m128i b5[2]) {
2002   __m128i s3[4], s5[5], sq3[4][2], sq5[5][2], sum[2], index[2];
2003   Square(s[0][1], sq[0] + 2);
2004   Square(s[1][1], sq[1] + 2);
2005   SumHorizontal16(s[0], &s3[2], &s5[3]);
2006   SumHorizontal16(s[1], &s3[3], &s5[4]);
2007   StoreAligned16(sum3[2], s3[2]);
2008   StoreAligned16(sum3[3], s3[3]);
2009   StoreAligned16(sum5[3], s5[3]);
2010   StoreAligned16(sum5[4], s5[4]);
2011   SumHorizontal32(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2012   StoreAligned32U32(square_sum3[2], sq3[2]);
2013   StoreAligned32U32(square_sum5[3], sq5[3]);
2014   SumHorizontal32(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
2015   StoreAligned32U32(square_sum3[3], sq3[3]);
2016   StoreAligned32U32(square_sum5[4], sq5[4]);
2017   LoadAligned16x2U16(sum3, 0, s3);
2018   LoadAligned32x2U32(square_sum3, 0, sq3);
2019   LoadAligned16x3U16(sum5, 0, s5);
2020   LoadAligned32x3U32(square_sum5, 0, sq5);
2021   CalculateSumAndIndex3(s3 + 0, sq3 + 0, scales[1], &sum[0], &index[0]);
2022   CalculateSumAndIndex3(s3 + 1, sq3 + 1, scales[1], &sum[1], &index[1]);
2023   CalculateIntermediate(sum, index, &ma3[0][0], b3[0], b3[1]);
2024   ma3[1][0] = _mm_srli_si128(ma3[0][0], 8);
2025   CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
2026 }
2027 
BoxFilterPreProcess(const uint16_t * const src0,const uint16_t * const src1,const ptrdiff_t over_read_in_bytes,const ptrdiff_t x,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,__m256i sq[2][8],__m256i ma3[2][3],__m256i b3[2][7],__m256i ma5[3],__m256i b5[5])2028 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcess(
2029     const uint16_t* const src0, const uint16_t* const src1,
2030     const ptrdiff_t over_read_in_bytes, const ptrdiff_t x,
2031     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
2032     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2033     const ptrdiff_t sum_width, __m256i sq[2][8], __m256i ma3[2][3],
2034     __m256i b3[2][7], __m256i ma5[3], __m256i b5[5]) {
2035   __m256i s[2], s3[2][4], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2][2],
2036       index_3[2][2], sum_5[2], index_5[2], t[4];
2037   s[0] = LoadUnaligned32Msan(src0 + 8, over_read_in_bytes + 16);
2038   s[1] = LoadUnaligned32Msan(src1 + 8, over_read_in_bytes + 16);
2039   Square(s[0], sq[0] + 2);
2040   Square(s[1], sq[1] + 2);
2041   sq[0][0] = _mm256_permute2x128_si256(sq[0][0], sq[0][2], 0x21);
2042   sq[0][1] = _mm256_permute2x128_si256(sq[0][1], sq[0][3], 0x21);
2043   sq[1][0] = _mm256_permute2x128_si256(sq[1][0], sq[1][2], 0x21);
2044   sq[1][1] = _mm256_permute2x128_si256(sq[1][1], sq[1][3], 0x21);
2045   SumHorizontal16(src0, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3],
2046                   &s5[1][3]);
2047   SumHorizontal16(src1, over_read_in_bytes, &s3[0][3], &s3[1][3], &s5[0][4],
2048                   &s5[1][4]);
2049   StoreAligned32(sum3[2] + x + 0, s3[0][2]);
2050   StoreAligned32(sum3[2] + x + 16, s3[1][2]);
2051   StoreAligned32(sum3[3] + x + 0, s3[0][3]);
2052   StoreAligned32(sum3[3] + x + 16, s3[1][3]);
2053   StoreAligned32(sum5[3] + x + 0, s5[0][3]);
2054   StoreAligned32(sum5[3] + x + 16, s5[1][3]);
2055   StoreAligned32(sum5[4] + x + 0, s5[0][4]);
2056   StoreAligned32(sum5[4] + x + 16, s5[1][4]);
2057   SumHorizontal32(sq[0], &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2058   SumHorizontal32(sq[1], &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
2059   StoreAligned64(square_sum3[2] + x, sq3[2]);
2060   StoreAligned64(square_sum5[3] + x, sq5[3]);
2061   StoreAligned64(square_sum3[3] + x, sq3[3]);
2062   StoreAligned64(square_sum5[4] + x, sq5[4]);
2063   LoadAligned32x2U16(sum3, x, s3[0]);
2064   LoadAligned64x2U32(square_sum3, x, sq3);
2065   CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0][0], &index_3[0][0]);
2066   CalculateSumAndIndex3(s3[0] + 1, sq3 + 1, scales[1], &sum_3[1][0],
2067                         &index_3[1][0]);
2068   LoadAligned32x3U16(sum5, x, s5[0]);
2069   LoadAligned64x3U32(square_sum5, x, sq5);
2070   CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]);
2071 
2072   s[0] = LoadUnaligned32Msan(src0 + 24, over_read_in_bytes + 48);
2073   s[1] = LoadUnaligned32Msan(src1 + 24, over_read_in_bytes + 48);
2074   Square(s[0], sq[0] + 6);
2075   Square(s[1], sq[1] + 6);
2076   sq[0][4] = _mm256_permute2x128_si256(sq[0][2], sq[0][6], 0x21);
2077   sq[0][5] = _mm256_permute2x128_si256(sq[0][3], sq[0][7], 0x21);
2078   sq[1][4] = _mm256_permute2x128_si256(sq[1][2], sq[1][6], 0x21);
2079   sq[1][5] = _mm256_permute2x128_si256(sq[1][3], sq[1][7], 0x21);
2080   SumHorizontal32(sq[0] + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2081   SumHorizontal32(sq[1] + 4, &sq3[3][0], &sq3[3][1], &sq5[4][0], &sq5[4][1]);
2082   StoreAligned64(square_sum3[2] + x + 16, sq3[2]);
2083   StoreAligned64(square_sum5[3] + x + 16, sq5[3]);
2084   StoreAligned64(square_sum3[3] + x + 16, sq3[3]);
2085   StoreAligned64(square_sum5[4] + x + 16, sq5[4]);
2086   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]);
2087   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
2088   CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[0][1], &index_3[0][1]);
2089   CalculateSumAndIndex3(s3[1] + 1, sq3 + 1, scales[1], &sum_3[1][1],
2090                         &index_3[1][1]);
2091   CalculateIntermediate<9>(sum_3[0], index_3[0], ma3[0], t, t + 2);
2092   PermuteB(t, b3[0]);
2093   CalculateIntermediate<9>(sum_3[1], index_3[1], ma3[1], t, t + 2);
2094   PermuteB(t, b3[1]);
2095   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
2096   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
2097   CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]);
2098   CalculateIntermediate<25>(sum_5, index_5, ma5, t, t + 2);
2099   PermuteB(t, b5);
2100 }
2101 
BoxFilterPreProcessLastRowLo(const __m128i s[2],const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],__m128i sq[4],__m128i * const ma3,__m128i * const ma5,__m128i b3[2],__m128i b5[2])2102 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRowLo(
2103     const __m128i s[2], const uint16_t scales[2], const uint16_t* const sum3[4],
2104     const uint16_t* const sum5[5], const uint32_t* const square_sum3[4],
2105     const uint32_t* const square_sum5[5], __m128i sq[4], __m128i* const ma3,
2106     __m128i* const ma5, __m128i b3[2], __m128i b5[2]) {
2107   __m128i s3[3], s5[5], sq3[3][2], sq5[5][2];
2108   Square(s[1], sq + 2);
2109   SumHorizontal16(s, &s3[2], &s5[3]);
2110   SumHorizontal32(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2111   LoadAligned16x3U16(sum5, 0, s5);
2112   s5[4] = s5[3];
2113   LoadAligned32x3U32(square_sum5, 0, sq5);
2114   sq5[4][0] = sq5[3][0];
2115   sq5[4][1] = sq5[3][1];
2116   CalculateIntermediate5(s5, sq5, scales[0], ma5, b5);
2117   LoadAligned16x2U16(sum3, 0, s3);
2118   LoadAligned32x2U32(square_sum3, 0, sq3);
2119   CalculateIntermediate3(s3, sq3, scales[1], ma3, b3);
2120 }
2121 
BoxFilterPreProcessLastRow(const uint16_t * const src,const ptrdiff_t over_read_in_bytes,const ptrdiff_t sum_width,const ptrdiff_t x,const uint16_t scales[2],const uint16_t * const sum3[4],const uint16_t * const sum5[5],const uint32_t * const square_sum3[4],const uint32_t * const square_sum5[5],__m256i sq[6],__m256i ma3[2],__m256i ma5[2],__m256i b3[5],__m256i b5[5])2122 LIBGAV1_ALWAYS_INLINE void BoxFilterPreProcessLastRow(
2123     const uint16_t* const src, const ptrdiff_t over_read_in_bytes,
2124     const ptrdiff_t sum_width, const ptrdiff_t x, const uint16_t scales[2],
2125     const uint16_t* const sum3[4], const uint16_t* const sum5[5],
2126     const uint32_t* const square_sum3[4], const uint32_t* const square_sum5[5],
2127     __m256i sq[6], __m256i ma3[2], __m256i ma5[2], __m256i b3[5],
2128     __m256i b5[5]) {
2129   const __m256i s0 = LoadUnaligned32Msan(src + 8, over_read_in_bytes + 16);
2130   __m256i s3[2][3], s5[2][5], sq3[4][2], sq5[5][2], sum_3[2], index_3[2],
2131       sum_5[2], index_5[2], t[4];
2132   Square(s0, sq + 2);
2133   sq[0] = _mm256_permute2x128_si256(sq[0], sq[2], 0x21);
2134   sq[1] = _mm256_permute2x128_si256(sq[1], sq[3], 0x21);
2135   SumHorizontal16(src, over_read_in_bytes, &s3[0][2], &s3[1][2], &s5[0][3],
2136                   &s5[1][3]);
2137   SumHorizontal32(sq, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2138   LoadAligned32x2U16(sum3, x, s3[0]);
2139   LoadAligned64x2U32(square_sum3, x, sq3);
2140   CalculateSumAndIndex3(s3[0], sq3, scales[1], &sum_3[0], &index_3[0]);
2141   LoadAligned32x3U16(sum5, x, s5[0]);
2142   s5[0][4] = s5[0][3];
2143   LoadAligned64x3U32(square_sum5, x, sq5);
2144   sq5[4][0] = sq5[3][0];
2145   sq5[4][1] = sq5[3][1];
2146   CalculateSumAndIndex5(s5[0], sq5, scales[0], &sum_5[0], &index_5[0]);
2147 
2148   const __m256i s1 = LoadUnaligned32Msan(src + 24, over_read_in_bytes + 48);
2149   Square(s1, sq + 6);
2150   sq[4] = _mm256_permute2x128_si256(sq[2], sq[6], 0x21);
2151   sq[5] = _mm256_permute2x128_si256(sq[3], sq[7], 0x21);
2152   SumHorizontal32(sq + 4, &sq3[2][0], &sq3[2][1], &sq5[3][0], &sq5[3][1]);
2153   LoadAligned32x2U16Msan(sum3, x + 16, sum_width, s3[1]);
2154   LoadAligned64x2U32Msan(square_sum3, x + 16, sum_width, sq3);
2155   CalculateSumAndIndex3(s3[1], sq3, scales[1], &sum_3[1], &index_3[1]);
2156   CalculateIntermediate<9>(sum_3, index_3, ma3, t, t + 2);
2157   PermuteB(t, b3);
2158   LoadAligned32x3U16Msan(sum5, x + 16, sum_width, s5[1]);
2159   s5[1][4] = s5[1][3];
2160   LoadAligned64x3U32Msan(square_sum5, x + 16, sum_width, sq5);
2161   sq5[4][0] = sq5[3][0];
2162   sq5[4][1] = sq5[3][1];
2163   CalculateSumAndIndex5(s5[1], sq5, scales[0], &sum_5[1], &index_5[1]);
2164   CalculateIntermediate<25>(sum_5, index_5, ma5, t, t + 2);
2165   PermuteB(t, b5);
2166 }
2167 
BoxSumFilterPreProcess5(const uint16_t * const src0,const uint16_t * const src1,const int width,const uint32_t scale,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * ma565,uint32_t * b565)2168 inline void BoxSumFilterPreProcess5(const uint16_t* const src0,
2169                                     const uint16_t* const src1, const int width,
2170                                     const uint32_t scale,
2171                                     uint16_t* const sum5[5],
2172                                     uint32_t* const square_sum5[5],
2173                                     const ptrdiff_t sum_width, uint16_t* ma565,
2174                                     uint32_t* b565) {
2175   const ptrdiff_t overread_in_bytes =
2176       kOverreadInBytesPass1_128 - sizeof(*src0) * width;
2177   __m128i s[2][2], ma0, sq_128[2][4], b0[2];
2178   __m256i mas[3], sq[2][8], bs[10];
2179   s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0);
2180   s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16);
2181   s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0);
2182   s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16);
2183   Square(s[0][0], sq_128[0]);
2184   Square(s[1][0], sq_128[1]);
2185   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, b0);
2186   sq[0][0] = SetrM128i(sq_128[0][2], sq_128[0][2]);
2187   sq[0][1] = SetrM128i(sq_128[0][3], sq_128[0][3]);
2188   sq[1][0] = SetrM128i(sq_128[1][2], sq_128[1][2]);
2189   sq[1][1] = SetrM128i(sq_128[1][3], sq_128[1][3]);
2190   mas[0] = SetrM128i(ma0, ma0);
2191   bs[0] = SetrM128i(b0[0], b0[0]);
2192   bs[1] = SetrM128i(b0[1], b0[1]);
2193 
2194   int x = 0;
2195   do {
2196     __m256i ma5[3], ma[2], b[4];
2197     BoxFilterPreProcess5(
2198         src0 + x + 8, src1 + x + 8,
2199         kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), sum_width,
2200         x + 8, scale, sum5, square_sum5, sq, mas, bs);
2201     Prepare3_8(mas, ma5);
2202     ma[0] = Sum565Lo(ma5);
2203     ma[1] = Sum565Hi(ma5);
2204     StoreAligned64_ma(ma565, ma);
2205     Sum565(bs + 0, b + 0);
2206     Sum565(bs + 3, b + 2);
2207     StoreAligned64(b565, b + 0);
2208     StoreAligned64(b565 + 16, b + 2);
2209     sq[0][0] = sq[0][6];
2210     sq[0][1] = sq[0][7];
2211     sq[1][0] = sq[1][6];
2212     sq[1][1] = sq[1][7];
2213     mas[0] = mas[2];
2214     bs[0] = bs[5];
2215     bs[1] = bs[6];
2216     ma565 += 32;
2217     b565 += 32;
2218     x += 32;
2219   } while (x < width);
2220 }
2221 
2222 template <bool calculate444>
BoxSumFilterPreProcess3(const uint16_t * const src,const int width,const uint32_t scale,uint16_t * const sum3[3],uint32_t * const square_sum3[3],const ptrdiff_t sum_width,uint16_t * ma343,uint16_t * ma444,uint32_t * b343,uint32_t * b444)2223 LIBGAV1_ALWAYS_INLINE void BoxSumFilterPreProcess3(
2224     const uint16_t* const src, const int width, const uint32_t scale,
2225     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
2226     const ptrdiff_t sum_width, uint16_t* ma343, uint16_t* ma444, uint32_t* b343,
2227     uint32_t* b444) {
2228   const ptrdiff_t overread_in_bytes_128 =
2229       kOverreadInBytesPass2_128 - sizeof(*src) * width;
2230   __m128i s[2], ma0, sq_128[4], b0[2];
2231   __m256i mas[3], sq[8], bs[7];
2232   s[0] = LoadUnaligned16Msan(src + 0, overread_in_bytes_128 + 0);
2233   s[1] = LoadUnaligned16Msan(src + 8, overread_in_bytes_128 + 16);
2234   Square(s[0], sq_128);
2235   BoxFilterPreProcess3Lo(s, scale, sum3, square_sum3, sq_128, &ma0, b0);
2236   sq[0] = SetrM128i(sq_128[2], sq_128[2]);
2237   sq[1] = SetrM128i(sq_128[3], sq_128[3]);
2238   mas[0] = SetrM128i(ma0, ma0);
2239   bs[0] = SetrM128i(b0[0], b0[0]);
2240   bs[1] = SetrM128i(b0[1], b0[1]);
2241 
2242   int x = 0;
2243   do {
2244     __m256i ma3[3];
2245     BoxFilterPreProcess3(
2246         src + x + 8, kOverreadInBytesPass2_256 + sizeof(*src) * (x + 8 - width),
2247         x + 8, sum_width, scale, sum3, square_sum3, sq, mas, bs);
2248     Prepare3_8(mas, ma3);
2249     if (calculate444) {  // NOLINT(readability-simplify-boolean-expr)
2250       Store343_444Lo(ma3, bs + 0, 0, ma343, ma444, b343, b444);
2251       Store343_444Hi(ma3, bs + 3, kMaStoreOffset, ma343, ma444, b343, b444);
2252       ma444 += 32;
2253       b444 += 32;
2254     } else {
2255       __m256i ma[2], b[4];
2256       ma[0] = Sum343Lo(ma3);
2257       ma[1] = Sum343Hi(ma3);
2258       StoreAligned64_ma(ma343, ma);
2259       Sum343(bs + 0, b + 0);
2260       Sum343(bs + 3, b + 2);
2261       StoreAligned64(b343 + 0, b + 0);
2262       StoreAligned64(b343 + 16, b + 2);
2263     }
2264     sq[0] = sq[6];
2265     sq[1] = sq[7];
2266     mas[0] = mas[2];
2267     bs[0] = bs[5];
2268     bs[1] = bs[6];
2269     ma343 += 32;
2270     b343 += 32;
2271     x += 32;
2272   } while (x < width);
2273 }
2274 
BoxSumFilterPreProcess(const uint16_t * const src0,const uint16_t * const src1,const int width,const uint16_t scales[2],uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * const ma343[4],uint16_t * const ma444,uint16_t * ma565,uint32_t * const b343[4],uint32_t * const b444,uint32_t * b565)2275 inline void BoxSumFilterPreProcess(
2276     const uint16_t* const src0, const uint16_t* const src1, const int width,
2277     const uint16_t scales[2], uint16_t* const sum3[4], uint16_t* const sum5[5],
2278     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2279     const ptrdiff_t sum_width, uint16_t* const ma343[4], uint16_t* const ma444,
2280     uint16_t* ma565, uint32_t* const b343[4], uint32_t* const b444,
2281     uint32_t* b565) {
2282   const ptrdiff_t overread_in_bytes =
2283       kOverreadInBytesPass1_128 - sizeof(*src0) * width;
2284   __m128i s[2][4], ma3_128[2][3], ma5_128[3], sq_128[2][8], b3_128[2][10],
2285       b5_128[10];
2286   __m256i ma3[2][3], ma5[3], sq[2][8], b3[2][7], b5[7];
2287   s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0);
2288   s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16);
2289   s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0);
2290   s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16);
2291   Square(s[0][0], sq_128[0]);
2292   Square(s[1][0], sq_128[1]);
2293   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128,
2294                         ma3_128, b3_128, &ma5_128[0], b5_128);
2295   sq[0][0] = SetrM128i(sq_128[0][2], sq_128[0][2]);
2296   sq[0][1] = SetrM128i(sq_128[0][3], sq_128[0][3]);
2297   sq[1][0] = SetrM128i(sq_128[1][2], sq_128[1][2]);
2298   sq[1][1] = SetrM128i(sq_128[1][3], sq_128[1][3]);
2299   ma3[0][0] = SetrM128i(ma3_128[0][0], ma3_128[0][0]);
2300   ma3[1][0] = SetrM128i(ma3_128[1][0], ma3_128[1][0]);
2301   ma5[0] = SetrM128i(ma5_128[0], ma5_128[0]);
2302   b3[0][0] = SetrM128i(b3_128[0][0], b3_128[0][0]);
2303   b3[0][1] = SetrM128i(b3_128[0][1], b3_128[0][1]);
2304   b3[1][0] = SetrM128i(b3_128[1][0], b3_128[1][0]);
2305   b3[1][1] = SetrM128i(b3_128[1][1], b3_128[1][1]);
2306   b5[0] = SetrM128i(b5_128[0], b5_128[0]);
2307   b5[1] = SetrM128i(b5_128[1], b5_128[1]);
2308 
2309   int x = 0;
2310   do {
2311     __m256i ma[2], b[4], ma3x[3], ma5x[3];
2312     BoxFilterPreProcess(
2313         src0 + x + 8, src1 + x + 8,
2314         kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), x + 8,
2315         scales, sum3, sum5, square_sum3, square_sum5, sum_width, sq, ma3, b3,
2316         ma5, b5);
2317     Prepare3_8(ma3[0], ma3x);
2318     ma[0] = Sum343Lo(ma3x);
2319     ma[1] = Sum343Hi(ma3x);
2320     StoreAligned64_ma(ma343[0] + x, ma);
2321     Sum343(b3[0], b);
2322     Sum343(b3[0] + 3, b + 2);
2323     StoreAligned64(b343[0] + x, b);
2324     StoreAligned64(b343[0] + x + 16, b + 2);
2325     Prepare3_8(ma3[1], ma3x);
2326     Store343_444Lo(ma3x, b3[1], x, ma343[1], ma444, b343[1], b444);
2327     Store343_444Hi(ma3x, b3[1] + 3, x + kMaStoreOffset, ma343[1], ma444,
2328                    b343[1], b444);
2329     Prepare3_8(ma5, ma5x);
2330     ma[0] = Sum565Lo(ma5x);
2331     ma[1] = Sum565Hi(ma5x);
2332     StoreAligned64_ma(ma565, ma);
2333     Sum565(b5, b);
2334     StoreAligned64(b565, b);
2335     Sum565(b5 + 3, b);
2336     StoreAligned64(b565 + 16, b);
2337     sq[0][0] = sq[0][6];
2338     sq[0][1] = sq[0][7];
2339     sq[1][0] = sq[1][6];
2340     sq[1][1] = sq[1][7];
2341     ma3[0][0] = ma3[0][2];
2342     ma3[1][0] = ma3[1][2];
2343     ma5[0] = ma5[2];
2344     b3[0][0] = b3[0][5];
2345     b3[0][1] = b3[0][6];
2346     b3[1][0] = b3[1][5];
2347     b3[1][1] = b3[1][6];
2348     b5[0] = b5[5];
2349     b5[1] = b5[6];
2350     ma565 += 32;
2351     b565 += 32;
2352     x += 32;
2353   } while (x < width);
2354 }
2355 
2356 template <int shift>
FilterOutput(const __m256i ma_x_src,const __m256i b)2357 inline __m256i FilterOutput(const __m256i ma_x_src, const __m256i b) {
2358   // ma: 255 * 32 = 8160 (13 bits)
2359   // b: 65088 * 32 = 2082816 (21 bits)
2360   // v: b - ma * 255 (22 bits)
2361   const __m256i v = _mm256_sub_epi32(b, ma_x_src);
2362   // kSgrProjSgrBits = 8
2363   // kSgrProjRestoreBits = 4
2364   // shift = 4 or 5
2365   // v >> 8 or 9 (13 bits)
2366   return VrshrS32(v, kSgrProjSgrBits + shift - kSgrProjRestoreBits);
2367 }
2368 
2369 template <int shift>
CalculateFilteredOutput(const __m256i src,const __m256i ma,const __m256i b[2])2370 inline __m256i CalculateFilteredOutput(const __m256i src, const __m256i ma,
2371                                        const __m256i b[2]) {
2372   const __m256i ma_x_src_lo = VmullLo16(ma, src);
2373   const __m256i ma_x_src_hi = VmullHi16(ma, src);
2374   const __m256i dst_lo = FilterOutput<shift>(ma_x_src_lo, b[0]);
2375   const __m256i dst_hi = FilterOutput<shift>(ma_x_src_hi, b[1]);
2376   return _mm256_packs_epi32(dst_lo, dst_hi);  // 13 bits
2377 }
2378 
CalculateFilteredOutputPass1(const __m256i src,const __m256i ma[2],const __m256i b[2][2])2379 inline __m256i CalculateFilteredOutputPass1(const __m256i src,
2380                                             const __m256i ma[2],
2381                                             const __m256i b[2][2]) {
2382   const __m256i ma_sum = _mm256_add_epi16(ma[0], ma[1]);
2383   __m256i b_sum[2];
2384   b_sum[0] = _mm256_add_epi32(b[0][0], b[1][0]);
2385   b_sum[1] = _mm256_add_epi32(b[0][1], b[1][1]);
2386   return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
2387 }
2388 
CalculateFilteredOutputPass2(const __m256i src,const __m256i ma[3],const __m256i b[3][2])2389 inline __m256i CalculateFilteredOutputPass2(const __m256i src,
2390                                             const __m256i ma[3],
2391                                             const __m256i b[3][2]) {
2392   const __m256i ma_sum = Sum3_16(ma);
2393   __m256i b_sum[2];
2394   Sum3_32(b, b_sum);
2395   return CalculateFilteredOutput<5>(src, ma_sum, b_sum);
2396 }
2397 
SelfGuidedFinal(const __m256i src,const __m256i v[2])2398 inline __m256i SelfGuidedFinal(const __m256i src, const __m256i v[2]) {
2399   const __m256i v_lo =
2400       VrshrS32(v[0], kSgrProjRestoreBits + kSgrProjPrecisionBits);
2401   const __m256i v_hi =
2402       VrshrS32(v[1], kSgrProjRestoreBits + kSgrProjPrecisionBits);
2403   const __m256i vv = _mm256_packs_epi32(v_lo, v_hi);
2404   return _mm256_add_epi16(src, vv);
2405 }
2406 
SelfGuidedDoubleMultiplier(const __m256i src,const __m256i filter[2],const int w0,const int w2)2407 inline __m256i SelfGuidedDoubleMultiplier(const __m256i src,
2408                                           const __m256i filter[2], const int w0,
2409                                           const int w2) {
2410   __m256i v[2];
2411   const __m256i w0_w2 =
2412       _mm256_set1_epi32((w2 << 16) | static_cast<uint16_t>(w0));
2413   const __m256i f_lo = _mm256_unpacklo_epi16(filter[0], filter[1]);
2414   const __m256i f_hi = _mm256_unpackhi_epi16(filter[0], filter[1]);
2415   v[0] = _mm256_madd_epi16(w0_w2, f_lo);
2416   v[1] = _mm256_madd_epi16(w0_w2, f_hi);
2417   return SelfGuidedFinal(src, v);
2418 }
2419 
SelfGuidedSingleMultiplier(const __m256i src,const __m256i filter,const int w0)2420 inline __m256i SelfGuidedSingleMultiplier(const __m256i src,
2421                                           const __m256i filter, const int w0) {
2422   // weight: -96 to 96 (Sgrproj_Xqd_Min/Max)
2423   __m256i v[2];
2424   v[0] = VmullNLo8(filter, w0);
2425   v[1] = VmullNHi8(filter, w0);
2426   return SelfGuidedFinal(src, v);
2427 }
2428 
ClipAndStore(uint16_t * const dst,const __m256i val)2429 inline void ClipAndStore(uint16_t* const dst, const __m256i val) {
2430   const __m256i val0 = _mm256_max_epi16(val, _mm256_setzero_si256());
2431   const __m256i val1 = _mm256_min_epi16(val0, _mm256_set1_epi16(1023));
2432   StoreUnaligned32(dst, val1);
2433 }
2434 
BoxFilterPass1(const uint16_t * const src,const uint16_t * const src0,const uint16_t * const src1,const ptrdiff_t stride,uint16_t * const sum5[5],uint32_t * const square_sum5[5],const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const ma565[2],uint32_t * const b565[2],uint16_t * const dst)2435 LIBGAV1_ALWAYS_INLINE void BoxFilterPass1(
2436     const uint16_t* const src, const uint16_t* const src0,
2437     const uint16_t* const src1, const ptrdiff_t stride, uint16_t* const sum5[5],
2438     uint32_t* const square_sum5[5], const int width, const ptrdiff_t sum_width,
2439     const uint32_t scale, const int16_t w0, uint16_t* const ma565[2],
2440     uint32_t* const b565[2], uint16_t* const dst) {
2441   const ptrdiff_t overread_in_bytes =
2442       kOverreadInBytesPass1_128 - sizeof(*src0) * width;
2443   __m128i s[2][2], ma0, sq_128[2][4], b0[2];
2444   __m256i mas[3], sq[2][8], bs[7];
2445   s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0);
2446   s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16);
2447   s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0);
2448   s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16);
2449   Square(s[0][0], sq_128[0]);
2450   Square(s[1][0], sq_128[1]);
2451   BoxFilterPreProcess5Lo(s, scale, sum5, square_sum5, sq_128, &ma0, b0);
2452   sq[0][0] = SetrM128i(sq_128[0][2], sq_128[0][2]);
2453   sq[0][1] = SetrM128i(sq_128[0][3], sq_128[0][3]);
2454   sq[1][0] = SetrM128i(sq_128[1][2], sq_128[1][2]);
2455   sq[1][1] = SetrM128i(sq_128[1][3], sq_128[1][3]);
2456   mas[0] = SetrM128i(ma0, ma0);
2457   bs[0] = SetrM128i(b0[0], b0[0]);
2458   bs[1] = SetrM128i(b0[1], b0[1]);
2459 
2460   int x = 0;
2461   do {
2462     __m256i ma5[3], ma[4], b[4][2];
2463     BoxFilterPreProcess5(
2464         src0 + x + 8, src1 + x + 8,
2465         kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), sum_width,
2466         x + 8, scale, sum5, square_sum5, sq, mas, bs);
2467     Prepare3_8(mas, ma5);
2468     ma[2] = Sum565Lo(ma5);
2469     ma[3] = Sum565Hi(ma5);
2470     ma[1] = _mm256_permute2x128_si256(ma[2], ma[3], 0x20);
2471     ma[3] = _mm256_permute2x128_si256(ma[2], ma[3], 0x31);
2472     StoreAligned32(ma565[1] + x + 0, ma[1]);
2473     StoreAligned32(ma565[1] + x + 16, ma[3]);
2474     Sum565(bs + 0, b[1]);
2475     Sum565(bs + 3, b[3]);
2476     StoreAligned64(b565[1] + x, b[1]);
2477     StoreAligned64(b565[1] + x + 16, b[3]);
2478     const __m256i sr0_lo = LoadUnaligned32(src + x + 0);
2479     ma[0] = LoadAligned32(ma565[0] + x);
2480     LoadAligned64(b565[0] + x, b[0]);
2481     const __m256i p0 = CalculateFilteredOutputPass1(sr0_lo, ma, b);
2482     const __m256i d0 = SelfGuidedSingleMultiplier(sr0_lo, p0, w0);
2483     ClipAndStore(dst + x + 0, d0);
2484     const __m256i sr0_hi = LoadUnaligned32(src + x + 16);
2485     ma[2] = LoadAligned32(ma565[0] + x + 16);
2486     LoadAligned64(b565[0] + x + 16, b[2]);
2487     const __m256i p1 = CalculateFilteredOutputPass1(sr0_hi, ma + 2, b + 2);
2488     const __m256i d1 = SelfGuidedSingleMultiplier(sr0_hi, p1, w0);
2489     ClipAndStore(dst + x + 16, d1);
2490     const __m256i sr1_lo = LoadUnaligned32(src + stride + x + 0);
2491     const __m256i p10 = CalculateFilteredOutput<4>(sr1_lo, ma[1], b[1]);
2492     const __m256i d10 = SelfGuidedSingleMultiplier(sr1_lo, p10, w0);
2493     ClipAndStore(dst + stride + x + 0, d10);
2494     const __m256i sr1_hi = LoadUnaligned32(src + stride + x + 16);
2495     const __m256i p11 = CalculateFilteredOutput<4>(sr1_hi, ma[3], b[3]);
2496     const __m256i d11 = SelfGuidedSingleMultiplier(sr1_hi, p11, w0);
2497     ClipAndStore(dst + stride + x + 16, d11);
2498     sq[0][0] = sq[0][6];
2499     sq[0][1] = sq[0][7];
2500     sq[1][0] = sq[1][6];
2501     sq[1][1] = sq[1][7];
2502     mas[0] = mas[2];
2503     bs[0] = bs[5];
2504     bs[1] = bs[6];
2505     x += 32;
2506   } while (x < width);
2507 }
2508 
BoxFilterPass1LastRow(const uint16_t * const src,const uint16_t * const src0,const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const sum5[5],uint32_t * const square_sum5[5],uint16_t * ma565,uint32_t * b565,uint16_t * const dst)2509 inline void BoxFilterPass1LastRow(
2510     const uint16_t* const src, const uint16_t* const src0, const int width,
2511     const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
2512     uint16_t* const sum5[5], uint32_t* const square_sum5[5], uint16_t* ma565,
2513     uint32_t* b565, uint16_t* const dst) {
2514   const ptrdiff_t overread_in_bytes =
2515       kOverreadInBytesPass1_128 - sizeof(*src0) * width;
2516   __m128i s[2], ma0[2], sq_128[8], b0[6];
2517   __m256i mas[3], sq[8], bs[7];
2518   s[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0);
2519   s[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16);
2520   Square(s[0], sq_128);
2521   BoxFilterPreProcess5LastRowLo(s, scale, sum5, square_sum5, sq_128, &ma0[0],
2522                                 b0);
2523   sq[0] = SetrM128i(sq_128[2], sq_128[2]);
2524   sq[1] = SetrM128i(sq_128[3], sq_128[3]);
2525   mas[0] = SetrM128i(ma0[0], ma0[0]);
2526   bs[0] = SetrM128i(b0[0], b0[0]);
2527   bs[1] = SetrM128i(b0[1], b0[1]);
2528 
2529   int x = 0;
2530   do {
2531     __m256i ma5[3], ma[4], b[4][2];
2532     BoxFilterPreProcess5LastRow(
2533         src0 + x + 8,
2534         kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), sum_width,
2535         x + 8, scale, sum5, square_sum5, sq, mas, bs);
2536     Prepare3_8(mas, ma5);
2537     ma[2] = Sum565Lo(ma5);
2538     ma[3] = Sum565Hi(ma5);
2539     Sum565(bs + 0, b[1]);
2540     Sum565(bs + 3, b[3]);
2541     const __m256i sr0_lo = LoadUnaligned32(src + x + 0);
2542     ma[0] = LoadAligned32(ma565 + x);
2543     ma[1] = _mm256_permute2x128_si256(ma[2], ma[3], 0x20);
2544     LoadAligned64(b565 + x, b[0]);
2545     const __m256i p0 = CalculateFilteredOutputPass1(sr0_lo, ma, b);
2546     const __m256i d0 = SelfGuidedSingleMultiplier(sr0_lo, p0, w0);
2547     ClipAndStore(dst + x + 0, d0);
2548     const __m256i sr0_hi = LoadUnaligned32(src + x + 16);
2549     ma[0] = LoadAligned32(ma565 + x + 16);
2550     ma[1] = _mm256_permute2x128_si256(ma[2], ma[3], 0x31);
2551     LoadAligned64(b565 + x + 16, b[2]);
2552     const __m256i p1 = CalculateFilteredOutputPass1(sr0_hi, ma, b + 2);
2553     const __m256i d1 = SelfGuidedSingleMultiplier(sr0_hi, p1, w0);
2554     ClipAndStore(dst + x + 16, d1);
2555     sq[0] = sq[6];
2556     sq[1] = sq[7];
2557     mas[0] = mas[2];
2558     bs[0] = bs[5];
2559     bs[1] = bs[6];
2560     x += 32;
2561   } while (x < width);
2562 }
2563 
BoxFilterPass2(const uint16_t * const src,const uint16_t * const src0,const int width,const ptrdiff_t sum_width,const uint32_t scale,const int16_t w0,uint16_t * const sum3[3],uint32_t * const square_sum3[3],uint16_t * const ma343[3],uint16_t * const ma444[2],uint32_t * const b343[3],uint32_t * const b444[2],uint16_t * const dst)2564 LIBGAV1_ALWAYS_INLINE void BoxFilterPass2(
2565     const uint16_t* const src, const uint16_t* const src0, const int width,
2566     const ptrdiff_t sum_width, const uint32_t scale, const int16_t w0,
2567     uint16_t* const sum3[3], uint32_t* const square_sum3[3],
2568     uint16_t* const ma343[3], uint16_t* const ma444[2], uint32_t* const b343[3],
2569     uint32_t* const b444[2], uint16_t* const dst) {
2570   const ptrdiff_t overread_in_bytes_128 =
2571       kOverreadInBytesPass2_128 - sizeof(*src0) * width;
2572   __m128i s0[2], ma0, sq_128[4], b0[2];
2573   __m256i mas[3], sq[8], bs[7];
2574   s0[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes_128 + 0);
2575   s0[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes_128 + 16);
2576   Square(s0[0], sq_128);
2577   BoxFilterPreProcess3Lo(s0, scale, sum3, square_sum3, sq_128, &ma0, b0);
2578   sq[0] = SetrM128i(sq_128[2], sq_128[2]);
2579   sq[1] = SetrM128i(sq_128[3], sq_128[3]);
2580   mas[0] = SetrM128i(ma0, ma0);
2581   bs[0] = SetrM128i(b0[0], b0[0]);
2582   bs[1] = SetrM128i(b0[1], b0[1]);
2583 
2584   int x = 0;
2585   do {
2586     __m256i ma[4], b[4][2], ma3[3];
2587     BoxFilterPreProcess3(
2588         src0 + x + 8,
2589         kOverreadInBytesPass2_256 + sizeof(*src0) * (x + 8 - width), x + 8,
2590         sum_width, scale, sum3, square_sum3, sq, mas, bs);
2591     Prepare3_8(mas, ma3);
2592     Store343_444(ma3, bs, x, &ma[2], &ma[3], b[2], b[3], ma343[2], ma444[1],
2593                  b343[2], b444[1]);
2594     const __m256i sr_lo = LoadUnaligned32(src + x + 0);
2595     const __m256i sr_hi = LoadUnaligned32(src + x + 16);
2596     ma[0] = LoadAligned32(ma343[0] + x);
2597     ma[1] = LoadAligned32(ma444[0] + x);
2598     LoadAligned64(b343[0] + x, b[0]);
2599     LoadAligned64(b444[0] + x, b[1]);
2600     const __m256i p0 = CalculateFilteredOutputPass2(sr_lo, ma, b);
2601     ma[1] = LoadAligned32(ma343[0] + x + 16);
2602     ma[2] = LoadAligned32(ma444[0] + x + 16);
2603     LoadAligned64(b343[0] + x + 16, b[1]);
2604     LoadAligned64(b444[0] + x + 16, b[2]);
2605     const __m256i p1 = CalculateFilteredOutputPass2(sr_hi, ma + 1, b + 1);
2606     const __m256i d0 = SelfGuidedSingleMultiplier(sr_lo, p0, w0);
2607     const __m256i d1 = SelfGuidedSingleMultiplier(sr_hi, p1, w0);
2608     ClipAndStore(dst + x + 0, d0);
2609     ClipAndStore(dst + x + 16, d1);
2610     sq[0] = sq[6];
2611     sq[1] = sq[7];
2612     mas[0] = mas[2];
2613     bs[0] = bs[5];
2614     bs[1] = bs[6];
2615     x += 32;
2616   } while (x < width);
2617 }
2618 
BoxFilter(const uint16_t * const src,const uint16_t * const src0,const uint16_t * const src1,const ptrdiff_t stride,const int width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],const ptrdiff_t sum_width,uint16_t * const ma343[4],uint16_t * const ma444[3],uint16_t * const ma565[2],uint32_t * const b343[4],uint32_t * const b444[3],uint32_t * const b565[2],uint16_t * const dst)2619 LIBGAV1_ALWAYS_INLINE void BoxFilter(
2620     const uint16_t* const src, const uint16_t* const src0,
2621     const uint16_t* const src1, const ptrdiff_t stride, const int width,
2622     const uint16_t scales[2], const int16_t w0, const int16_t w2,
2623     uint16_t* const sum3[4], uint16_t* const sum5[5],
2624     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2625     const ptrdiff_t sum_width, uint16_t* const ma343[4],
2626     uint16_t* const ma444[3], uint16_t* const ma565[2], uint32_t* const b343[4],
2627     uint32_t* const b444[3], uint32_t* const b565[2], uint16_t* const dst) {
2628   const ptrdiff_t overread_in_bytes =
2629       kOverreadInBytesPass1_128 - sizeof(*src0) * width;
2630   __m128i s[2][4], ma3_128[2][3], ma5_0, sq_128[2][8], b3_128[2][10], b5_128[2];
2631   __m256i ma3[2][3], ma5[3], sq[2][8], b3[2][7], b5[7];
2632   s[0][0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0);
2633   s[0][1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16);
2634   s[1][0] = LoadUnaligned16Msan(src1 + 0, overread_in_bytes + 0);
2635   s[1][1] = LoadUnaligned16Msan(src1 + 8, overread_in_bytes + 16);
2636   Square(s[0][0], sq_128[0]);
2637   Square(s[1][0], sq_128[1]);
2638   BoxFilterPreProcessLo(s, scales, sum3, sum5, square_sum3, square_sum5, sq_128,
2639                         ma3_128, b3_128, &ma5_0, b5_128);
2640   sq[0][0] = SetrM128i(sq_128[0][2], sq_128[0][2]);
2641   sq[0][1] = SetrM128i(sq_128[0][3], sq_128[0][3]);
2642   sq[1][0] = SetrM128i(sq_128[1][2], sq_128[1][2]);
2643   sq[1][1] = SetrM128i(sq_128[1][3], sq_128[1][3]);
2644   ma3[0][0] = SetrM128i(ma3_128[0][0], ma3_128[0][0]);
2645   ma3[1][0] = SetrM128i(ma3_128[1][0], ma3_128[1][0]);
2646   ma5[0] = SetrM128i(ma5_0, ma5_0);
2647   b3[0][0] = SetrM128i(b3_128[0][0], b3_128[0][0]);
2648   b3[0][1] = SetrM128i(b3_128[0][1], b3_128[0][1]);
2649   b3[1][0] = SetrM128i(b3_128[1][0], b3_128[1][0]);
2650   b3[1][1] = SetrM128i(b3_128[1][1], b3_128[1][1]);
2651   b5[0] = SetrM128i(b5_128[0], b5_128[0]);
2652   b5[1] = SetrM128i(b5_128[1], b5_128[1]);
2653 
2654   int x = 0;
2655   do {
2656     __m256i ma[3][4], mat[3][3], b[3][3][2], bt[3][3][2], p[2][2], ma3x[2][3],
2657         ma5x[3];
2658     BoxFilterPreProcess(
2659         src0 + x + 8, src1 + x + 8,
2660         kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), x + 8,
2661         scales, sum3, sum5, square_sum3, square_sum5, sum_width, sq, ma3, b3,
2662         ma5, b5);
2663     Prepare3_8(ma3[0], ma3x[0]);
2664     Prepare3_8(ma3[1], ma3x[1]);
2665     Prepare3_8(ma5, ma5x);
2666     Store343_444(ma3x[0], b3[0], x, &ma[1][2], &mat[1][2], &ma[2][1],
2667                  &mat[2][1], b[1][2], bt[1][2], b[2][1], bt[2][1], ma343[2],
2668                  ma444[1], b343[2], b444[1]);
2669     Store343_444(ma3x[1], b3[1], x, &ma[2][2], &mat[2][2], b[2][2], bt[2][2],
2670                  ma343[3], ma444[2], b343[3], b444[2]);
2671 
2672     ma[0][2] = Sum565Lo(ma5x);
2673     ma[0][3] = Sum565Hi(ma5x);
2674     ma[0][1] = _mm256_permute2x128_si256(ma[0][2], ma[0][3], 0x20);
2675     ma[0][3] = _mm256_permute2x128_si256(ma[0][2], ma[0][3], 0x31);
2676     StoreAligned32(ma565[1] + x + 0, ma[0][1]);
2677     StoreAligned32(ma565[1] + x + 16, ma[0][3]);
2678     Sum565(b5, b[0][1]);
2679     StoreAligned64(b565[1] + x, b[0][1]);
2680     const __m256i sr0_lo = LoadUnaligned32(src + x);
2681     const __m256i sr1_lo = LoadUnaligned32(src + stride + x);
2682     ma[0][0] = LoadAligned32(ma565[0] + x);
2683     LoadAligned64(b565[0] + x, b[0][0]);
2684     p[0][0] = CalculateFilteredOutputPass1(sr0_lo, ma[0], b[0]);
2685     p[1][0] = CalculateFilteredOutput<4>(sr1_lo, ma[0][1], b[0][1]);
2686     ma[1][0] = LoadAligned32(ma343[0] + x);
2687     ma[1][1] = LoadAligned32(ma444[0] + x);
2688     // Keeping the following 4 redundant lines is faster. The reason is that
2689     // there are not enough registers available, and these values could be saved
2690     // and loaded which is even slower.
2691     ma[1][2] = LoadAligned32(ma343[2] + x);  // Redundant line 1.
2692     LoadAligned64(b343[0] + x, b[1][0]);
2693     LoadAligned64(b444[0] + x, b[1][1]);
2694     p[0][1] = CalculateFilteredOutputPass2(sr0_lo, ma[1], b[1]);
2695     ma[2][0] = LoadAligned32(ma343[1] + x);
2696     ma[2][1] = LoadAligned32(ma444[1] + x);  // Redundant line 2.
2697     LoadAligned64(b343[1] + x, b[2][0]);
2698     p[1][1] = CalculateFilteredOutputPass2(sr1_lo, ma[2], b[2]);
2699     const __m256i d00 = SelfGuidedDoubleMultiplier(sr0_lo, p[0], w0, w2);
2700     ClipAndStore(dst + x, d00);
2701     const __m256i d10x = SelfGuidedDoubleMultiplier(sr1_lo, p[1], w0, w2);
2702     ClipAndStore(dst + stride + x, d10x);
2703 
2704     Sum565(b5 + 3, bt[0][1]);
2705     StoreAligned64(b565[1] + x + 16, bt[0][1]);
2706     const __m256i sr0_hi = LoadUnaligned32(src + x + 16);
2707     const __m256i sr1_hi = LoadUnaligned32(src + stride + x + 16);
2708     ma[0][2] = LoadAligned32(ma565[0] + x + 16);
2709     LoadAligned64(b565[0] + x + 16, bt[0][0]);
2710     p[0][0] = CalculateFilteredOutputPass1(sr0_hi, ma[0] + 2, bt[0]);
2711     p[1][0] = CalculateFilteredOutput<4>(sr1_hi, ma[0][3], bt[0][1]);
2712     mat[1][0] = LoadAligned32(ma343[0] + x + 16);
2713     mat[1][1] = LoadAligned32(ma444[0] + x + 16);
2714     mat[1][2] = LoadAligned32(ma343[2] + x + 16);  // Redundant line 3.
2715     LoadAligned64(b343[0] + x + 16, bt[1][0]);
2716     LoadAligned64(b444[0] + x + 16, bt[1][1]);
2717     p[0][1] = CalculateFilteredOutputPass2(sr0_hi, mat[1], bt[1]);
2718     mat[2][0] = LoadAligned32(ma343[1] + x + 16);
2719     mat[2][1] = LoadAligned32(ma444[1] + x + 16);  // Redundant line 4.
2720     LoadAligned64(b343[1] + x + 16, bt[2][0]);
2721     p[1][1] = CalculateFilteredOutputPass2(sr1_hi, mat[2], bt[2]);
2722     const __m256i d01 = SelfGuidedDoubleMultiplier(sr0_hi, p[0], w0, w2);
2723     ClipAndStore(dst + x + 16, d01);
2724     const __m256i d11 = SelfGuidedDoubleMultiplier(sr1_hi, p[1], w0, w2);
2725     ClipAndStore(dst + stride + x + 16, d11);
2726 
2727     sq[0][0] = sq[0][6];
2728     sq[0][1] = sq[0][7];
2729     sq[1][0] = sq[1][6];
2730     sq[1][1] = sq[1][7];
2731     ma3[0][0] = ma3[0][2];
2732     ma3[1][0] = ma3[1][2];
2733     ma5[0] = ma5[2];
2734     b3[0][0] = b3[0][5];
2735     b3[0][1] = b3[0][6];
2736     b3[1][0] = b3[1][5];
2737     b3[1][1] = b3[1][6];
2738     b5[0] = b5[5];
2739     b5[1] = b5[6];
2740     x += 32;
2741   } while (x < width);
2742 }
2743 
BoxFilterLastRow(const uint16_t * const src,const uint16_t * const src0,const int width,const ptrdiff_t sum_width,const uint16_t scales[2],const int16_t w0,const int16_t w2,uint16_t * const sum3[4],uint16_t * const sum5[5],uint32_t * const square_sum3[4],uint32_t * const square_sum5[5],uint16_t * const ma343,uint16_t * const ma444,uint16_t * const ma565,uint32_t * const b343,uint32_t * const b444,uint32_t * const b565,uint16_t * const dst)2744 inline void BoxFilterLastRow(
2745     const uint16_t* const src, const uint16_t* const src0, const int width,
2746     const ptrdiff_t sum_width, const uint16_t scales[2], const int16_t w0,
2747     const int16_t w2, uint16_t* const sum3[4], uint16_t* const sum5[5],
2748     uint32_t* const square_sum3[4], uint32_t* const square_sum5[5],
2749     uint16_t* const ma343, uint16_t* const ma444, uint16_t* const ma565,
2750     uint32_t* const b343, uint32_t* const b444, uint32_t* const b565,
2751     uint16_t* const dst) {
2752   const ptrdiff_t overread_in_bytes =
2753       kOverreadInBytesPass1_128 - sizeof(*src0) * width;
2754   __m128i s[2], ma3_0, ma5_0, sq_128[4], b3_128[2], b5_128[2];
2755   __m256i ma3[3], ma5[3], sq[8], b3[7], b5[7];
2756   s[0] = LoadUnaligned16Msan(src0 + 0, overread_in_bytes + 0);
2757   s[1] = LoadUnaligned16Msan(src0 + 8, overread_in_bytes + 16);
2758   Square(s[0], sq_128);
2759   BoxFilterPreProcessLastRowLo(s, scales, sum3, sum5, square_sum3, square_sum5,
2760                                sq_128, &ma3_0, &ma5_0, b3_128, b5_128);
2761   sq[0] = SetrM128i(sq_128[2], sq_128[2]);
2762   sq[1] = SetrM128i(sq_128[3], sq_128[3]);
2763   ma3[0] = SetrM128i(ma3_0, ma3_0);
2764   ma5[0] = SetrM128i(ma5_0, ma5_0);
2765   b3[0] = SetrM128i(b3_128[0], b3_128[0]);
2766   b3[1] = SetrM128i(b3_128[1], b3_128[1]);
2767   b5[0] = SetrM128i(b5_128[0], b5_128[0]);
2768   b5[1] = SetrM128i(b5_128[1], b5_128[1]);
2769 
2770   int x = 0;
2771   do {
2772     __m256i ma[4], mat[4], b[3][2], bt[3][2], ma3x[3], ma5x[3], p[2];
2773     BoxFilterPreProcessLastRow(
2774         src0 + x + 8,
2775         kOverreadInBytesPass1_256 + sizeof(*src0) * (x + 8 - width), sum_width,
2776         x + 8, scales, sum3, sum5, square_sum3, square_sum5, sq, ma3, ma5, b3,
2777         b5);
2778     Prepare3_8(ma3, ma3x);
2779     Prepare3_8(ma5, ma5x);
2780     ma[2] = Sum565Lo(ma5x);
2781     Sum565(b5, b[1]);
2782     mat[1] = Sum565Hi(ma5x);
2783     Sum565(b5 + 3, bt[1]);
2784     ma[3] = Sum343Lo(ma3x);
2785     Sum343(b3, b[2]);
2786     mat[2] = Sum343Hi(ma3x);
2787     Sum343(b3 + 3, bt[2]);
2788 
2789     const __m256i sr_lo = LoadUnaligned32(src + x);
2790     ma[0] = LoadAligned32(ma565 + x);
2791     ma[1] = _mm256_permute2x128_si256(ma[2], mat[1], 0x20);
2792     mat[1] = _mm256_permute2x128_si256(ma[2], mat[1], 0x31);
2793     LoadAligned64(b565 + x, b[0]);
2794     p[0] = CalculateFilteredOutputPass1(sr_lo, ma, b);
2795     ma[0] = LoadAligned32(ma343 + x);
2796     ma[1] = LoadAligned32(ma444 + x);
2797     ma[2] = _mm256_permute2x128_si256(ma[3], mat[2], 0x20);
2798     LoadAligned64(b343 + x, b[0]);
2799     LoadAligned64(b444 + x, b[1]);
2800     p[1] = CalculateFilteredOutputPass2(sr_lo, ma, b);
2801     const __m256i d0 = SelfGuidedDoubleMultiplier(sr_lo, p, w0, w2);
2802 
2803     const __m256i sr_hi = LoadUnaligned32(src + x + 16);
2804     mat[0] = LoadAligned32(ma565 + x + 16);
2805     LoadAligned64(b565 + x + 16, bt[0]);
2806     p[0] = CalculateFilteredOutputPass1(sr_hi, mat, bt);
2807     mat[0] = LoadAligned32(ma343 + x + 16);
2808     mat[1] = LoadAligned32(ma444 + x + 16);
2809     mat[2] = _mm256_permute2x128_si256(ma[3], mat[2], 0x31);
2810     LoadAligned64(b343 + x + 16, bt[0]);
2811     LoadAligned64(b444 + x + 16, bt[1]);
2812     p[1] = CalculateFilteredOutputPass2(sr_hi, mat, bt);
2813     const __m256i d1 = SelfGuidedDoubleMultiplier(sr_hi, p, w0, w2);
2814     ClipAndStore(dst + x + 0, d0);
2815     ClipAndStore(dst + x + 16, d1);
2816 
2817     sq[0] = sq[6];
2818     sq[1] = sq[7];
2819     ma3[0] = ma3[2];
2820     ma5[0] = ma5[2];
2821     b3[0] = b3[5];
2822     b3[1] = b3[6];
2823     b5[0] = b5[5];
2824     b5[1] = b5[6];
2825     x += 32;
2826   } while (x < width);
2827 }
2828 
BoxFilterProcess(const RestorationUnitInfo & restoration_info,const uint16_t * src,const ptrdiff_t stride,const uint16_t * const top_border,const ptrdiff_t top_border_stride,const uint16_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint16_t * dst)2829 LIBGAV1_ALWAYS_INLINE void BoxFilterProcess(
2830     const RestorationUnitInfo& restoration_info, const uint16_t* src,
2831     const ptrdiff_t stride, const uint16_t* const top_border,
2832     const ptrdiff_t top_border_stride, const uint16_t* bottom_border,
2833     const ptrdiff_t bottom_border_stride, const int width, const int height,
2834     SgrBuffer* const sgr_buffer, uint16_t* dst) {
2835   const auto temp_stride = Align<ptrdiff_t>(width, 32);
2836   const auto sum_width = temp_stride + 8;
2837   const auto sum_stride = temp_stride + 32;
2838   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2839   const uint16_t* const scales = kSgrScaleParameter[sgr_proj_index];  // < 2^12.
2840   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2841   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
2842   const int16_t w2 = (1 << kSgrProjPrecisionBits) - w0 - w1;
2843   uint16_t *sum3[4], *sum5[5], *ma343[4], *ma444[3], *ma565[2];
2844   uint32_t *square_sum3[4], *square_sum5[5], *b343[4], *b444[3], *b565[2];
2845   sum3[0] = sgr_buffer->sum3 + kSumOffset;
2846   square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset;
2847   ma343[0] = sgr_buffer->ma343;
2848   b343[0] = sgr_buffer->b343;
2849   for (int i = 1; i <= 3; ++i) {
2850     sum3[i] = sum3[i - 1] + sum_stride;
2851     square_sum3[i] = square_sum3[i - 1] + sum_stride;
2852     ma343[i] = ma343[i - 1] + temp_stride;
2853     b343[i] = b343[i - 1] + temp_stride;
2854   }
2855   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2856   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2857   for (int i = 1; i <= 4; ++i) {
2858     sum5[i] = sum5[i - 1] + sum_stride;
2859     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2860   }
2861   ma444[0] = sgr_buffer->ma444;
2862   b444[0] = sgr_buffer->b444;
2863   for (int i = 1; i <= 2; ++i) {
2864     ma444[i] = ma444[i - 1] + temp_stride;
2865     b444[i] = b444[i - 1] + temp_stride;
2866   }
2867   ma565[0] = sgr_buffer->ma565;
2868   ma565[1] = ma565[0] + temp_stride;
2869   b565[0] = sgr_buffer->b565;
2870   b565[1] = b565[0] + temp_stride;
2871   assert(scales[0] != 0);
2872   assert(scales[1] != 0);
2873   BoxSum(top_border, top_border_stride, width, sum_stride, temp_stride, sum3[0],
2874          sum5[1], square_sum3[0], square_sum5[1]);
2875   sum5[0] = sum5[1];
2876   square_sum5[0] = square_sum5[1];
2877   const uint16_t* const s = (height > 1) ? src + stride : bottom_border;
2878   BoxSumFilterPreProcess(src, s, width, scales, sum3, sum5, square_sum3,
2879                          square_sum5, sum_width, ma343, ma444[0], ma565[0],
2880                          b343, b444[0], b565[0]);
2881   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2882   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2883 
2884   for (int y = (height >> 1) - 1; y > 0; --y) {
2885     Circulate4PointersBy2<uint16_t>(sum3);
2886     Circulate4PointersBy2<uint32_t>(square_sum3);
2887     Circulate5PointersBy2<uint16_t>(sum5);
2888     Circulate5PointersBy2<uint32_t>(square_sum5);
2889     BoxFilter(src + 3, src + 2 * stride, src + 3 * stride, stride, width,
2890               scales, w0, w2, sum3, sum5, square_sum3, square_sum5, sum_width,
2891               ma343, ma444, ma565, b343, b444, b565, dst);
2892     src += 2 * stride;
2893     dst += 2 * stride;
2894     Circulate4PointersBy2<uint16_t>(ma343);
2895     Circulate4PointersBy2<uint32_t>(b343);
2896     std::swap(ma444[0], ma444[2]);
2897     std::swap(b444[0], b444[2]);
2898     std::swap(ma565[0], ma565[1]);
2899     std::swap(b565[0], b565[1]);
2900   }
2901 
2902   Circulate4PointersBy2<uint16_t>(sum3);
2903   Circulate4PointersBy2<uint32_t>(square_sum3);
2904   Circulate5PointersBy2<uint16_t>(sum5);
2905   Circulate5PointersBy2<uint32_t>(square_sum5);
2906   if ((height & 1) == 0 || height > 1) {
2907     const uint16_t* sr[2];
2908     if ((height & 1) == 0) {
2909       sr[0] = bottom_border;
2910       sr[1] = bottom_border + bottom_border_stride;
2911     } else {
2912       sr[0] = src + 2 * stride;
2913       sr[1] = bottom_border;
2914     }
2915     BoxFilter(src + 3, sr[0], sr[1], stride, width, scales, w0, w2, sum3, sum5,
2916               square_sum3, square_sum5, sum_width, ma343, ma444, ma565, b343,
2917               b444, b565, dst);
2918   }
2919   if ((height & 1) != 0) {
2920     if (height > 1) {
2921       src += 2 * stride;
2922       dst += 2 * stride;
2923       Circulate4PointersBy2<uint16_t>(sum3);
2924       Circulate4PointersBy2<uint32_t>(square_sum3);
2925       Circulate5PointersBy2<uint16_t>(sum5);
2926       Circulate5PointersBy2<uint32_t>(square_sum5);
2927       Circulate4PointersBy2<uint16_t>(ma343);
2928       Circulate4PointersBy2<uint32_t>(b343);
2929       std::swap(ma444[0], ma444[2]);
2930       std::swap(b444[0], b444[2]);
2931       std::swap(ma565[0], ma565[1]);
2932       std::swap(b565[0], b565[1]);
2933     }
2934     BoxFilterLastRow(src + 3, bottom_border + bottom_border_stride, width,
2935                      sum_width, scales, w0, w2, sum3, sum5, square_sum3,
2936                      square_sum5, ma343[0], ma444[0], ma565[0], b343[0],
2937                      b444[0], b565[0], dst);
2938   }
2939 }
2940 
BoxFilterProcessPass1(const RestorationUnitInfo & restoration_info,const uint16_t * src,const ptrdiff_t stride,const uint16_t * const top_border,const ptrdiff_t top_border_stride,const uint16_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint16_t * dst)2941 inline void BoxFilterProcessPass1(const RestorationUnitInfo& restoration_info,
2942                                   const uint16_t* src, const ptrdiff_t stride,
2943                                   const uint16_t* const top_border,
2944                                   const ptrdiff_t top_border_stride,
2945                                   const uint16_t* bottom_border,
2946                                   const ptrdiff_t bottom_border_stride,
2947                                   const int width, const int height,
2948                                   SgrBuffer* const sgr_buffer, uint16_t* dst) {
2949   const auto temp_stride = Align<ptrdiff_t>(width, 32);
2950   const auto sum_width = temp_stride + 8;
2951   const auto sum_stride = temp_stride + 32;
2952   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
2953   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][0];  // < 2^12.
2954   const int16_t w0 = restoration_info.sgr_proj_info.multiplier[0];
2955   uint16_t *sum5[5], *ma565[2];
2956   uint32_t *square_sum5[5], *b565[2];
2957   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2958   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2959   for (int i = 1; i <= 4; ++i) {
2960     sum5[i] = sum5[i - 1] + sum_stride;
2961     square_sum5[i] = square_sum5[i - 1] + sum_stride;
2962   }
2963   ma565[0] = sgr_buffer->ma565;
2964   ma565[1] = ma565[0] + temp_stride;
2965   b565[0] = sgr_buffer->b565;
2966   b565[1] = b565[0] + temp_stride;
2967   assert(scale != 0);
2968   BoxSum<5>(top_border, top_border_stride, width, sum_stride, temp_stride,
2969             sum5[1], square_sum5[1]);
2970   sum5[0] = sum5[1];
2971   square_sum5[0] = square_sum5[1];
2972   const uint16_t* const s = (height > 1) ? src + stride : bottom_border;
2973   BoxSumFilterPreProcess5(src, s, width, scale, sum5, square_sum5, sum_width,
2974                           ma565[0], b565[0]);
2975   sum5[0] = sgr_buffer->sum5 + kSumOffset;
2976   square_sum5[0] = sgr_buffer->square_sum5 + kSumOffset;
2977 
2978   for (int y = (height >> 1) - 1; y > 0; --y) {
2979     Circulate5PointersBy2<uint16_t>(sum5);
2980     Circulate5PointersBy2<uint32_t>(square_sum5);
2981     BoxFilterPass1(src + 3, src + 2 * stride, src + 3 * stride, stride, sum5,
2982                    square_sum5, width, sum_width, scale, w0, ma565, b565, dst);
2983     src += 2 * stride;
2984     dst += 2 * stride;
2985     std::swap(ma565[0], ma565[1]);
2986     std::swap(b565[0], b565[1]);
2987   }
2988 
2989   Circulate5PointersBy2<uint16_t>(sum5);
2990   Circulate5PointersBy2<uint32_t>(square_sum5);
2991   if ((height & 1) == 0 || height > 1) {
2992     const uint16_t* sr[2];
2993     if ((height & 1) == 0) {
2994       sr[0] = bottom_border;
2995       sr[1] = bottom_border + bottom_border_stride;
2996     } else {
2997       sr[0] = src + 2 * stride;
2998       sr[1] = bottom_border;
2999     }
3000     BoxFilterPass1(src + 3, sr[0], sr[1], stride, sum5, square_sum5, width,
3001                    sum_width, scale, w0, ma565, b565, dst);
3002   }
3003   if ((height & 1) != 0) {
3004     src += 3;
3005     if (height > 1) {
3006       src += 2 * stride;
3007       dst += 2 * stride;
3008       std::swap(ma565[0], ma565[1]);
3009       std::swap(b565[0], b565[1]);
3010       Circulate5PointersBy2<uint16_t>(sum5);
3011       Circulate5PointersBy2<uint32_t>(square_sum5);
3012     }
3013     BoxFilterPass1LastRow(src, bottom_border + bottom_border_stride, width,
3014                           sum_width, scale, w0, sum5, square_sum5, ma565[0],
3015                           b565[0], dst);
3016   }
3017 }
3018 
BoxFilterProcessPass2(const RestorationUnitInfo & restoration_info,const uint16_t * src,const ptrdiff_t stride,const uint16_t * const top_border,const ptrdiff_t top_border_stride,const uint16_t * bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,SgrBuffer * const sgr_buffer,uint16_t * dst)3019 inline void BoxFilterProcessPass2(const RestorationUnitInfo& restoration_info,
3020                                   const uint16_t* src, const ptrdiff_t stride,
3021                                   const uint16_t* const top_border,
3022                                   const ptrdiff_t top_border_stride,
3023                                   const uint16_t* bottom_border,
3024                                   const ptrdiff_t bottom_border_stride,
3025                                   const int width, const int height,
3026                                   SgrBuffer* const sgr_buffer, uint16_t* dst) {
3027   assert(restoration_info.sgr_proj_info.multiplier[0] == 0);
3028   const auto temp_stride = Align<ptrdiff_t>(width, 32);
3029   const auto sum_width = temp_stride + 8;
3030   const auto sum_stride = temp_stride + 32;
3031   const int16_t w1 = restoration_info.sgr_proj_info.multiplier[1];
3032   const int16_t w0 = (1 << kSgrProjPrecisionBits) - w1;
3033   const int sgr_proj_index = restoration_info.sgr_proj_info.index;
3034   const uint32_t scale = kSgrScaleParameter[sgr_proj_index][1];  // < 2^12.
3035   uint16_t *sum3[3], *ma343[3], *ma444[2];
3036   uint32_t *square_sum3[3], *b343[3], *b444[2];
3037   sum3[0] = sgr_buffer->sum3 + kSumOffset;
3038   square_sum3[0] = sgr_buffer->square_sum3 + kSumOffset;
3039   ma343[0] = sgr_buffer->ma343;
3040   b343[0] = sgr_buffer->b343;
3041   for (int i = 1; i <= 2; ++i) {
3042     sum3[i] = sum3[i - 1] + sum_stride;
3043     square_sum3[i] = square_sum3[i - 1] + sum_stride;
3044     ma343[i] = ma343[i - 1] + temp_stride;
3045     b343[i] = b343[i - 1] + temp_stride;
3046   }
3047   ma444[0] = sgr_buffer->ma444;
3048   ma444[1] = ma444[0] + temp_stride;
3049   b444[0] = sgr_buffer->b444;
3050   b444[1] = b444[0] + temp_stride;
3051   assert(scale != 0);
3052   BoxSum<3>(top_border, top_border_stride, width, sum_stride, temp_stride,
3053             sum3[0], square_sum3[0]);
3054   BoxSumFilterPreProcess3<false>(src, width, scale, sum3, square_sum3,
3055                                  sum_width, ma343[0], nullptr, b343[0],
3056                                  nullptr);
3057   Circulate3PointersBy1<uint16_t>(sum3);
3058   Circulate3PointersBy1<uint32_t>(square_sum3);
3059   const uint16_t* s;
3060   if (height > 1) {
3061     s = src + stride;
3062   } else {
3063     s = bottom_border;
3064     bottom_border += bottom_border_stride;
3065   }
3066   BoxSumFilterPreProcess3<true>(s, width, scale, sum3, square_sum3, sum_width,
3067                                 ma343[1], ma444[0], b343[1], b444[0]);
3068 
3069   for (int y = height - 2; y > 0; --y) {
3070     Circulate3PointersBy1<uint16_t>(sum3);
3071     Circulate3PointersBy1<uint32_t>(square_sum3);
3072     BoxFilterPass2(src + 2, src + 2 * stride, width, sum_width, scale, w0, sum3,
3073                    square_sum3, ma343, ma444, b343, b444, dst);
3074     src += stride;
3075     dst += stride;
3076     Circulate3PointersBy1<uint16_t>(ma343);
3077     Circulate3PointersBy1<uint32_t>(b343);
3078     std::swap(ma444[0], ma444[1]);
3079     std::swap(b444[0], b444[1]);
3080   }
3081 
3082   int y = std::min(height, 2);
3083   src += 2;
3084   do {
3085     Circulate3PointersBy1<uint16_t>(sum3);
3086     Circulate3PointersBy1<uint32_t>(square_sum3);
3087     BoxFilterPass2(src, bottom_border, width, sum_width, scale, w0, sum3,
3088                    square_sum3, ma343, ma444, b343, b444, dst);
3089     src += stride;
3090     dst += stride;
3091     bottom_border += bottom_border_stride;
3092     Circulate3PointersBy1<uint16_t>(ma343);
3093     Circulate3PointersBy1<uint32_t>(b343);
3094     std::swap(ma444[0], ma444[1]);
3095     std::swap(b444[0], b444[1]);
3096   } while (--y != 0);
3097 }
3098 
3099 // If |width| is non-multiple of 32, up to 31 more pixels are written to |dest|
3100 // in the end of each row. It is safe to overwrite the output as it will not be
3101 // part of the visible frame.
SelfGuidedFilter_AVX2(const RestorationUnitInfo & LIBGAV1_RESTRICT restoration_info,const void * LIBGAV1_RESTRICT const source,const ptrdiff_t stride,const void * LIBGAV1_RESTRICT const top_border,const ptrdiff_t top_border_stride,const void * LIBGAV1_RESTRICT const bottom_border,const ptrdiff_t bottom_border_stride,const int width,const int height,RestorationBuffer * LIBGAV1_RESTRICT const restoration_buffer,void * LIBGAV1_RESTRICT const dest)3102 void SelfGuidedFilter_AVX2(
3103     const RestorationUnitInfo& LIBGAV1_RESTRICT restoration_info,
3104     const void* LIBGAV1_RESTRICT const source, const ptrdiff_t stride,
3105     const void* LIBGAV1_RESTRICT const top_border,
3106     const ptrdiff_t top_border_stride,
3107     const void* LIBGAV1_RESTRICT const bottom_border,
3108     const ptrdiff_t bottom_border_stride, const int width, const int height,
3109     RestorationBuffer* LIBGAV1_RESTRICT const restoration_buffer,
3110     void* LIBGAV1_RESTRICT const dest) {
3111   const int index = restoration_info.sgr_proj_info.index;
3112   const int radius_pass_0 = kSgrProjParams[index][0];  // 2 or 0
3113   const int radius_pass_1 = kSgrProjParams[index][2];  // 1 or 0
3114   const auto* const src = static_cast<const uint16_t*>(source);
3115   const auto* const top = static_cast<const uint16_t*>(top_border);
3116   const auto* const bottom = static_cast<const uint16_t*>(bottom_border);
3117   auto* const dst = static_cast<uint16_t*>(dest);
3118   SgrBuffer* const sgr_buffer = &restoration_buffer->sgr_buffer;
3119   if (radius_pass_1 == 0) {
3120     // |radius_pass_0| and |radius_pass_1| cannot both be 0, so we have the
3121     // following assertion.
3122     assert(radius_pass_0 != 0);
3123     BoxFilterProcessPass1(restoration_info, src - 3, stride, top - 3,
3124                           top_border_stride, bottom - 3, bottom_border_stride,
3125                           width, height, sgr_buffer, dst);
3126   } else if (radius_pass_0 == 0) {
3127     BoxFilterProcessPass2(restoration_info, src - 2, stride, top - 2,
3128                           top_border_stride, bottom - 2, bottom_border_stride,
3129                           width, height, sgr_buffer, dst);
3130   } else {
3131     BoxFilterProcess(restoration_info, src - 3, stride, top - 3,
3132                      top_border_stride, bottom - 3, bottom_border_stride, width,
3133                      height, sgr_buffer, dst);
3134   }
3135 }
3136 
Init10bpp()3137 void Init10bpp() {
3138   Dsp* const dsp = dsp_internal::GetWritableDspTable(kBitdepth10);
3139   assert(dsp != nullptr);
3140 #if DSP_ENABLED_10BPP_AVX2(WienerFilter)
3141   dsp->loop_restorations[0] = WienerFilter_AVX2;
3142 #endif
3143 #if DSP_ENABLED_10BPP_AVX2(SelfGuidedFilter)
3144   dsp->loop_restorations[1] = SelfGuidedFilter_AVX2;
3145 #endif
3146 }
3147 
3148 }  // namespace
3149 
LoopRestorationInit10bpp_AVX2()3150 void LoopRestorationInit10bpp_AVX2() { Init10bpp(); }
3151 
3152 }  // namespace dsp
3153 }  // namespace libgav1
3154 
3155 #else   // !(LIBGAV1_TARGETING_AVX2 && LIBGAV1_MAX_BITDEPTH >= 10)
3156 namespace libgav1 {
3157 namespace dsp {
3158 
LoopRestorationInit10bpp_AVX2()3159 void LoopRestorationInit10bpp_AVX2() {}
3160 
3161 }  // namespace dsp
3162 }  // namespace libgav1
3163 #endif  // LIBGAV1_TARGETING_AVX2 && LIBGAV1_MAX_BITDEPTH >= 10
3164