xref: /aosp_15_r20/external/libaom/av1/encoder/x86/av1_k_means_sse2.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2021, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <emmintrin.h>  // SSE2
13 
14 #include "config/av1_rtcd.h"
15 #include "aom_dsp/x86/synonyms.h"
16 
k_means_horizontal_sum_sse2(__m128i a)17 static int64_t k_means_horizontal_sum_sse2(__m128i a) {
18   const __m128i sum1 = _mm_unpackhi_epi64(a, a);
19   const __m128i sum2 = _mm_add_epi64(a, sum1);
20   int64_t res;
21   _mm_storel_epi64((__m128i *)&res, sum2);
22   return res;
23 }
24 
av1_calc_indices_dim1_sse2(const int16_t * data,const int16_t * centroids,uint8_t * indices,int64_t * total_dist,int n,int k)25 void av1_calc_indices_dim1_sse2(const int16_t *data, const int16_t *centroids,
26                                 uint8_t *indices, int64_t *total_dist, int n,
27                                 int k) {
28   const __m128i v_zero = _mm_setzero_si128();
29   __m128i sum = _mm_setzero_si128();
30   __m128i cents[PALETTE_MAX_SIZE];
31   for (int j = 0; j < k; ++j) {
32     cents[j] = _mm_set1_epi16(centroids[j]);
33   }
34 
35   for (int i = 0; i < n; i += 8) {
36     const __m128i in = _mm_loadu_si128((__m128i *)data);
37     __m128i ind = _mm_setzero_si128();
38     // Compute the distance to the first centroid.
39     __m128i d1 = _mm_sub_epi16(in, cents[0]);
40     __m128i d2 = _mm_sub_epi16(cents[0], in);
41     __m128i dist_min = _mm_max_epi16(d1, d2);
42 
43     for (int j = 1; j < k; ++j) {
44       // Compute the distance to the centroid.
45       d1 = _mm_sub_epi16(in, cents[j]);
46       d2 = _mm_sub_epi16(cents[j], in);
47       const __m128i dist = _mm_max_epi16(d1, d2);
48       // Compare to the minimal one.
49       const __m128i cmp = _mm_cmpgt_epi16(dist_min, dist);
50       dist_min = _mm_min_epi16(dist_min, dist);
51       const __m128i ind1 = _mm_set1_epi16(j);
52       ind = _mm_or_si128(_mm_andnot_si128(cmp, ind), _mm_and_si128(cmp, ind1));
53     }
54     if (total_dist) {
55       // Square, convert to 32 bit and add together.
56       dist_min = _mm_madd_epi16(dist_min, dist_min);
57       // Convert to 64 bit and add to sum.
58       const __m128i dist1 = _mm_unpacklo_epi32(dist_min, v_zero);
59       const __m128i dist2 = _mm_unpackhi_epi32(dist_min, v_zero);
60       sum = _mm_add_epi64(sum, dist1);
61       sum = _mm_add_epi64(sum, dist2);
62     }
63     __m128i p2 = _mm_packus_epi16(ind, v_zero);
64     _mm_storel_epi64((__m128i *)indices, p2);
65     indices += 8;
66     data += 8;
67   }
68   if (total_dist) {
69     *total_dist = k_means_horizontal_sum_sse2(sum);
70   }
71 }
72 
av1_calc_indices_dim2_sse2(const int16_t * data,const int16_t * centroids,uint8_t * indices,int64_t * total_dist,int n,int k)73 void av1_calc_indices_dim2_sse2(const int16_t *data, const int16_t *centroids,
74                                 uint8_t *indices, int64_t *total_dist, int n,
75                                 int k) {
76   const __m128i v_zero = _mm_setzero_si128();
77   __m128i sum = _mm_setzero_si128();
78   __m128i ind[2];
79   __m128i cents[PALETTE_MAX_SIZE];
80   for (int j = 0; j < k; ++j) {
81     const int16_t cx = centroids[2 * j], cy = centroids[2 * j + 1];
82     cents[j] = _mm_set_epi16(cy, cx, cy, cx, cy, cx, cy, cx);
83   }
84 
85   for (int i = 0; i < n; i += 8) {
86     for (int l = 0; l < 2; ++l) {
87       const __m128i in = _mm_loadu_si128((__m128i *)data);
88       ind[l] = _mm_setzero_si128();
89       // Compute the distance to the first centroid.
90       __m128i d1 = _mm_sub_epi16(in, cents[0]);
91       __m128i dist_min = _mm_madd_epi16(d1, d1);
92 
93       for (int j = 1; j < k; ++j) {
94         // Compute the distance to the centroid.
95         d1 = _mm_sub_epi16(in, cents[j]);
96         const __m128i dist = _mm_madd_epi16(d1, d1);
97         // Compare to the minimal one.
98         const __m128i cmp = _mm_cmpgt_epi32(dist_min, dist);
99         const __m128i dist1 = _mm_andnot_si128(cmp, dist_min);
100         const __m128i dist2 = _mm_and_si128(cmp, dist);
101         dist_min = _mm_or_si128(dist1, dist2);
102         const __m128i ind1 = _mm_set1_epi32(j);
103         ind[l] = _mm_or_si128(_mm_andnot_si128(cmp, ind[l]),
104                               _mm_and_si128(cmp, ind1));
105       }
106       if (total_dist) {
107         // Convert to 64 bit and add to sum.
108         const __m128i dist1 = _mm_unpacklo_epi32(dist_min, v_zero);
109         const __m128i dist2 = _mm_unpackhi_epi32(dist_min, v_zero);
110         sum = _mm_add_epi64(sum, dist1);
111         sum = _mm_add_epi64(sum, dist2);
112       }
113       data += 8;
114     }
115     // Cast to 8 bit and store.
116     const __m128i d2 = _mm_packus_epi16(ind[0], ind[1]);
117     const __m128i d3 = _mm_packus_epi16(d2, v_zero);
118     _mm_storel_epi64((__m128i *)indices, d3);
119     indices += 8;
120   }
121   if (total_dist) {
122     *total_dist = k_means_horizontal_sum_sse2(sum);
123   }
124 }
125