xref: /aosp_15_r20/external/libaom/aom_dsp/x86/blk_sse_sum_avx2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2019, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <immintrin.h>
13 
14 #include "config/aom_dsp_rtcd.h"
15 
accumulate_sse_sum(__m256i regx_sum,__m256i regx2_sum,int * x_sum,int64_t * x2_sum)16 static inline void accumulate_sse_sum(__m256i regx_sum, __m256i regx2_sum,
17                                       int *x_sum, int64_t *x2_sum) {
18   __m256i sum_buffer, sse_buffer;
19   __m128i out_buffer;
20 
21   // Accumulate the various elements of register into first element.
22   sum_buffer = _mm256_permute2f128_si256(regx_sum, regx_sum, 1);
23   regx_sum = _mm256_add_epi32(sum_buffer, regx_sum);
24   regx_sum = _mm256_add_epi32(regx_sum, _mm256_srli_si256(regx_sum, 8));
25   regx_sum = _mm256_add_epi32(regx_sum, _mm256_srli_si256(regx_sum, 4));
26 
27   sse_buffer = _mm256_permute2f128_si256(regx2_sum, regx2_sum, 1);
28   regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum);
29   regx2_sum = _mm256_add_epi64(regx2_sum, _mm256_srli_si256(regx2_sum, 8));
30 
31   out_buffer = _mm256_castsi256_si128(regx_sum);
32   *x_sum += _mm_cvtsi128_si32(out_buffer);
33   out_buffer = _mm256_castsi256_si128(regx2_sum);
34 #if AOM_ARCH_X86_64
35   *x2_sum += _mm_cvtsi128_si64(out_buffer);
36 #else
37   {
38     int64_t tmp;
39     _mm_storel_epi64((__m128i *)&tmp, out_buffer);
40     *x2_sum += tmp;
41   }
42 #endif
43 }
44 
sse_sum_wd4_avx2(const int16_t * data,int stride,int bh,int * x_sum,int64_t * x2_sum)45 static inline void sse_sum_wd4_avx2(const int16_t *data, int stride, int bh,
46                                     int *x_sum, int64_t *x2_sum) {
47   __m128i row1, row2, row3;
48   __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer,
49       temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer;
50   const int16_t *data_tmp = data;
51   __m256i one = _mm256_set1_epi16(1);
52   regx_sum = _mm256_setzero_si256();
53   regx2_sum = regx_sum;
54   sum_buffer = _mm256_setzero_si256();
55   sse_buffer = sum_buffer;
56 
57   for (int j = 0; j < (bh >> 2); ++j) {
58     // Load 4 rows at a time.
59     row1 = _mm_loadl_epi64((__m128i const *)(data_tmp));
60     row2 = _mm_loadl_epi64((__m128i const *)(data_tmp + stride));
61     row1 = _mm_unpacklo_epi64(row1, row2);
62     row2 = _mm_loadl_epi64((__m128i const *)(data_tmp + 2 * stride));
63     row3 = _mm_loadl_epi64((__m128i const *)(data_tmp + 3 * stride));
64     row2 = _mm_unpacklo_epi64(row2, row3);
65     load_pixels =
66         _mm256_insertf128_si256(_mm256_castsi128_si256(row1), row2, 1);
67 
68     row_sum_buffer = _mm256_madd_epi16(load_pixels, one);
69     row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels);
70     sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer);
71     sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer);
72     data_tmp += 4 * stride;
73   }
74 
75   // To prevent 32-bit variable overflow, unpack the elements to 64-bit.
76   temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256());
77   temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256());
78   sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2);
79   regx_sum = _mm256_add_epi32(sum_buffer, regx_sum);
80   regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum);
81 
82   accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum);
83 }
84 
sse_sum_wd8_avx2(const int16_t * data,int stride,int bh,int * x_sum,int64_t * x2_sum)85 static inline void sse_sum_wd8_avx2(const int16_t *data, int stride, int bh,
86                                     int *x_sum, int64_t *x2_sum) {
87   __m128i load_128bit, load_next_128bit;
88   __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer,
89       temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer;
90   const int16_t *data_tmp = data;
91   __m256i one = _mm256_set1_epi16(1);
92   regx_sum = _mm256_setzero_si256();
93   regx2_sum = regx_sum;
94   sum_buffer = _mm256_setzero_si256();
95   sse_buffer = sum_buffer;
96 
97   for (int j = 0; j < (bh >> 1); ++j) {
98     // Load 2 rows at a time.
99     load_128bit = _mm_loadu_si128((__m128i const *)(data_tmp));
100     load_next_128bit = _mm_loadu_si128((__m128i const *)(data_tmp + stride));
101     load_pixels = _mm256_insertf128_si256(_mm256_castsi128_si256(load_128bit),
102                                           load_next_128bit, 1);
103 
104     row_sum_buffer = _mm256_madd_epi16(load_pixels, one);
105     row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels);
106     sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer);
107     sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer);
108     data_tmp += 2 * stride;
109   }
110 
111   temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256());
112   temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256());
113   sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2);
114   regx_sum = _mm256_add_epi32(sum_buffer, regx_sum);
115   regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum);
116 
117   accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum);
118 }
119 
sse_sum_wd16_avx2(const int16_t * data,int stride,int bh,int * x_sum,int64_t * x2_sum,int loop_count)120 static inline void sse_sum_wd16_avx2(const int16_t *data, int stride, int bh,
121                                      int *x_sum, int64_t *x2_sum,
122                                      int loop_count) {
123   __m256i regx_sum, regx2_sum, load_pixels, sum_buffer, sse_buffer,
124       temp_buffer1, temp_buffer2, row_sum_buffer, row_sse_buffer;
125   const int16_t *data_tmp = data;
126   __m256i one = _mm256_set1_epi16(1);
127   regx_sum = _mm256_setzero_si256();
128   regx2_sum = regx_sum;
129   sum_buffer = _mm256_setzero_si256();
130   sse_buffer = sum_buffer;
131 
132   for (int i = 0; i < loop_count; ++i) {
133     data_tmp = data + 16 * i;
134     for (int j = 0; j < bh; ++j) {
135       load_pixels = _mm256_lddqu_si256((__m256i const *)(data_tmp));
136 
137       row_sum_buffer = _mm256_madd_epi16(load_pixels, one);
138       row_sse_buffer = _mm256_madd_epi16(load_pixels, load_pixels);
139       sum_buffer = _mm256_add_epi32(row_sum_buffer, sum_buffer);
140       sse_buffer = _mm256_add_epi32(row_sse_buffer, sse_buffer);
141       data_tmp += stride;
142     }
143   }
144 
145   temp_buffer1 = _mm256_unpacklo_epi32(sse_buffer, _mm256_setzero_si256());
146   temp_buffer2 = _mm256_unpackhi_epi32(sse_buffer, _mm256_setzero_si256());
147   sse_buffer = _mm256_add_epi64(temp_buffer1, temp_buffer2);
148   regx_sum = _mm256_add_epi32(sum_buffer, regx_sum);
149   regx2_sum = _mm256_add_epi64(sse_buffer, regx2_sum);
150 
151   accumulate_sse_sum(regx_sum, regx2_sum, x_sum, x2_sum);
152 }
153 
aom_get_blk_sse_sum_avx2(const int16_t * data,int stride,int bw,int bh,int * x_sum,int64_t * x2_sum)154 void aom_get_blk_sse_sum_avx2(const int16_t *data, int stride, int bw, int bh,
155                               int *x_sum, int64_t *x2_sum) {
156   *x_sum = 0;
157   *x2_sum = 0;
158 
159   if ((bh & 3) == 0) {
160     switch (bw) {
161         // For smaller block widths, compute multiple rows simultaneously.
162       case 4: sse_sum_wd4_avx2(data, stride, bh, x_sum, x2_sum); break;
163       case 8: sse_sum_wd8_avx2(data, stride, bh, x_sum, x2_sum); break;
164       case 16:
165       case 32:
166         sse_sum_wd16_avx2(data, stride, bh, x_sum, x2_sum, bw >> 4);
167         break;
168       case 64:
169         // 32-bit variables will overflow for 64 rows at a single time, so
170         // compute 32 rows at a time.
171         if (bh <= 32) {
172           sse_sum_wd16_avx2(data, stride, bh, x_sum, x2_sum, bw >> 4);
173         } else {
174           sse_sum_wd16_avx2(data, stride, 32, x_sum, x2_sum, bw >> 4);
175           sse_sum_wd16_avx2(data + 32 * stride, stride, 32, x_sum, x2_sum,
176                             bw >> 4);
177         }
178         break;
179 
180       default: aom_get_blk_sse_sum_c(data, stride, bw, bh, x_sum, x2_sum);
181     }
182   } else {
183     aom_get_blk_sse_sum_c(data, stride, bw, bh, x_sum, x2_sum);
184   }
185 }
186