xref: /aosp_15_r20/external/libaom/av1/common/x86/jnt_convolve_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2018, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <emmintrin.h>
13 #include <immintrin.h>
14 
15 #include "config/av1_rtcd.h"
16 
17 #include "aom_dsp/aom_dsp_common.h"
18 #include "aom_dsp/aom_filter.h"
19 #include "aom_dsp/x86/convolve_avx2.h"
20 #include "aom_dsp/x86/convolve_common_intrin.h"
21 #include "aom_dsp/x86/convolve_sse4_1.h"
22 #include "aom_dsp/x86/mem_sse2.h"
23 #include "aom_dsp/x86/synonyms_avx2.h"
24 
25 #include "av1/common/convolve.h"
26 
unpack_weights_avx2(ConvolveParams * conv_params)27 static inline __m256i unpack_weights_avx2(ConvolveParams *conv_params) {
28   const int w0 = conv_params->fwd_offset;
29   const int w1 = conv_params->bck_offset;
30   const __m256i wt0 = _mm256_set1_epi16((int16_t)w0);
31   const __m256i wt1 = _mm256_set1_epi16((int16_t)w1);
32   const __m256i wt = _mm256_unpacklo_epi16(wt0, wt1);
33   return wt;
34 }
35 
load_line2_avx2(const void * a,const void * b)36 static inline __m256i load_line2_avx2(const void *a, const void *b) {
37   return _mm256_permute2x128_si256(
38       _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)a)),
39       _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)b)), 0x20);
40 }
41 
av1_dist_wtd_convolve_x_avx2(const uint8_t * src,int src_stride,uint8_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const int subpel_x_qn,ConvolveParams * conv_params)42 void av1_dist_wtd_convolve_x_avx2(const uint8_t *src, int src_stride,
43                                   uint8_t *dst0, int dst_stride0, int w, int h,
44                                   const InterpFilterParams *filter_params_x,
45                                   const int subpel_x_qn,
46                                   ConvolveParams *conv_params) {
47   CONV_BUF_TYPE *dst = conv_params->dst;
48   int dst_stride = conv_params->dst_stride;
49   const int bd = 8;
50   int i, j, is_horiz_4tap = 0;
51   const int bits = FILTER_BITS - conv_params->round_1;
52   const __m256i wt = unpack_weights_avx2(conv_params);
53   const int do_average = conv_params->do_average;
54   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
55   const int offset_0 =
56       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
57   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
58   const __m256i offset_const = _mm256_set1_epi16(offset);
59   const int rounding_shift =
60       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
61   const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
62 
63   assert(bits >= 0);
64   assert(conv_params->round_0 > 0);
65 
66   const __m256i round_const =
67       _mm256_set1_epi16((1 << (conv_params->round_0 - 1)) >> 1);
68   const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_0 - 1);
69 
70   __m256i filt[4], coeffs[4];
71 
72   filt[0] = _mm256_load_si256((__m256i const *)filt_global_avx2);
73   filt[1] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32));
74 
75   prepare_coeffs_lowbd(filter_params_x, subpel_x_qn, coeffs);
76 
77   // Condition for checking valid horz_filt taps
78   if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs[0], coeffs[3]), 0)))
79     is_horiz_4tap = 1;
80 
81   // horz_filt as 4 tap
82   if (is_horiz_4tap) {
83     const int fo_horiz = 1;
84     const uint8_t *const src_ptr = src - fo_horiz;
85     for (i = 0; i < h; i += 2) {
86       const uint8_t *src_data = src_ptr + i * src_stride;
87       CONV_BUF_TYPE *dst_data = dst + i * dst_stride;
88       for (j = 0; j < w; j += 8) {
89         const __m256i data =
90             load_line2_avx2(&src_data[j], &src_data[j + src_stride]);
91 
92         __m256i res = convolve_lowbd_x_4tap(data, coeffs + 1, filt);
93         res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const), round_shift);
94         res = _mm256_slli_epi16(res, bits);
95 
96         const __m256i res_unsigned = _mm256_add_epi16(res, offset_const);
97 
98         // Accumulate values into the destination buffer
99         if (do_average) {
100           const __m256i data_ref_0 =
101               load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]);
102           const __m256i comp_avg_res =
103               comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);
104 
105           const __m256i round_result = convolve_rounding(
106               &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
107 
108           const __m256i res_8 = _mm256_packus_epi16(round_result, round_result);
109           const __m128i res_0 = _mm256_castsi256_si128(res_8);
110           const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
111 
112           if (w > 4) {
113             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
114             _mm_storel_epi64(
115                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
116           } else {
117             *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
118             *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
119                 _mm_cvtsi128_si32(res_1);
120           }
121         } else {
122           const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
123           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
124 
125           const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
126           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
127                           res_1);
128         }
129       }
130     }
131   } else {
132     const int fo_horiz = filter_params_x->taps / 2 - 1;
133     const uint8_t *const src_ptr = src - fo_horiz;
134 
135     filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
136     filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
137     for (i = 0; i < h; i += 2) {
138       const uint8_t *src_data = src_ptr + i * src_stride;
139       CONV_BUF_TYPE *dst_data = dst + i * dst_stride;
140       for (j = 0; j < w; j += 8) {
141         const __m256i data =
142             load_line2_avx2(&src_data[j], &src_data[j + src_stride]);
143 
144         __m256i res = convolve_lowbd_x(data, coeffs, filt);
145 
146         res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const), round_shift);
147 
148         res = _mm256_slli_epi16(res, bits);
149 
150         const __m256i res_unsigned = _mm256_add_epi16(res, offset_const);
151 
152         // Accumulate values into the destination buffer
153         if (do_average) {
154           const __m256i data_ref_0 =
155               load_line2_avx2(&dst_data[j], &dst_data[j + dst_stride]);
156           const __m256i comp_avg_res =
157               comp_avg(&data_ref_0, &res_unsigned, &wt, use_dist_wtd_comp_avg);
158 
159           const __m256i round_result = convolve_rounding(
160               &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
161 
162           const __m256i res_8 = _mm256_packus_epi16(round_result, round_result);
163           const __m128i res_0 = _mm256_castsi256_si128(res_8);
164           const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
165 
166           if (w > 4) {
167             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
168             _mm_storel_epi64(
169                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
170           } else {
171             *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
172             *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
173                 _mm_cvtsi128_si32(res_1);
174           }
175         } else {
176           const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
177           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
178 
179           const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
180           _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
181                           res_1);
182         }
183       }
184     }
185   }
186 }
187 
av1_dist_wtd_convolve_y_avx2(const uint8_t * src,int src_stride,uint8_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_y,const int subpel_y_qn,ConvolveParams * conv_params)188 void av1_dist_wtd_convolve_y_avx2(const uint8_t *src, int src_stride,
189                                   uint8_t *dst0, int dst_stride0, int w, int h,
190                                   const InterpFilterParams *filter_params_y,
191                                   const int subpel_y_qn,
192                                   ConvolveParams *conv_params) {
193   CONV_BUF_TYPE *dst = conv_params->dst;
194   int dst_stride = conv_params->dst_stride;
195   const int bd = 8;
196   int i, j, is_vert_4tap = 0;
197   // +1 to compensate for dividing the filter coeffs by 2
198   const int left_shift = FILTER_BITS - conv_params->round_0 + 1;
199   const __m256i round_const =
200       _mm256_set1_epi32((1 << conv_params->round_1) >> 1);
201   const __m128i round_shift = _mm_cvtsi32_si128(conv_params->round_1);
202   const __m256i wt = unpack_weights_avx2(conv_params);
203   const int do_average = conv_params->do_average;
204   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
205   const int offset_0 =
206       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
207   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
208   const __m256i offset_const = _mm256_set1_epi16(offset);
209   const int offset_1 = (1 << (bd + FILTER_BITS - 2));
210   const __m256i offset_const_1 = _mm256_set1_epi16(offset_1);
211   const __m256i offset_const_2 = _mm256_set1_epi16((1 << offset_0));
212   const int rounding_shift =
213       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
214   const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
215   const __m256i zero = _mm256_setzero_si256();
216   __m256i coeffs[4], s[8];
217 
218   assert((FILTER_BITS - conv_params->round_0) >= 0);
219 
220   prepare_coeffs_lowbd(filter_params_y, subpel_y_qn, coeffs);
221 
222   // Condition for checking valid vert_filt taps
223   if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs[0], coeffs[3]), 0)))
224     is_vert_4tap = 1;
225 
226   if (is_vert_4tap) {
227     const int fo_vert = 1;
228     const uint8_t *const src_ptr = src - fo_vert * src_stride;
229     for (j = 0; j < w; j += 16) {
230       const uint8_t *data = &src_ptr[j];
231       __m256i src4;
232       // Load lines a and b. Line a to lower 128, line b to upper 128
233       {
234         __m256i src_ab[4];
235         __m256i src_a[5];
236         src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
237         for (int kk = 0; kk < 4; ++kk) {
238           data += src_stride;
239           src_a[kk + 1] =
240               _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
241           src_ab[kk] =
242               _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20);
243         }
244         src4 = src_a[4];
245         s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]);
246         s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]);
247 
248         s[3] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]);
249         s[4] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]);
250       }
251 
252       for (i = 0; i < h; i += 2) {
253         data = &src_ptr[(i + 5) * src_stride + j];
254         const __m256i src5 =
255             _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
256         const __m256i src_45a = _mm256_permute2x128_si256(src4, src5, 0x20);
257 
258         src4 = _mm256_castsi128_si256(
259             _mm_loadu_si128((__m128i *)(data + src_stride)));
260         const __m256i src_56a = _mm256_permute2x128_si256(src5, src4, 0x20);
261 
262         s[2] = _mm256_unpacklo_epi8(src_45a, src_56a);
263         s[5] = _mm256_unpackhi_epi8(src_45a, src_56a);
264 
265         __m256i res_lo = convolve_lowbd_4tap(s, coeffs + 1);
266 
267         res_lo = _mm256_add_epi16(res_lo, offset_const_1);
268 
269         const __m256i res_lo_0_32b = _mm256_unpacklo_epi16(res_lo, zero);
270         const __m256i res_lo_0_shift =
271             _mm256_slli_epi32(res_lo_0_32b, left_shift);
272         const __m256i res_lo_0_round = _mm256_sra_epi32(
273             _mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
274 
275         const __m256i res_lo_1_32b = _mm256_unpackhi_epi16(res_lo, zero);
276         const __m256i res_lo_1_shift =
277             _mm256_slli_epi32(res_lo_1_32b, left_shift);
278         const __m256i res_lo_1_round = _mm256_sra_epi32(
279             _mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
280 
281         const __m256i res_lo_round =
282             _mm256_packs_epi32(res_lo_0_round, res_lo_1_round);
283 
284         const __m256i res_lo_unsigned =
285             _mm256_add_epi16(res_lo_round, offset_const_2);
286 
287         if (w - j < 16) {
288           if (do_average) {
289             const __m256i data_ref_0 =
290                 load_line2_avx2(&dst[i * dst_stride + j],
291                                 &dst[i * dst_stride + j + dst_stride]);
292             const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned,
293                                                   &wt, use_dist_wtd_comp_avg);
294 
295             const __m256i round_result = convolve_rounding(
296                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
297 
298             const __m256i res_8 =
299                 _mm256_packus_epi16(round_result, round_result);
300             const __m128i res_0 = _mm256_castsi256_si128(res_8);
301             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
302 
303             if (w - j > 4) {
304               _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
305               _mm_storel_epi64(
306                   (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])),
307                   res_1);
308             } else {
309               *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
310               *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
311                   _mm_cvtsi128_si32(res_1);
312             }
313           } else {
314             const __m128i res_0 = _mm256_castsi256_si128(res_lo_unsigned);
315             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
316 
317             const __m128i res_1 = _mm256_extracti128_si256(res_lo_unsigned, 1);
318             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
319                             res_1);
320           }
321         } else {
322           __m256i res_hi = convolve_lowbd_4tap(s + 3, coeffs + 1);
323 
324           res_hi = _mm256_add_epi16(res_hi, offset_const_1);
325 
326           const __m256i res_hi_0_32b = _mm256_unpacklo_epi16(res_hi, zero);
327           const __m256i res_hi_0_shift =
328               _mm256_slli_epi32(res_hi_0_32b, left_shift);
329           const __m256i res_hi_0_round = _mm256_sra_epi32(
330               _mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
331 
332           const __m256i res_hi_1_32b = _mm256_unpackhi_epi16(res_hi, zero);
333           const __m256i res_hi_1_shift =
334               _mm256_slli_epi32(res_hi_1_32b, left_shift);
335           const __m256i res_hi_1_round = _mm256_sra_epi32(
336               _mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
337 
338           const __m256i res_hi_round =
339               _mm256_packs_epi32(res_hi_0_round, res_hi_1_round);
340 
341           const __m256i res_hi_unsigned =
342               _mm256_add_epi16(res_hi_round, offset_const_2);
343 
344           if (do_average) {
345             const __m256i data_ref_0_lo =
346                 load_line2_avx2(&dst[i * dst_stride + j],
347                                 &dst[i * dst_stride + j + dst_stride]);
348 
349             const __m256i data_ref_0_hi =
350                 load_line2_avx2(&dst[i * dst_stride + j + 8],
351                                 &dst[i * dst_stride + j + 8 + dst_stride]);
352 
353             const __m256i comp_avg_res_lo = comp_avg(
354                 &data_ref_0_lo, &res_lo_unsigned, &wt, use_dist_wtd_comp_avg);
355 
356             const __m256i comp_avg_res_hi = comp_avg(
357                 &data_ref_0_hi, &res_hi_unsigned, &wt, use_dist_wtd_comp_avg);
358 
359             const __m256i round_result_lo =
360                 convolve_rounding(&comp_avg_res_lo, &offset_const,
361                                   &rounding_const, rounding_shift);
362 
363             const __m256i round_result_hi =
364                 convolve_rounding(&comp_avg_res_hi, &offset_const,
365                                   &rounding_const, rounding_shift);
366 
367             const __m256i res_8 =
368                 _mm256_packus_epi16(round_result_lo, round_result_hi);
369             const __m128i res_0 = _mm256_castsi256_si128(res_8);
370             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
371 
372             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
373             _mm_store_si128(
374                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
375 
376           } else {
377             const __m128i res_lo_0 = _mm256_castsi256_si128(res_lo_unsigned);
378             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_lo_0);
379 
380             const __m128i res_lo_1 =
381                 _mm256_extracti128_si256(res_lo_unsigned, 1);
382             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
383                             res_lo_1);
384 
385             const __m128i res_hi_0 = _mm256_castsi256_si128(res_hi_unsigned);
386             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + 8]),
387                             res_hi_0);
388 
389             const __m128i res_hi_1 =
390                 _mm256_extracti128_si256(res_hi_unsigned, 1);
391             _mm_store_si128(
392                 (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]),
393                 res_hi_1);
394           }
395         }
396         s[0] = s[1];
397         s[1] = s[2];
398 
399         s[3] = s[4];
400         s[4] = s[5];
401       }
402     }
403   } else {
404     const int fo_vert = filter_params_y->taps / 2 - 1;
405     const uint8_t *const src_ptr = src - fo_vert * src_stride;
406     for (j = 0; j < w; j += 16) {
407       const uint8_t *data = &src_ptr[j];
408       __m256i src6;
409       // Load lines a and b. Line a to lower 128, line b to upper 128
410       {
411         __m256i src_ab[7];
412         __m256i src_a[7];
413         src_a[0] = _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
414         for (int kk = 0; kk < 6; ++kk) {
415           data += src_stride;
416           src_a[kk + 1] =
417               _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
418           src_ab[kk] =
419               _mm256_permute2x128_si256(src_a[kk], src_a[kk + 1], 0x20);
420         }
421         src6 = src_a[6];
422         s[0] = _mm256_unpacklo_epi8(src_ab[0], src_ab[1]);
423         s[1] = _mm256_unpacklo_epi8(src_ab[2], src_ab[3]);
424         s[2] = _mm256_unpacklo_epi8(src_ab[4], src_ab[5]);
425         s[4] = _mm256_unpackhi_epi8(src_ab[0], src_ab[1]);
426         s[5] = _mm256_unpackhi_epi8(src_ab[2], src_ab[3]);
427         s[6] = _mm256_unpackhi_epi8(src_ab[4], src_ab[5]);
428       }
429 
430       for (i = 0; i < h; i += 2) {
431         data = &src_ptr[(i + 7) * src_stride + j];
432         const __m256i src7 =
433             _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)data));
434         const __m256i src_67a = _mm256_permute2x128_si256(src6, src7, 0x20);
435 
436         src6 = _mm256_castsi128_si256(
437             _mm_loadu_si128((__m128i *)(data + src_stride)));
438         const __m256i src_78a = _mm256_permute2x128_si256(src7, src6, 0x20);
439 
440         s[3] = _mm256_unpacklo_epi8(src_67a, src_78a);
441         s[7] = _mm256_unpackhi_epi8(src_67a, src_78a);
442 
443         __m256i res_lo = convolve_lowbd(s, coeffs);
444 
445         res_lo = _mm256_add_epi16(res_lo, offset_const_1);
446 
447         const __m256i res_lo_0_32b = _mm256_unpacklo_epi16(res_lo, zero);
448         const __m256i res_lo_0_shift =
449             _mm256_slli_epi32(res_lo_0_32b, left_shift);
450         const __m256i res_lo_0_round = _mm256_sra_epi32(
451             _mm256_add_epi32(res_lo_0_shift, round_const), round_shift);
452 
453         const __m256i res_lo_1_32b = _mm256_unpackhi_epi16(res_lo, zero);
454         const __m256i res_lo_1_shift =
455             _mm256_slli_epi32(res_lo_1_32b, left_shift);
456         const __m256i res_lo_1_round = _mm256_sra_epi32(
457             _mm256_add_epi32(res_lo_1_shift, round_const), round_shift);
458 
459         const __m256i res_lo_round =
460             _mm256_packs_epi32(res_lo_0_round, res_lo_1_round);
461 
462         const __m256i res_lo_unsigned =
463             _mm256_add_epi16(res_lo_round, offset_const_2);
464 
465         if (w - j < 16) {
466           if (do_average) {
467             const __m256i data_ref_0 =
468                 load_line2_avx2(&dst[i * dst_stride + j],
469                                 &dst[i * dst_stride + j + dst_stride]);
470             const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_lo_unsigned,
471                                                   &wt, use_dist_wtd_comp_avg);
472 
473             const __m256i round_result = convolve_rounding(
474                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
475 
476             const __m256i res_8 =
477                 _mm256_packus_epi16(round_result, round_result);
478             const __m128i res_0 = _mm256_castsi256_si128(res_8);
479             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
480 
481             if (w - j > 4) {
482               _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
483               _mm_storel_epi64(
484                   (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])),
485                   res_1);
486             } else {
487               *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
488               *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
489                   _mm_cvtsi128_si32(res_1);
490             }
491           } else {
492             const __m128i res_0 = _mm256_castsi256_si128(res_lo_unsigned);
493             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
494 
495             const __m128i res_1 = _mm256_extracti128_si256(res_lo_unsigned, 1);
496             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
497                             res_1);
498           }
499         } else {
500           __m256i res_hi = convolve_lowbd(s + 4, coeffs);
501 
502           res_hi = _mm256_add_epi16(res_hi, offset_const_1);
503 
504           const __m256i res_hi_0_32b = _mm256_unpacklo_epi16(res_hi, zero);
505           const __m256i res_hi_0_shift =
506               _mm256_slli_epi32(res_hi_0_32b, left_shift);
507           const __m256i res_hi_0_round = _mm256_sra_epi32(
508               _mm256_add_epi32(res_hi_0_shift, round_const), round_shift);
509 
510           const __m256i res_hi_1_32b = _mm256_unpackhi_epi16(res_hi, zero);
511           const __m256i res_hi_1_shift =
512               _mm256_slli_epi32(res_hi_1_32b, left_shift);
513           const __m256i res_hi_1_round = _mm256_sra_epi32(
514               _mm256_add_epi32(res_hi_1_shift, round_const), round_shift);
515 
516           const __m256i res_hi_round =
517               _mm256_packs_epi32(res_hi_0_round, res_hi_1_round);
518 
519           const __m256i res_hi_unsigned =
520               _mm256_add_epi16(res_hi_round, offset_const_2);
521 
522           if (do_average) {
523             const __m256i data_ref_0_lo =
524                 load_line2_avx2(&dst[i * dst_stride + j],
525                                 &dst[i * dst_stride + j + dst_stride]);
526 
527             const __m256i data_ref_0_hi =
528                 load_line2_avx2(&dst[i * dst_stride + j + 8],
529                                 &dst[i * dst_stride + j + 8 + dst_stride]);
530 
531             const __m256i comp_avg_res_lo = comp_avg(
532                 &data_ref_0_lo, &res_lo_unsigned, &wt, use_dist_wtd_comp_avg);
533 
534             const __m256i comp_avg_res_hi = comp_avg(
535                 &data_ref_0_hi, &res_hi_unsigned, &wt, use_dist_wtd_comp_avg);
536 
537             const __m256i round_result_lo =
538                 convolve_rounding(&comp_avg_res_lo, &offset_const,
539                                   &rounding_const, rounding_shift);
540 
541             const __m256i round_result_hi =
542                 convolve_rounding(&comp_avg_res_hi, &offset_const,
543                                   &rounding_const, rounding_shift);
544 
545             const __m256i res_8 =
546                 _mm256_packus_epi16(round_result_lo, round_result_hi);
547             const __m128i res_0 = _mm256_castsi256_si128(res_8);
548             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
549 
550             _mm_store_si128((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
551             _mm_store_si128(
552                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
553 
554           } else {
555             const __m128i res_lo_0 = _mm256_castsi256_si128(res_lo_unsigned);
556             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_lo_0);
557 
558             const __m128i res_lo_1 =
559                 _mm256_extracti128_si256(res_lo_unsigned, 1);
560             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
561                             res_lo_1);
562 
563             const __m128i res_hi_0 = _mm256_castsi256_si128(res_hi_unsigned);
564             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + 8]),
565                             res_hi_0);
566 
567             const __m128i res_hi_1 =
568                 _mm256_extracti128_si256(res_hi_unsigned, 1);
569             _mm_store_si128(
570                 (__m128i *)(&dst[i * dst_stride + j + 8 + dst_stride]),
571                 res_hi_1);
572           }
573         }
574         s[0] = s[1];
575         s[1] = s[2];
576         s[2] = s[3];
577 
578         s[4] = s[5];
579         s[5] = s[6];
580         s[6] = s[7];
581       }
582     }
583   }
584 }
585 
av1_dist_wtd_convolve_2d_avx2(const uint8_t * src,int src_stride,uint8_t * dst0,int dst_stride0,int w,int h,const InterpFilterParams * filter_params_x,const InterpFilterParams * filter_params_y,const int subpel_x_qn,const int subpel_y_qn,ConvolveParams * conv_params)586 void av1_dist_wtd_convolve_2d_avx2(const uint8_t *src, int src_stride,
587                                    uint8_t *dst0, int dst_stride0, int w, int h,
588                                    const InterpFilterParams *filter_params_x,
589                                    const InterpFilterParams *filter_params_y,
590                                    const int subpel_x_qn, const int subpel_y_qn,
591                                    ConvolveParams *conv_params) {
592   CONV_BUF_TYPE *dst = conv_params->dst;
593   int dst_stride = conv_params->dst_stride;
594   const int bd = 8;
595 
596   DECLARE_ALIGNED(32, int16_t, im_block[(MAX_SB_SIZE + MAX_FILTER_TAP) * 8]);
597 
598   int im_stride = 8;
599   int i, is_horiz_4tap = 0, is_vert_4tap = 0;
600   const __m256i wt = unpack_weights_avx2(conv_params);
601   const int do_average = conv_params->do_average;
602   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
603   const int offset_0 =
604       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
605   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
606   const __m256i offset_const = _mm256_set1_epi16(offset);
607   const int rounding_shift =
608       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
609   const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
610 
611   assert(conv_params->round_0 > 0);
612 
613   const __m256i round_const_h = _mm256_set1_epi16(
614       ((1 << (conv_params->round_0 - 1)) >> 1) + (1 << (bd + FILTER_BITS - 2)));
615   const __m128i round_shift_h = _mm_cvtsi32_si128(conv_params->round_0 - 1);
616 
617   const __m256i round_const_v = _mm256_set1_epi32(
618       ((1 << conv_params->round_1) >> 1) -
619       (1 << (bd + 2 * FILTER_BITS - conv_params->round_0 - 1)));
620   const __m128i round_shift_v = _mm_cvtsi32_si128(conv_params->round_1);
621 
622   __m256i filt[4], coeffs_x[4], coeffs_y[4];
623 
624   filt[0] = _mm256_load_si256((__m256i const *)filt_global_avx2);
625   filt[1] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32));
626 
627   prepare_coeffs_lowbd(filter_params_x, subpel_x_qn, coeffs_x);
628   prepare_coeffs(filter_params_y, subpel_y_qn, coeffs_y);
629 
630   // Condition for checking valid horz_filt taps
631   if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs_x[0], coeffs_x[3]), 0)))
632     is_horiz_4tap = 1;
633 
634   // Condition for checking valid vert_filt taps
635   if (!(_mm256_extract_epi32(_mm256_or_si256(coeffs_y[0], coeffs_y[3]), 0)))
636     is_vert_4tap = 1;
637 
638   if (is_horiz_4tap) {
639     int im_h = h + filter_params_y->taps - 1;
640     const int fo_vert = filter_params_y->taps / 2 - 1;
641     const int fo_horiz = 1;
642     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
643     for (int j = 0; j < w; j += 8) {
644       /* Horizontal filter */
645       const uint8_t *src_h = src_ptr + j;
646       for (i = 0; i < im_h; i += 2) {
647         __m256i data =
648             _mm256_castsi128_si256(_mm_loadu_si128((__m128i *)src_h));
649         if (i + 1 < im_h)
650           data = _mm256_inserti128_si256(
651               data, _mm_loadu_si128((__m128i *)(src_h + src_stride)), 1);
652         src_h += (src_stride << 1);
653         __m256i res = convolve_lowbd_x_4tap(data, coeffs_x + 1, filt);
654 
655         res = _mm256_sra_epi16(_mm256_add_epi16(res, round_const_h),
656                                round_shift_h);
657 
658         _mm256_store_si256((__m256i *)&im_block[i * im_stride], res);
659       }
660       DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP;
661     }
662   } else if (is_vert_4tap) {
663     int im_h = h + 3;
664     const int fo_vert = 1;
665     const int fo_horiz = filter_params_x->taps / 2 - 1;
666     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
667 
668     filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
669     filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
670 
671     for (int j = 0; j < w; j += 8) {
672       /* Horizontal filter */
673       const uint8_t *src_h = src_ptr + j;
674       DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP;
675 
676       /* Vertical filter */
677       __m256i s[6];
678       __m256i s0 = _mm256_loadu_si256((__m256i *)(im_block + 0 * im_stride));
679       __m256i s1 = _mm256_loadu_si256((__m256i *)(im_block + 1 * im_stride));
680       __m256i s2 = _mm256_loadu_si256((__m256i *)(im_block + 2 * im_stride));
681       __m256i s3 = _mm256_loadu_si256((__m256i *)(im_block + 3 * im_stride));
682 
683       s[0] = _mm256_unpacklo_epi16(s0, s1);
684       s[1] = _mm256_unpacklo_epi16(s2, s3);
685 
686       s[3] = _mm256_unpackhi_epi16(s0, s1);
687       s[4] = _mm256_unpackhi_epi16(s2, s3);
688 
689       for (i = 0; i < h; i += 2) {
690         const int16_t *data = &im_block[i * im_stride];
691 
692         const __m256i s4 =
693             _mm256_loadu_si256((__m256i *)(data + 4 * im_stride));
694         const __m256i s5 =
695             _mm256_loadu_si256((__m256i *)(data + 5 * im_stride));
696 
697         s[2] = _mm256_unpacklo_epi16(s4, s5);
698         s[5] = _mm256_unpackhi_epi16(s4, s5);
699 
700         const __m256i res_a = convolve_4tap(s, coeffs_y + 1);
701         const __m256i res_a_round = _mm256_sra_epi32(
702             _mm256_add_epi32(res_a, round_const_v), round_shift_v);
703 
704         if (w - j > 4) {
705           const __m256i res_b = convolve_4tap(s + 3, coeffs_y + 1);
706           const __m256i res_b_round = _mm256_sra_epi32(
707               _mm256_add_epi32(res_b, round_const_v), round_shift_v);
708           const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_b_round);
709           const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
710 
711           if (do_average) {
712             const __m256i data_ref_0 =
713                 load_line2_avx2(&dst[i * dst_stride + j],
714                                 &dst[i * dst_stride + j + dst_stride]);
715             const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned,
716                                                   &wt, use_dist_wtd_comp_avg);
717 
718             const __m256i round_result = convolve_rounding(
719                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
720 
721             const __m256i res_8 =
722                 _mm256_packus_epi16(round_result, round_result);
723             const __m128i res_0 = _mm256_castsi256_si128(res_8);
724             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
725 
726             _mm_storel_epi64((__m128i *)(&dst0[i * dst_stride0 + j]), res_0);
727             _mm_storel_epi64(
728                 (__m128i *)((&dst0[i * dst_stride0 + j + dst_stride0])), res_1);
729           } else {
730             const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
731             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
732 
733             const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
734             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
735                             res_1);
736           }
737         } else {
738           const __m256i res_16b = _mm256_packs_epi32(res_a_round, res_a_round);
739           const __m256i res_unsigned = _mm256_add_epi16(res_16b, offset_const);
740 
741           if (do_average) {
742             const __m256i data_ref_0 =
743                 load_line2_avx2(&dst[i * dst_stride + j],
744                                 &dst[i * dst_stride + j + dst_stride]);
745 
746             const __m256i comp_avg_res = comp_avg(&data_ref_0, &res_unsigned,
747                                                   &wt, use_dist_wtd_comp_avg);
748 
749             const __m256i round_result = convolve_rounding(
750                 &comp_avg_res, &offset_const, &rounding_const, rounding_shift);
751 
752             const __m256i res_8 =
753                 _mm256_packus_epi16(round_result, round_result);
754             const __m128i res_0 = _mm256_castsi256_si128(res_8);
755             const __m128i res_1 = _mm256_extracti128_si256(res_8, 1);
756 
757             *(int *)(&dst0[i * dst_stride0 + j]) = _mm_cvtsi128_si32(res_0);
758             *(int *)(&dst0[i * dst_stride0 + j + dst_stride0]) =
759                 _mm_cvtsi128_si32(res_1);
760 
761           } else {
762             const __m128i res_0 = _mm256_castsi256_si128(res_unsigned);
763             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j]), res_0);
764 
765             const __m128i res_1 = _mm256_extracti128_si256(res_unsigned, 1);
766             _mm_store_si128((__m128i *)(&dst[i * dst_stride + j + dst_stride]),
767                             res_1);
768           }
769         }
770         s[0] = s[1];
771         s[1] = s[2];
772         s[3] = s[4];
773         s[4] = s[5];
774       }
775     }
776   } else {
777     int im_h = h + filter_params_y->taps - 1;
778     const int fo_vert = filter_params_y->taps / 2 - 1;
779     const int fo_horiz = filter_params_x->taps / 2 - 1;
780     const uint8_t *const src_ptr = src - fo_vert * src_stride - fo_horiz;
781 
782     filt[2] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 2));
783     filt[3] = _mm256_load_si256((__m256i const *)(filt_global_avx2 + 32 * 3));
784 
785     for (int j = 0; j < w; j += 8) {
786       /* Horizontal filter */
787       const uint8_t *src_h = src_ptr + j;
788       DIST_WTD_CONVOLVE_HORIZONTAL_FILTER_8TAP;
789 
790       DIST_WTD_CONVOLVE_VERTICAL_FILTER_8TAP;
791     }
792   }
793 }
794 
795 #define DO_NO_AVG_2D_COPY_4X16(r0, c0, r1, c1, r2, c2, r3, c3)          \
796   do {                                                                  \
797     src_0 = _mm256_cvtepu8_epi16(                                       \
798         _mm_loadu_si128((__m128i *)(&src[r0 * src_stride + c0])));      \
799     src_1 = _mm256_cvtepu8_epi16(                                       \
800         _mm_loadu_si128((__m128i *)(&src[r1 * src_stride + c1])));      \
801     src_2 = _mm256_cvtepu8_epi16(                                       \
802         _mm_loadu_si128((__m128i *)(&src[r2 * src_stride + c2])));      \
803     src_3 = _mm256_cvtepu8_epi16(                                       \
804         _mm_loadu_si128((__m128i *)(&src[r3 * src_stride + c3])));      \
805                                                                         \
806     src_0 = _mm256_slli_epi16(src_0, LEFT_SHIFT);                       \
807     src_1 = _mm256_slli_epi16(src_1, LEFT_SHIFT);                       \
808     src_2 = _mm256_slli_epi16(src_2, LEFT_SHIFT);                       \
809     src_3 = _mm256_slli_epi16(src_3, LEFT_SHIFT);                       \
810                                                                         \
811     src_0 = _mm256_add_epi16(src_0, offset_const);                      \
812     src_1 = _mm256_add_epi16(src_1, offset_const);                      \
813     src_2 = _mm256_add_epi16(src_2, offset_const);                      \
814     src_3 = _mm256_add_epi16(src_3, offset_const);                      \
815                                                                         \
816     _mm256_store_si256((__m256i *)(&dst[r0 * dst_stride + c0]), src_0); \
817     _mm256_store_si256((__m256i *)(&dst[r1 * dst_stride + c1]), src_1); \
818     _mm256_store_si256((__m256i *)(&dst[r2 * dst_stride + c2]), src_2); \
819     _mm256_store_si256((__m256i *)(&dst[r3 * dst_stride + c3]), src_3); \
820   } while (0)
821 
822 #define LEFT_SHIFT (2 * FILTER_BITS - 3 - 7)
av1_dist_wtd_convolve_2d_no_avg_copy_avx2(const uint8_t * src,int src_stride,CONV_BUF_TYPE * dst,int dst_stride,int w,int h,const __m256i offset_const)823 static inline void av1_dist_wtd_convolve_2d_no_avg_copy_avx2(
824     const uint8_t *src, int src_stride, CONV_BUF_TYPE *dst, int dst_stride,
825     int w, int h, const __m256i offset_const) {
826   int i = h;
827   if (w >= 16) {
828     __m256i src_0, src_1, src_2, src_3;
829     if (w == 128) {
830       do {
831         DO_NO_AVG_2D_COPY_4X16(0, 0, 0, 16, 0, 32, 0, 48);
832         DO_NO_AVG_2D_COPY_4X16(0, 64, 0, 80, 0, 96, 0, 112);
833         src += 1 * src_stride;
834         dst += 1 * dst_stride;
835         i -= 1;
836       } while (i);
837     } else if (w == 64) {
838       do {
839         DO_NO_AVG_2D_COPY_4X16(0, 0, 0, 16, 0, 32, 0, 48);
840         src += 1 * src_stride;
841         dst += 1 * dst_stride;
842         i -= 1;
843       } while (i);
844     } else if (w == 32) {
845       do {
846         DO_NO_AVG_2D_COPY_4X16(0, 0, 1, 0, 0, 16, 1, 16);
847         src += 2 * src_stride;
848         dst += 2 * dst_stride;
849         i -= 2;
850       } while (i);
851     } else if (w == 16) {
852       do {
853         DO_NO_AVG_2D_COPY_4X16(0, 0, 1, 0, 2, 0, 3, 0);
854         src += 4 * src_stride;
855         dst += 4 * dst_stride;
856         i -= 4;
857       } while (i);
858     }
859   } else {
860     const __m256i zero = _mm256_setzero_si256();
861     do {
862       const __m128i src_row_0 =
863           _mm_loadl_epi64((__m128i *)(&src[0 * src_stride]));
864       const __m128i src_row_1 =
865           _mm_loadl_epi64((__m128i *)(&src[1 * src_stride]));
866       const __m128i src_row_2 =
867           _mm_loadl_epi64((__m128i *)(&src[2 * src_stride]));
868       const __m128i src_row_3 =
869           _mm_loadl_epi64((__m128i *)(&src[3 * src_stride]));
870 
871       __m256i src_10 = _mm256_insertf128_si256(
872           _mm256_castsi128_si256(src_row_0), src_row_1, 1);
873       __m256i src_32 = _mm256_insertf128_si256(
874           _mm256_castsi128_si256(src_row_2), src_row_3, 1);
875 
876       src_10 = _mm256_unpacklo_epi8(src_10, zero);
877       src_32 = _mm256_unpacklo_epi8(src_32, zero);
878 
879       src_10 = _mm256_slli_epi16(src_10, LEFT_SHIFT);
880       src_32 = _mm256_slli_epi16(src_32, LEFT_SHIFT);
881 
882       src_10 = _mm256_add_epi16(src_10, offset_const);
883       src_32 = _mm256_add_epi16(src_32, offset_const);
884 
885       // Accumulate values into the destination buffer
886       _mm_store_si128((__m128i *)(&dst[0 * dst_stride]),
887                       _mm256_castsi256_si128(src_10));
888       _mm_store_si128((__m128i *)(&dst[1 * dst_stride]),
889                       _mm256_extracti128_si256(src_10, 1));
890       _mm_store_si128((__m128i *)(&dst[2 * dst_stride]),
891                       _mm256_castsi256_si128(src_32));
892       _mm_store_si128((__m128i *)(&dst[3 * dst_stride]),
893                       _mm256_extracti128_si256(src_32, 1));
894 
895       src += 4 * src_stride;
896       dst += 4 * dst_stride;
897       i -= 4;
898     } while (i);
899   }
900 }
901 
902 #define DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, r0, c0, r1, c1, r2, c2, r3, c3) \
903   do {                                                                         \
904     src_0 = _mm256_cvtepu8_epi16(                                              \
905         _mm_loadu_si128((__m128i *)(&src[r0 * src_stride + c0])));             \
906     src_1 = _mm256_cvtepu8_epi16(                                              \
907         _mm_loadu_si128((__m128i *)(&src[r1 * src_stride + c1])));             \
908     src_2 = _mm256_cvtepu8_epi16(                                              \
909         _mm_loadu_si128((__m128i *)(&src[r2 * src_stride + c2])));             \
910     src_3 = _mm256_cvtepu8_epi16(                                              \
911         _mm_loadu_si128((__m128i *)(&src[r3 * src_stride + c3])));             \
912                                                                                \
913     src_0 = _mm256_slli_epi16(src_0, LEFT_SHIFT);                              \
914     src_1 = _mm256_slli_epi16(src_1, LEFT_SHIFT);                              \
915     src_2 = _mm256_slli_epi16(src_2, LEFT_SHIFT);                              \
916     src_3 = _mm256_slli_epi16(src_3, LEFT_SHIFT);                              \
917     src_0 = _mm256_add_epi16(src_0, offset_const);                             \
918     src_1 = _mm256_add_epi16(src_1, offset_const);                             \
919     src_2 = _mm256_add_epi16(src_2, offset_const);                             \
920     src_3 = _mm256_add_epi16(src_3, offset_const);                             \
921                                                                                \
922     ref_0 = _mm256_loadu_si256((__m256i *)(&dst[r0 * dst_stride + c0]));       \
923     ref_1 = _mm256_loadu_si256((__m256i *)(&dst[r1 * dst_stride + c1]));       \
924     ref_2 = _mm256_loadu_si256((__m256i *)(&dst[r2 * dst_stride + c2]));       \
925     ref_3 = _mm256_loadu_si256((__m256i *)(&dst[r3 * dst_stride + c3]));       \
926                                                                                \
927     res_0 = comp_avg(&ref_0, &src_0, &wt, USE_DIST_WEIGHTED);                  \
928     res_1 = comp_avg(&ref_1, &src_1, &wt, USE_DIST_WEIGHTED);                  \
929     res_2 = comp_avg(&ref_2, &src_2, &wt, USE_DIST_WEIGHTED);                  \
930     res_3 = comp_avg(&ref_3, &src_3, &wt, USE_DIST_WEIGHTED);                  \
931                                                                                \
932     res_0 = convolve_rounding(&res_0, &offset_const, &rounding_const,          \
933                               rounding_shift);                                 \
934     res_1 = convolve_rounding(&res_1, &offset_const, &rounding_const,          \
935                               rounding_shift);                                 \
936     res_2 = convolve_rounding(&res_2, &offset_const, &rounding_const,          \
937                               rounding_shift);                                 \
938     res_3 = convolve_rounding(&res_3, &offset_const, &rounding_const,          \
939                               rounding_shift);                                 \
940                                                                                \
941     res_10 = _mm256_packus_epi16(res_0, res_1);                                \
942     res_32 = _mm256_packus_epi16(res_2, res_3);                                \
943     res_10 = _mm256_permute4x64_epi64(res_10, 0xD8);                           \
944     res_32 = _mm256_permute4x64_epi64(res_32, 0xD8);                           \
945                                                                                \
946     _mm_store_si128((__m128i *)(&dst0[r0 * dst_stride0 + c0]),                 \
947                     _mm256_castsi256_si128(res_10));                           \
948     _mm_store_si128((__m128i *)(&dst0[r1 * dst_stride0 + c1]),                 \
949                     _mm256_extracti128_si256(res_10, 1));                      \
950     _mm_store_si128((__m128i *)(&dst0[r2 * dst_stride0 + c2]),                 \
951                     _mm256_castsi256_si128(res_32));                           \
952     _mm_store_si128((__m128i *)(&dst0[r3 * dst_stride0 + c3]),                 \
953                     _mm256_extracti128_si256(res_32, 1));                      \
954   } while (0)
955 
956 #define DO_AVG_2D_COPY(USE_DIST_WEIGHTED)                                     \
957   int i = h;                                                                  \
958   if (w >= 16) {                                                              \
959     __m256i src_0, src_1, src_2, src_3;                                       \
960     __m256i ref_0, ref_1, ref_2, ref_3;                                       \
961     __m256i res_0, res_1, res_2, res_3;                                       \
962     __m256i res_10, res_32;                                                   \
963     if (w == 128) {                                                           \
964       do {                                                                    \
965         DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 0, 16, 0, 32, 0, 48);    \
966         DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 64, 0, 80, 0, 96, 0, 112);  \
967         i -= 1;                                                               \
968         src += 1 * src_stride;                                                \
969         dst += 1 * dst_stride;                                                \
970         dst0 += 1 * dst_stride0;                                              \
971       } while (i);                                                            \
972     } else if (w == 64) {                                                     \
973       do {                                                                    \
974         DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 0, 16, 0, 32, 0, 48);    \
975                                                                               \
976         i -= 1;                                                               \
977         src += 1 * src_stride;                                                \
978         dst += 1 * dst_stride;                                                \
979         dst0 += 1 * dst_stride0;                                              \
980       } while (i);                                                            \
981     } else if (w == 32) {                                                     \
982       do {                                                                    \
983         DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 1, 0, 0, 16, 1, 16);     \
984                                                                               \
985         i -= 2;                                                               \
986         src += 2 * src_stride;                                                \
987         dst += 2 * dst_stride;                                                \
988         dst0 += 2 * dst_stride0;                                              \
989       } while (i);                                                            \
990     } else {                                                                  \
991       assert(w == 16);                                                        \
992       do {                                                                    \
993         DO_AVG_2D_COPY_4X16(USE_DIST_WEIGHTED, 0, 0, 1, 0, 2, 0, 3, 0);       \
994                                                                               \
995         i -= 4;                                                               \
996         src += 4 * src_stride;                                                \
997         dst += 4 * dst_stride;                                                \
998         dst0 += 4 * dst_stride0;                                              \
999       } while (i);                                                            \
1000     }                                                                         \
1001   } else if (w == 8) {                                                        \
1002     do {                                                                      \
1003       const __m128i src_0 =                                                   \
1004           _mm_loadl_epi64((__m128i *)(&src[0 * src_stride]));                 \
1005       const __m128i src_1 =                                                   \
1006           _mm_loadl_epi64((__m128i *)(&src[1 * src_stride]));                 \
1007       const __m128i src_2 =                                                   \
1008           _mm_loadl_epi64((__m128i *)(&src[2 * src_stride]));                 \
1009       const __m128i src_3 =                                                   \
1010           _mm_loadl_epi64((__m128i *)(&src[3 * src_stride]));                 \
1011       __m256i src_10 =                                                        \
1012           _mm256_insertf128_si256(_mm256_castsi128_si256(src_0), src_1, 1);   \
1013       __m256i src_32 =                                                        \
1014           _mm256_insertf128_si256(_mm256_castsi128_si256(src_2), src_3, 1);   \
1015                                                                               \
1016       src_10 = _mm256_unpacklo_epi8(src_10, zero);                            \
1017       src_32 = _mm256_unpacklo_epi8(src_32, zero);                            \
1018                                                                               \
1019       src_10 = _mm256_slli_epi16(src_10, LEFT_SHIFT);                         \
1020       src_32 = _mm256_slli_epi16(src_32, LEFT_SHIFT);                         \
1021                                                                               \
1022       src_10 = _mm256_add_epi16(src_10, offset_const);                        \
1023       src_32 = _mm256_add_epi16(src_32, offset_const);                        \
1024                                                                               \
1025       const __m256i ref_10 =                                                  \
1026           load_line2_avx2(&dst[0 * dst_stride], &dst[1 * dst_stride]);        \
1027       const __m256i ref_32 =                                                  \
1028           load_line2_avx2(&dst[2 * dst_stride], &dst[3 * dst_stride]);        \
1029       __m256i res_10 = comp_avg(&ref_10, &src_10, &wt, USE_DIST_WEIGHTED);    \
1030       __m256i res_32 = comp_avg(&ref_32, &src_32, &wt, USE_DIST_WEIGHTED);    \
1031                                                                               \
1032       res_10 = convolve_rounding(&res_10, &offset_const, &rounding_const,     \
1033                                  rounding_shift);                             \
1034       res_32 = convolve_rounding(&res_32, &offset_const, &rounding_const,     \
1035                                  rounding_shift);                             \
1036                                                                               \
1037       __m256i res = _mm256_packus_epi16(res_10, res_32);                      \
1038       const __m128i res_20 = _mm256_castsi256_si128(res);                     \
1039       const __m128i res_31 = _mm256_extracti128_si256(res, 1);                \
1040                                                                               \
1041       _mm_storel_epi64((__m128i *)(&dst0[0 * dst_stride0]), res_20);          \
1042       _mm_storel_epi64((__m128i *)((&dst0[1 * dst_stride0])), res_31);        \
1043       _mm_storeh_epi64((__m128i *)(&dst0[2 * dst_stride0]), res_20);          \
1044       _mm_storeh_epi64((__m128i *)((&dst0[3 * dst_stride0])), res_31);        \
1045       i -= 4;                                                                 \
1046       src += 4 * src_stride;                                                  \
1047       dst += 4 * dst_stride;                                                  \
1048       dst0 += 4 * dst_stride0;                                                \
1049     } while (i);                                                              \
1050   } else {                                                                    \
1051     assert(w == 4);                                                           \
1052     do {                                                                      \
1053       __m256i src_3210_8bit =                                                 \
1054           _mm256_setr_epi32(loadu_int32(src + 0 * src_stride),                \
1055                             loadu_int32(src + 1 * src_stride), 0, 0,          \
1056                             loadu_int32(src + 2 * src_stride),                \
1057                             loadu_int32(src + 3 * src_stride), 0, 0);         \
1058                                                                               \
1059       __m256i src_3210 = _mm256_unpacklo_epi8(src_3210_8bit, zero);           \
1060       src_3210 = _mm256_slli_epi16(src_3210, LEFT_SHIFT);                     \
1061       src_3210 = _mm256_add_epi16(src_3210, offset_const);                    \
1062                                                                               \
1063       __m256i ref_3210 =                                                      \
1064           _mm256_setr_epi64x(*(int64_t *)(dst + 0 * dst_stride),              \
1065                              *(int64_t *)(dst + 1 * dst_stride),              \
1066                              *(int64_t *)(dst + 2 * dst_stride),              \
1067                              *(int64_t *)(dst + 3 * dst_stride));             \
1068       __m256i res_3210 =                                                      \
1069           comp_avg(&ref_3210, &src_3210, &wt, USE_DIST_WEIGHTED);             \
1070                                                                               \
1071       res_3210 = convolve_rounding(&res_3210, &offset_const, &rounding_const, \
1072                                    rounding_shift);                           \
1073                                                                               \
1074       res_3210 = _mm256_packus_epi16(res_3210, res_3210);                     \
1075       const __m128i res_10 = _mm256_castsi256_si128(res_3210);                \
1076       const __m128i res_32 = _mm256_extracti128_si256(res_3210, 1);           \
1077                                                                               \
1078       *(int *)(&dst0[0 * dst_stride0]) = _mm_cvtsi128_si32(res_10);           \
1079       *(int *)(&dst0[2 * dst_stride0]) = _mm_cvtsi128_si32(res_32);           \
1080       *(int *)(&dst0[1 * dst_stride0]) = _mm_extract_epi32(res_10, 1);        \
1081       *(int *)(&dst0[3 * dst_stride0]) = _mm_extract_epi32(res_32, 1);        \
1082       i -= 4;                                                                 \
1083       src += 4 * src_stride;                                                  \
1084       dst += 4 * dst_stride;                                                  \
1085       dst0 += 4 * dst_stride0;                                                \
1086     } while (i);                                                              \
1087   }
1088 
av1_dist_wtd_convolve_2d_copy_avx2(const uint8_t * src,int src_stride,uint8_t * dst0,int dst_stride0,int w,int h,ConvolveParams * conv_params)1089 void av1_dist_wtd_convolve_2d_copy_avx2(const uint8_t *src, int src_stride,
1090                                         uint8_t *dst0, int dst_stride0, int w,
1091                                         int h, ConvolveParams *conv_params) {
1092   const int bd = 8;
1093   CONV_BUF_TYPE *dst = conv_params->dst;
1094   int dst_stride = conv_params->dst_stride;
1095   assert(conv_params->round_0 == 3);
1096   assert(conv_params->round_1 == 7);
1097   assert(w % 4 == 0);
1098   assert(h % 4 == 0);
1099 
1100   const int do_average = conv_params->do_average;
1101   const int use_dist_wtd_comp_avg = conv_params->use_dist_wtd_comp_avg;
1102   const __m256i wt = unpack_weights_avx2(conv_params);
1103   const __m256i zero = _mm256_setzero_si256();
1104 
1105   const int offset_0 =
1106       bd + 2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
1107   const int offset = (1 << offset_0) + (1 << (offset_0 - 1));
1108   const __m256i offset_const = _mm256_set1_epi16(offset);
1109   const int rounding_shift =
1110       2 * FILTER_BITS - conv_params->round_0 - conv_params->round_1;
1111   const __m256i rounding_const = _mm256_set1_epi16((1 << rounding_shift) >> 1);
1112 
1113   if (do_average) {
1114     if (use_dist_wtd_comp_avg) {
1115       DO_AVG_2D_COPY(1)
1116     } else {
1117       DO_AVG_2D_COPY(0)
1118     }
1119   } else {
1120     av1_dist_wtd_convolve_2d_no_avg_copy_avx2(src, src_stride, dst, dst_stride,
1121                                               w, h, offset_const);
1122   }
1123 }
1124 #undef LEFT_SHIFT
1125