1 /*
2 * Copyright (c) 2023 The WebM project authors. All Rights Reserved.
3 *
4 * Use of this source code is governed by a BSD-style license
5 * that can be found in the LICENSE file in the root of the source
6 * tree. An additional intellectual property rights grant can be found
7 * in the file PATENTS. All contributing project authors may
8 * be found in the AUTHORS file in the root of the source tree.
9 */
10
11 #include <assert.h>
12 #include <immintrin.h>
13
14 #include "./vpx_dsp_rtcd.h"
15
vpx_comp_avg_pred_avx2(uint8_t * comp_pred,const uint8_t * pred,int width,int height,const uint8_t * ref,int ref_stride)16 void vpx_comp_avg_pred_avx2(uint8_t *comp_pred, const uint8_t *pred, int width,
17 int height, const uint8_t *ref, int ref_stride) {
18 int row = 0;
19 // comp_pred and pred must be 32 byte aligned.
20 assert(((intptr_t)comp_pred % 32) == 0);
21 assert(((intptr_t)pred % 32) == 0);
22
23 if (width == 8) {
24 assert(height % 4 == 0);
25 do {
26 const __m256i p = _mm256_load_si256((const __m256i *)pred);
27 const __m128i r_0 = _mm_loadl_epi64((const __m128i *)ref);
28 const __m128i r_1 =
29 _mm_loadl_epi64((const __m128i *)(ref + 2 * ref_stride));
30
31 const __m128i r1 = _mm_castps_si128(_mm_loadh_pi(
32 _mm_castsi128_ps(r_0), (const __m64 *)(ref + ref_stride)));
33 const __m128i r2 = _mm_castps_si128(_mm_loadh_pi(
34 _mm_castsi128_ps(r_1), (const __m64 *)(ref + 3 * ref_stride)));
35
36 const __m256i ref_0123 =
37 _mm256_inserti128_si256(_mm256_castsi128_si256(r1), r2, 1);
38 const __m256i avg = _mm256_avg_epu8(p, ref_0123);
39
40 _mm256_store_si256((__m256i *)comp_pred, avg);
41
42 row += 4;
43 pred += 32;
44 comp_pred += 32;
45 ref += 4 * ref_stride;
46 } while (row < height);
47 } else if (width == 16) {
48 assert(height % 4 == 0);
49 do {
50 const __m256i pred_0 = _mm256_load_si256((const __m256i *)pred);
51 const __m256i pred_1 = _mm256_load_si256((const __m256i *)(pred + 32));
52 const __m256i tmp0 =
53 _mm256_castsi128_si256(_mm_loadu_si128((const __m128i *)ref));
54 const __m256i ref_0 = _mm256_inserti128_si256(
55 tmp0, _mm_loadu_si128((const __m128i *)(ref + ref_stride)), 1);
56 const __m256i tmp1 = _mm256_castsi128_si256(
57 _mm_loadu_si128((const __m128i *)(ref + 2 * ref_stride)));
58 const __m256i ref_1 = _mm256_inserti128_si256(
59 tmp1, _mm_loadu_si128((const __m128i *)(ref + 3 * ref_stride)), 1);
60 const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
61 const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
62 _mm256_store_si256((__m256i *)comp_pred, average_0);
63 _mm256_store_si256((__m256i *)(comp_pred + 32), average_1);
64
65 row += 4;
66 pred += 64;
67 comp_pred += 64;
68 ref += 4 * ref_stride;
69 } while (row < height);
70 } else if (width == 32) {
71 assert(height % 2 == 0);
72 do {
73 const __m256i pred_0 = _mm256_load_si256((const __m256i *)pred);
74 const __m256i pred_1 = _mm256_load_si256((const __m256i *)(pred + 32));
75 const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)ref);
76 const __m256i ref_1 =
77 _mm256_loadu_si256((const __m256i *)(ref + ref_stride));
78 const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
79 const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
80 _mm256_store_si256((__m256i *)comp_pred, average_0);
81 _mm256_store_si256((__m256i *)(comp_pred + 32), average_1);
82
83 row += 2;
84 pred += 64;
85 comp_pred += 64;
86 ref += 2 * ref_stride;
87 } while (row < height);
88 } else if (width % 64 == 0) {
89 do {
90 int x;
91 for (x = 0; x < width; x += 64) {
92 const __m256i pred_0 = _mm256_load_si256((const __m256i *)(pred + x));
93 const __m256i pred_1 =
94 _mm256_load_si256((const __m256i *)(pred + x + 32));
95 const __m256i ref_0 = _mm256_loadu_si256((const __m256i *)(ref + x));
96 const __m256i ref_1 =
97 _mm256_loadu_si256((const __m256i *)(ref + x + 32));
98 const __m256i average_0 = _mm256_avg_epu8(pred_0, ref_0);
99 const __m256i average_1 = _mm256_avg_epu8(pred_1, ref_1);
100 _mm256_store_si256((__m256i *)(comp_pred + x), average_0);
101 _mm256_store_si256((__m256i *)(comp_pred + x + 32), average_1);
102 }
103 row++;
104 pred += width;
105 comp_pred += width;
106 ref += ref_stride;
107 } while (row < height);
108 } else {
109 vpx_comp_avg_pred_sse2(comp_pred, pred, width, height, ref, ref_stride);
110 }
111 }
112