xref: /aosp_15_r20/external/libvpx/vpx_dsp/loongarch/variance_lsx.c (revision fb1b10ab9aebc7c7068eedab379b749d7e3900be)
1 /*
2  *  Copyright (c) 2022 The WebM project authors. All Rights Reserved.
3  *
4  *  Use of this source code is governed by a BSD-style license
5  *  that can be found in the LICENSE file in the root of the source
6  *  tree. An additional intellectual property rights grant can be found
7  *  in the file PATENTS.  All contributing project authors may
8  *  be found in the AUTHORS file in the root of the source tree.
9  */
10 
11 #include "./vpx_dsp_rtcd.h"
12 #include "vpx_dsp/loongarch/variance_lsx.h"
13 
14 #define VARIANCE_WxH(sse, diff, shift) \
15   (sse) - (((uint32_t)(diff) * (diff)) >> (shift))
16 
17 #define VARIANCE_LARGE_WxH(sse, diff, shift) \
18   (sse) - (((int64_t)(diff) * (diff)) >> (shift))
19 
sse_diff_8width_lsx(const uint8_t * src_ptr,int32_t src_stride,const uint8_t * ref_ptr,int32_t ref_stride,int32_t height,int32_t * diff)20 static uint32_t sse_diff_8width_lsx(const uint8_t *src_ptr, int32_t src_stride,
21                                     const uint8_t *ref_ptr, int32_t ref_stride,
22                                     int32_t height, int32_t *diff) {
23   int32_t res, ht_cnt = (height >> 2);
24   __m128i src0, src1, src2, src3, ref0, ref1, ref2, ref3, vec;
25   __m128i avg = __lsx_vldi(0);
26   __m128i var = avg;
27   int32_t src_stride2 = src_stride << 1;
28   int32_t src_stride3 = src_stride2 + src_stride;
29   int32_t src_stride4 = src_stride2 << 1;
30   int32_t ref_stride2 = ref_stride << 1;
31   int32_t ref_stride3 = ref_stride2 + ref_stride;
32   int32_t ref_stride4 = ref_stride2 << 1;
33 
34   for (; ht_cnt--;) {
35     DUP4_ARG2(__lsx_vld, src_ptr, 0, src_ptr + src_stride, 0,
36               src_ptr + src_stride2, 0, src_ptr + src_stride3, 0, src0, src1,
37               src2, src3);
38     src_ptr += src_stride4;
39     DUP4_ARG2(__lsx_vld, ref_ptr, 0, ref_ptr + ref_stride, 0,
40               ref_ptr + ref_stride2, 0, ref_ptr + ref_stride3, 0, ref0, ref1,
41               ref2, ref3);
42     ref_ptr += ref_stride4;
43 
44     DUP4_ARG2(__lsx_vpickev_d, src1, src0, src3, src2, ref1, ref0, ref3, ref2,
45               src0, src1, ref0, ref1);
46     CALC_MSE_AVG_B(src0, ref0, var, avg);
47     CALC_MSE_AVG_B(src1, ref1, var, avg);
48   }
49 
50   vec = __lsx_vhaddw_w_h(avg, avg);
51   HADD_SW_S32(vec, *diff);
52   HADD_SW_S32(var, res);
53   return res;
54 }
55 
sse_diff_16width_lsx(const uint8_t * src_ptr,int32_t src_stride,const uint8_t * ref_ptr,int32_t ref_stride,int32_t height,int32_t * diff)56 static uint32_t sse_diff_16width_lsx(const uint8_t *src_ptr, int32_t src_stride,
57                                      const uint8_t *ref_ptr, int32_t ref_stride,
58                                      int32_t height, int32_t *diff) {
59   int32_t res, ht_cnt = (height >> 2);
60   __m128i src, ref, vec;
61   __m128i avg = __lsx_vldi(0);
62   __m128i var = avg;
63 
64   for (; ht_cnt--;) {
65     src = __lsx_vld(src_ptr, 0);
66     src_ptr += src_stride;
67     ref = __lsx_vld(ref_ptr, 0);
68     ref_ptr += ref_stride;
69     CALC_MSE_AVG_B(src, ref, var, avg);
70 
71     src = __lsx_vld(src_ptr, 0);
72     src_ptr += src_stride;
73     ref = __lsx_vld(ref_ptr, 0);
74     ref_ptr += ref_stride;
75     CALC_MSE_AVG_B(src, ref, var, avg);
76     src = __lsx_vld(src_ptr, 0);
77     src_ptr += src_stride;
78     ref = __lsx_vld(ref_ptr, 0);
79     ref_ptr += ref_stride;
80     CALC_MSE_AVG_B(src, ref, var, avg);
81 
82     src = __lsx_vld(src_ptr, 0);
83     src_ptr += src_stride;
84     ref = __lsx_vld(ref_ptr, 0);
85     ref_ptr += ref_stride;
86     CALC_MSE_AVG_B(src, ref, var, avg);
87   }
88   vec = __lsx_vhaddw_w_h(avg, avg);
89   HADD_SW_S32(vec, *diff);
90   HADD_SW_S32(var, res);
91   return res;
92 }
93 
sse_diff_32width_lsx(const uint8_t * src_ptr,int32_t src_stride,const uint8_t * ref_ptr,int32_t ref_stride,int32_t height,int32_t * diff)94 static uint32_t sse_diff_32width_lsx(const uint8_t *src_ptr, int32_t src_stride,
95                                      const uint8_t *ref_ptr, int32_t ref_stride,
96                                      int32_t height, int32_t *diff) {
97   int32_t res, ht_cnt = (height >> 2);
98   __m128i avg = __lsx_vldi(0);
99   __m128i src0, src1, ref0, ref1;
100   __m128i vec;
101   __m128i var = avg;
102 
103   for (; ht_cnt--;) {
104     DUP2_ARG2(__lsx_vld, src_ptr, 0, src_ptr, 16, src0, src1);
105     src_ptr += src_stride;
106     DUP2_ARG2(__lsx_vld, ref_ptr, 0, ref_ptr, 16, ref0, ref1);
107     ref_ptr += ref_stride;
108     CALC_MSE_AVG_B(src0, ref0, var, avg);
109     CALC_MSE_AVG_B(src1, ref1, var, avg);
110 
111     DUP2_ARG2(__lsx_vld, src_ptr, 0, src_ptr, 16, src0, src1);
112     src_ptr += src_stride;
113     DUP2_ARG2(__lsx_vld, ref_ptr, 0, ref_ptr, 16, ref0, ref1);
114     ref_ptr += ref_stride;
115     CALC_MSE_AVG_B(src0, ref0, var, avg);
116     CALC_MSE_AVG_B(src1, ref1, var, avg);
117 
118     DUP2_ARG2(__lsx_vld, src_ptr, 0, src_ptr, 16, src0, src1);
119     src_ptr += src_stride;
120     DUP2_ARG2(__lsx_vld, ref_ptr, 0, ref_ptr, 16, ref0, ref1);
121     ref_ptr += ref_stride;
122     CALC_MSE_AVG_B(src0, ref0, var, avg);
123     CALC_MSE_AVG_B(src1, ref1, var, avg);
124 
125     DUP2_ARG2(__lsx_vld, src_ptr, 0, src_ptr, 16, src0, src1);
126     src_ptr += src_stride;
127     DUP2_ARG2(__lsx_vld, ref_ptr, 0, ref_ptr, 16, ref0, ref1);
128     ref_ptr += ref_stride;
129     CALC_MSE_AVG_B(src0, ref0, var, avg);
130     CALC_MSE_AVG_B(src1, ref1, var, avg);
131   }
132 
133   vec = __lsx_vhaddw_w_h(avg, avg);
134   HADD_SW_S32(vec, *diff);
135   HADD_SW_S32(var, res);
136   return res;
137 }
138 
sse_diff_64x64_lsx(const uint8_t * src_ptr,int32_t src_stride,const uint8_t * ref_ptr,int32_t ref_stride,int32_t * diff)139 static uint32_t sse_diff_64x64_lsx(const uint8_t *src_ptr, int32_t src_stride,
140                                    const uint8_t *ref_ptr, int32_t ref_stride,
141                                    int32_t *diff) {
142   int32_t res, ht_cnt = 32;
143   __m128i avg0 = __lsx_vldi(0);
144   __m128i src0, src1, src2, src3;
145   __m128i ref0, ref1, ref2, ref3;
146   __m128i vec0, vec1;
147   __m128i avg1 = avg0;
148   __m128i avg2 = avg0;
149   __m128i avg3 = avg0;
150   __m128i var = avg0;
151 
152   for (; ht_cnt--;) {
153     DUP4_ARG2(__lsx_vld, src_ptr, 0, src_ptr, 16, src_ptr, 32, src_ptr, 48,
154               src0, src1, src2, src3);
155     src_ptr += src_stride;
156     DUP4_ARG2(__lsx_vld, ref_ptr, 0, ref_ptr, 16, ref_ptr, 32, ref_ptr, 48,
157               ref0, ref1, ref2, ref3);
158     ref_ptr += ref_stride;
159 
160     CALC_MSE_AVG_B(src0, ref0, var, avg0);
161     CALC_MSE_AVG_B(src1, ref1, var, avg1);
162     CALC_MSE_AVG_B(src2, ref2, var, avg2);
163     CALC_MSE_AVG_B(src3, ref3, var, avg3);
164     DUP4_ARG2(__lsx_vld, src_ptr, 0, src_ptr, 16, src_ptr, 32, src_ptr, 48,
165               src0, src1, src2, src3);
166     src_ptr += src_stride;
167     DUP4_ARG2(__lsx_vld, ref_ptr, 0, ref_ptr, 16, ref_ptr, 32, ref_ptr, 48,
168               ref0, ref1, ref2, ref3);
169     ref_ptr += ref_stride;
170     CALC_MSE_AVG_B(src0, ref0, var, avg0);
171     CALC_MSE_AVG_B(src1, ref1, var, avg1);
172     CALC_MSE_AVG_B(src2, ref2, var, avg2);
173     CALC_MSE_AVG_B(src3, ref3, var, avg3);
174   }
175   vec0 = __lsx_vhaddw_w_h(avg0, avg0);
176   vec1 = __lsx_vhaddw_w_h(avg1, avg1);
177   vec0 = __lsx_vadd_w(vec0, vec1);
178   vec1 = __lsx_vhaddw_w_h(avg2, avg2);
179   vec0 = __lsx_vadd_w(vec0, vec1);
180   vec1 = __lsx_vhaddw_w_h(avg3, avg3);
181   vec0 = __lsx_vadd_w(vec0, vec1);
182   HADD_SW_S32(vec0, *diff);
183   HADD_SW_S32(var, res);
184   return res;
185 }
186 
187 #define VARIANCE_8Wx8H(sse, diff) VARIANCE_WxH(sse, diff, 6)
188 #define VARIANCE_16Wx16H(sse, diff) VARIANCE_WxH(sse, diff, 8)
189 
190 #define VARIANCE_32Wx32H(sse, diff) VARIANCE_LARGE_WxH(sse, diff, 10)
191 #define VARIANCE_64Wx64H(sse, diff) VARIANCE_LARGE_WxH(sse, diff, 12)
192 
193 #define VPX_VARIANCE_WDXHT_LSX(wd, ht)                                         \
194   uint32_t vpx_variance##wd##x##ht##_lsx(                                      \
195       const uint8_t *src, int32_t src_stride, const uint8_t *ref,              \
196       int32_t ref_stride, uint32_t *sse) {                                     \
197     int32_t diff;                                                              \
198                                                                                \
199     *sse =                                                                     \
200         sse_diff_##wd##width_lsx(src, src_stride, ref, ref_stride, ht, &diff); \
201                                                                                \
202     return VARIANCE_##wd##Wx##ht##H(*sse, diff);                               \
203   }
204 
sse_16width_lsx(const uint8_t * src_ptr,int32_t src_stride,const uint8_t * ref_ptr,int32_t ref_stride,int32_t height)205 static uint32_t sse_16width_lsx(const uint8_t *src_ptr, int32_t src_stride,
206                                 const uint8_t *ref_ptr, int32_t ref_stride,
207                                 int32_t height) {
208   int32_t res, ht_cnt = (height >> 2);
209   __m128i src, ref;
210   __m128i var = __lsx_vldi(0);
211 
212   for (; ht_cnt--;) {
213     DUP2_ARG2(__lsx_vld, src_ptr, 0, ref_ptr, 0, src, ref);
214     src_ptr += src_stride;
215     ref_ptr += ref_stride;
216     CALC_MSE_B(src, ref, var);
217 
218     DUP2_ARG2(__lsx_vld, src_ptr, 0, ref_ptr, 0, src, ref);
219     src_ptr += src_stride;
220     ref_ptr += ref_stride;
221     CALC_MSE_B(src, ref, var);
222 
223     DUP2_ARG2(__lsx_vld, src_ptr, 0, ref_ptr, 0, src, ref);
224     src_ptr += src_stride;
225     ref_ptr += ref_stride;
226     CALC_MSE_B(src, ref, var);
227 
228     DUP2_ARG2(__lsx_vld, src_ptr, 0, ref_ptr, 0, src, ref);
229     src_ptr += src_stride;
230     ref_ptr += ref_stride;
231     CALC_MSE_B(src, ref, var);
232   }
233   HADD_SW_S32(var, res);
234   return res;
235 }
236 
237 VPX_VARIANCE_WDXHT_LSX(8, 8)
238 VPX_VARIANCE_WDXHT_LSX(16, 16)
239 VPX_VARIANCE_WDXHT_LSX(32, 32)
240 
vpx_variance64x64_lsx(const uint8_t * src,int32_t src_stride,const uint8_t * ref,int32_t ref_stride,uint32_t * sse)241 uint32_t vpx_variance64x64_lsx(const uint8_t *src, int32_t src_stride,
242                                const uint8_t *ref, int32_t ref_stride,
243                                uint32_t *sse) {
244   int32_t diff;
245 
246   *sse = sse_diff_64x64_lsx(src, src_stride, ref, ref_stride, &diff);
247 
248   return VARIANCE_64Wx64H(*sse, diff);
249 }
250 
vpx_mse16x16_lsx(const uint8_t * src,int32_t src_stride,const uint8_t * ref,int32_t ref_stride,uint32_t * sse)251 uint32_t vpx_mse16x16_lsx(const uint8_t *src, int32_t src_stride,
252                           const uint8_t *ref, int32_t ref_stride,
253                           uint32_t *sse) {
254   *sse = sse_16width_lsx(src, src_stride, ref, ref_stride, 16);
255 
256   return *sse;
257 }
258 
vpx_get16x16var_lsx(const uint8_t * src,int32_t src_stride,const uint8_t * ref,int32_t ref_stride,uint32_t * sse,int32_t * sum)259 void vpx_get16x16var_lsx(const uint8_t *src, int32_t src_stride,
260                          const uint8_t *ref, int32_t ref_stride, uint32_t *sse,
261                          int32_t *sum) {
262   *sse = sse_diff_16width_lsx(src, src_stride, ref, ref_stride, 16, sum);
263 }
264