xref: /aosp_15_r20/external/libvpx/vpx_dsp/x86/avg_pred_avx2.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2023 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include <assert.h>
12 #include <immintrin.h>
13 
14 #include "./vpx_dsp_rtcd.h"
15 
vpx_comp_avg_pred_avx2(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride)16 void vpx_comp_avg_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
17                             int height, const uint8_t *ref, int ref_stride) {
18   int row = 0;
19   // comp_pred and pred must be 32 byte aligned.
20   assert(((intptr_t)comp_pred % 32) == 0);
21   assert(((intptr_t)pred % 32) == 0);
22 
23   if (width == 8) {
24     assert(height % 4 == 0);
25     do {
26       const __m256i p = _mm256_load_si256((const __m256i *)pred);
27       const __m128i r_0 = _mm_loadl_epi64((const __m128i *)ref);
28       const __m128i r_1 =
29           _mm_loadl_epi64((const __m128i *)(ref + 2 * ref_stride));
30 
31       const __m128i r1 = _mm_castps_si128(_mm_loadh_pi(
32           _mm_castsi128_ps(r_0), (const __m64 *)(ref + ref_stride)));
33       const __m128i r2 = _mm_castps_si128(_mm_loadh_pi(
34           _mm_castsi128_ps(r_1), (const __m64 *)(ref + 3 * ref_stride)));
35 
36       const __m256i ref_0123 =
37           _mm256_inserti128_si256(_mm256_castsi128_si256(r1), r2, 1);
38       const __m256i avg = _mm256_avg_epu8(p, ref_0123);
39 
40       _mm256_store_si256((__m256i *)comp_pred, avg);
41 
42       row += 4;
43       pred += 32;
44       comp_pred += 32;
45       ref += 4 * ref_stride;
46     } while (row < height);
47   } else if (width == 16) {
48     assert(height % 4 == 0);
49     do {
50       const __m256i pred_0 = _mm256_load_si256((const __m256i *)pred);
51       const __m256i pred_1 = _mm256_load_si256((const __m256i *)(pred + 32));
52       const __m256i tmp0 =
53           _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)ref));
54       const __m256i ref_0 = _mm256_inserti128_si256(
55           tmp0, _mm_loadu_si128((const __m128i *)(ref + ref_stride)), 1);
56       const __m256i tmp1 = _mm256_castsi128_si256(
57           _mm_loadu_si128((const __m128i *)(ref + 2 * ref_stride)));
58       const __m256i ref_1 = _mm256_inserti128_si256(
59           tmp1, _mm_loadu_si128((const __m128i *)(ref + 3 * ref_stride)), 1);
60       const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
61       const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
62       _mm256_store_si256((__m256i *)comp_pred, average_0);
63       _mm256_store_si256((__m256i *)(comp_pred + 32), average_1);
64 
65       row += 4;
66       pred += 64;
67       comp_pred += 64;
68       ref += 4 * ref_stride;
69     } while (row < height);
70   } else if (width == 32) {
71     assert(height % 2 == 0);
72     do {
73       const __m256i pred_0 = _mm256_load_si256((const __m256i *)pred);
74       const __m256i pred_1 = _mm256_load_si256((const __m256i *)(pred + 32));
75       const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)ref);
76       const __m256i ref_1 =
77           _mm256_loadu_si256((const __m256i *)(ref + ref_stride));
78       const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
79       const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
80       _mm256_store_si256((__m256i *)comp_pred, average_0);
81       _mm256_store_si256((__m256i *)(comp_pred + 32), average_1);
82 
83       row += 2;
84       pred += 64;
85       comp_pred += 64;
86       ref += 2 * ref_stride;
87     } while (row < height);
88   } else if (width % 64 == 0) {
89     do {
90       int x;
91       for (x = 0; x < width; x += 64) {
92         const __m256i pred_0 = _mm256_load_si256((const __m256i *)(pred + x));
93         const __m256i pred_1 =
94             _mm256_load_si256((const __m256i *)(pred + x + 32));
95         const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)(ref + x));
96         const __m256i ref_1 =
97             _mm256_loadu_si256((const __m256i *)(ref + x + 32));
98         const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
99         const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
100         _mm256_store_si256((__m256i *)(comp_pred + x), average_0);
101         _mm256_store_si256((__m256i *)(comp_pred + x + 32), average_1);
102       }
103       row++;
104       pred += width;
105       comp_pred += width;
106       ref += ref_stride;
107     } while (row < height);
108   } else {
109     vpx_comp_avg_pred_sse2(comp_pred, pred, width, height, ref, ref_stride);
110   }
111 }
112