xref: /aosp_15_r20/external/libaom/av1/encoder/arm/pickrst_sve.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1*77c1e3ccSAndroid Build Coastguard Worker /*
2*77c1e3ccSAndroid Build Coastguard Worker  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3*77c1e3ccSAndroid Build Coastguard Worker  *
4*77c1e3ccSAndroid Build Coastguard Worker  * This source code is subject to the terms of the BSD 2 Clause License and
5*77c1e3ccSAndroid Build Coastguard Worker  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6*77c1e3ccSAndroid Build Coastguard Worker  * was not distributed with this source code in the LICENSE file, you can
7*77c1e3ccSAndroid Build Coastguard Worker  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8*77c1e3ccSAndroid Build Coastguard Worker  * Media Patent License 1.0 was not distributed with this source code in the
9*77c1e3ccSAndroid Build Coastguard Worker  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10*77c1e3ccSAndroid Build Coastguard Worker  */
11*77c1e3ccSAndroid Build Coastguard Worker 
12*77c1e3ccSAndroid Build Coastguard Worker #include <arm_neon.h>
13*77c1e3ccSAndroid Build Coastguard Worker #include <arm_sve.h>
14*77c1e3ccSAndroid Build Coastguard Worker #include <assert.h>
15*77c1e3ccSAndroid Build Coastguard Worker #include <string.h>
16*77c1e3ccSAndroid Build Coastguard Worker 
17*77c1e3ccSAndroid Build Coastguard Worker #include "config/aom_config.h"
18*77c1e3ccSAndroid Build Coastguard Worker #include "config/av1_rtcd.h"
19*77c1e3ccSAndroid Build Coastguard Worker 
20*77c1e3ccSAndroid Build Coastguard Worker #include "aom_dsp/arm/aom_neon_sve_bridge.h"
21*77c1e3ccSAndroid Build Coastguard Worker #include "aom_dsp/arm/mem_neon.h"
22*77c1e3ccSAndroid Build Coastguard Worker #include "aom_dsp/arm/sum_neon.h"
23*77c1e3ccSAndroid Build Coastguard Worker #include "aom_dsp/arm/transpose_neon.h"
24*77c1e3ccSAndroid Build Coastguard Worker #include "av1/common/restoration.h"
25*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/pickrst.h"
26*77c1e3ccSAndroid Build Coastguard Worker #include "av1/encoder/arm/pickrst_sve.h"
27*77c1e3ccSAndroid Build Coastguard Worker 
find_average_sve(const uint8_t * src,int src_stride,int width,int height)28*77c1e3ccSAndroid Build Coastguard Worker static inline uint8_t find_average_sve(const uint8_t *src, int src_stride,
29*77c1e3ccSAndroid Build Coastguard Worker                                        int width, int height) {
30*77c1e3ccSAndroid Build Coastguard Worker   uint32x4_t avg_u32 = vdupq_n_u32(0);
31*77c1e3ccSAndroid Build Coastguard Worker   uint8x16_t ones = vdupq_n_u8(1);
32*77c1e3ccSAndroid Build Coastguard Worker 
33*77c1e3ccSAndroid Build Coastguard Worker   // Use a predicate to compute the last columns.
34*77c1e3ccSAndroid Build Coastguard Worker   svbool_t pattern = svwhilelt_b8_u32(0, width % 16);
35*77c1e3ccSAndroid Build Coastguard Worker 
36*77c1e3ccSAndroid Build Coastguard Worker   int h = height;
37*77c1e3ccSAndroid Build Coastguard Worker   do {
38*77c1e3ccSAndroid Build Coastguard Worker     int j = width;
39*77c1e3ccSAndroid Build Coastguard Worker     const uint8_t *src_ptr = src;
40*77c1e3ccSAndroid Build Coastguard Worker     while (j >= 16) {
41*77c1e3ccSAndroid Build Coastguard Worker       uint8x16_t s = vld1q_u8(src_ptr);
42*77c1e3ccSAndroid Build Coastguard Worker       avg_u32 = vdotq_u32(avg_u32, s, ones);
43*77c1e3ccSAndroid Build Coastguard Worker 
44*77c1e3ccSAndroid Build Coastguard Worker       j -= 16;
45*77c1e3ccSAndroid Build Coastguard Worker       src_ptr += 16;
46*77c1e3ccSAndroid Build Coastguard Worker     }
47*77c1e3ccSAndroid Build Coastguard Worker     uint8x16_t s_end = svget_neonq_u8(svld1_u8(pattern, src_ptr));
48*77c1e3ccSAndroid Build Coastguard Worker     avg_u32 = vdotq_u32(avg_u32, s_end, ones);
49*77c1e3ccSAndroid Build Coastguard Worker 
50*77c1e3ccSAndroid Build Coastguard Worker     src += src_stride;
51*77c1e3ccSAndroid Build Coastguard Worker   } while (--h != 0);
52*77c1e3ccSAndroid Build Coastguard Worker   return (uint8_t)(vaddlvq_u32(avg_u32) / (width * height));
53*77c1e3ccSAndroid Build Coastguard Worker }
54*77c1e3ccSAndroid Build Coastguard Worker 
compute_sub_avg(const uint8_t * buf,int buf_stride,int avg,int16_t * buf_avg,int buf_avg_stride,int width,int height,int downsample_factor)55*77c1e3ccSAndroid Build Coastguard Worker static inline void compute_sub_avg(const uint8_t *buf, int buf_stride, int avg,
56*77c1e3ccSAndroid Build Coastguard Worker                                    int16_t *buf_avg, int buf_avg_stride,
57*77c1e3ccSAndroid Build Coastguard Worker                                    int width, int height,
58*77c1e3ccSAndroid Build Coastguard Worker                                    int downsample_factor) {
59*77c1e3ccSAndroid Build Coastguard Worker   uint8x8_t avg_u8 = vdup_n_u8(avg);
60*77c1e3ccSAndroid Build Coastguard Worker 
61*77c1e3ccSAndroid Build Coastguard Worker   // Use a predicate to compute the last columns.
62*77c1e3ccSAndroid Build Coastguard Worker   svbool_t pattern = svwhilelt_b8_u32(0, width % 8);
63*77c1e3ccSAndroid Build Coastguard Worker 
64*77c1e3ccSAndroid Build Coastguard Worker   uint8x8_t avg_end = vget_low_u8(svget_neonq_u8(svdup_n_u8_z(pattern, avg)));
65*77c1e3ccSAndroid Build Coastguard Worker 
66*77c1e3ccSAndroid Build Coastguard Worker   do {
67*77c1e3ccSAndroid Build Coastguard Worker     int j = width;
68*77c1e3ccSAndroid Build Coastguard Worker     const uint8_t *buf_ptr = buf;
69*77c1e3ccSAndroid Build Coastguard Worker     int16_t *buf_avg_ptr = buf_avg;
70*77c1e3ccSAndroid Build Coastguard Worker     while (j >= 8) {
71*77c1e3ccSAndroid Build Coastguard Worker       uint8x8_t d = vld1_u8(buf_ptr);
72*77c1e3ccSAndroid Build Coastguard Worker       vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubl_u8(d, avg_u8)));
73*77c1e3ccSAndroid Build Coastguard Worker 
74*77c1e3ccSAndroid Build Coastguard Worker       j -= 8;
75*77c1e3ccSAndroid Build Coastguard Worker       buf_ptr += 8;
76*77c1e3ccSAndroid Build Coastguard Worker       buf_avg_ptr += 8;
77*77c1e3ccSAndroid Build Coastguard Worker     }
78*77c1e3ccSAndroid Build Coastguard Worker     uint8x8_t d_end = vget_low_u8(svget_neonq_u8(svld1_u8(pattern, buf_ptr)));
79*77c1e3ccSAndroid Build Coastguard Worker     vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubl_u8(d_end, avg_end)));
80*77c1e3ccSAndroid Build Coastguard Worker 
81*77c1e3ccSAndroid Build Coastguard Worker     buf += buf_stride;
82*77c1e3ccSAndroid Build Coastguard Worker     buf_avg += buf_avg_stride;
83*77c1e3ccSAndroid Build Coastguard Worker     height -= downsample_factor;
84*77c1e3ccSAndroid Build Coastguard Worker   } while (height > 0);
85*77c1e3ccSAndroid Build Coastguard Worker }
86*77c1e3ccSAndroid Build Coastguard Worker 
copy_upper_triangle(int64_t * H,int64_t * H_tmp,const int wiener_win2,const int scale)87*77c1e3ccSAndroid Build Coastguard Worker static inline void copy_upper_triangle(int64_t *H, int64_t *H_tmp,
88*77c1e3ccSAndroid Build Coastguard Worker                                        const int wiener_win2, const int scale) {
89*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < wiener_win2 - 2; i = i + 2) {
90*77c1e3ccSAndroid Build Coastguard Worker     // Transpose the first 2x2 square. It needs a special case as the element
91*77c1e3ccSAndroid Build Coastguard Worker     // of the bottom left is on the diagonal.
92*77c1e3ccSAndroid Build Coastguard Worker     int64x2_t row0 = vld1q_s64(H_tmp + i * wiener_win2 + i + 1);
93*77c1e3ccSAndroid Build Coastguard Worker     int64x2_t row1 = vld1q_s64(H_tmp + (i + 1) * wiener_win2 + i + 1);
94*77c1e3ccSAndroid Build Coastguard Worker 
95*77c1e3ccSAndroid Build Coastguard Worker     int64x2_t tr_row = aom_vtrn2q_s64(row0, row1);
96*77c1e3ccSAndroid Build Coastguard Worker 
97*77c1e3ccSAndroid Build Coastguard Worker     vst1_s64(H_tmp + (i + 1) * wiener_win2 + i, vget_low_s64(row0));
98*77c1e3ccSAndroid Build Coastguard Worker     vst1q_s64(H_tmp + (i + 2) * wiener_win2 + i, tr_row);
99*77c1e3ccSAndroid Build Coastguard Worker 
100*77c1e3ccSAndroid Build Coastguard Worker     // Transpose and store all the remaining 2x2 squares of the line.
101*77c1e3ccSAndroid Build Coastguard Worker     for (int j = i + 3; j < wiener_win2; j = j + 2) {
102*77c1e3ccSAndroid Build Coastguard Worker       row0 = vld1q_s64(H_tmp + i * wiener_win2 + j);
103*77c1e3ccSAndroid Build Coastguard Worker       row1 = vld1q_s64(H_tmp + (i + 1) * wiener_win2 + j);
104*77c1e3ccSAndroid Build Coastguard Worker 
105*77c1e3ccSAndroid Build Coastguard Worker       int64x2_t tr_row0 = aom_vtrn1q_s64(row0, row1);
106*77c1e3ccSAndroid Build Coastguard Worker       int64x2_t tr_row1 = aom_vtrn2q_s64(row0, row1);
107*77c1e3ccSAndroid Build Coastguard Worker 
108*77c1e3ccSAndroid Build Coastguard Worker       vst1q_s64(H_tmp + j * wiener_win2 + i, tr_row0);
109*77c1e3ccSAndroid Build Coastguard Worker       vst1q_s64(H_tmp + (j + 1) * wiener_win2 + i, tr_row1);
110*77c1e3ccSAndroid Build Coastguard Worker     }
111*77c1e3ccSAndroid Build Coastguard Worker   }
112*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < wiener_win2 * wiener_win2; i++) {
113*77c1e3ccSAndroid Build Coastguard Worker     H[i] += H_tmp[i] * scale;
114*77c1e3ccSAndroid Build Coastguard Worker   }
115*77c1e3ccSAndroid Build Coastguard Worker }
116*77c1e3ccSAndroid Build Coastguard Worker 
117*77c1e3ccSAndroid Build Coastguard Worker // Transpose the matrix that has just been computed and accumulate it in M.
acc_transpose_M(int64_t * M,const int64_t * M_trn,const int wiener_win,int scale)118*77c1e3ccSAndroid Build Coastguard Worker static inline void acc_transpose_M(int64_t *M, const int64_t *M_trn,
119*77c1e3ccSAndroid Build Coastguard Worker                                    const int wiener_win, int scale) {
120*77c1e3ccSAndroid Build Coastguard Worker   for (int i = 0; i < wiener_win; ++i) {
121*77c1e3ccSAndroid Build Coastguard Worker     for (int j = 0; j < wiener_win; ++j) {
122*77c1e3ccSAndroid Build Coastguard Worker       int tr_idx = j * wiener_win + i;
123*77c1e3ccSAndroid Build Coastguard Worker       *M++ += (int64_t)(M_trn[tr_idx] * scale);
124*77c1e3ccSAndroid Build Coastguard Worker     }
125*77c1e3ccSAndroid Build Coastguard Worker   }
126*77c1e3ccSAndroid Build Coastguard Worker }
127*77c1e3ccSAndroid Build Coastguard Worker 
128*77c1e3ccSAndroid Build Coastguard Worker // This function computes two matrices: the cross-correlation between the src
129*77c1e3ccSAndroid Build Coastguard Worker // buffer and dgd buffer (M), and the auto-covariance of the dgd buffer (H).
130*77c1e3ccSAndroid Build Coastguard Worker //
131*77c1e3ccSAndroid Build Coastguard Worker // M is of size 7 * 7. It needs to be filled such that multiplying one element
132*77c1e3ccSAndroid Build Coastguard Worker // from src with each element of a row of the wiener window will fill one
133*77c1e3ccSAndroid Build Coastguard Worker // column of M. However this is not very convenient in terms of memory
134*77c1e3ccSAndroid Build Coastguard Worker // accesses, as it means we do contiguous loads of dgd but strided stores to M.
135*77c1e3ccSAndroid Build Coastguard Worker // As a result, we use an intermediate matrix M_trn which is instead filled
136*77c1e3ccSAndroid Build Coastguard Worker // such that one row of the wiener window gives one row of M_trn. Once fully
137*77c1e3ccSAndroid Build Coastguard Worker // computed, M_trn is then transposed to return M.
138*77c1e3ccSAndroid Build Coastguard Worker //
139*77c1e3ccSAndroid Build Coastguard Worker // H is of size 49 * 49. It is filled by multiplying every pair of elements of
140*77c1e3ccSAndroid Build Coastguard Worker // the wiener window together. Since it is a symmetric matrix, we only compute
141*77c1e3ccSAndroid Build Coastguard Worker // the upper triangle, and then copy it down to the lower one. Here we fill it
142*77c1e3ccSAndroid Build Coastguard Worker // by taking each different pair of columns, and multiplying all the elements of
143*77c1e3ccSAndroid Build Coastguard Worker // the first one with all the elements of the second one, with a special case
144*77c1e3ccSAndroid Build Coastguard Worker // when multiplying a column by itself.
compute_stats_win7_downsampled_sve(int16_t * dgd_avg,int dgd_avg_stride,int16_t * src_avg,int src_avg_stride,int width,int height,int64_t * M,int64_t * H,int downsample_factor)145*77c1e3ccSAndroid Build Coastguard Worker static inline void compute_stats_win7_downsampled_sve(
146*77c1e3ccSAndroid Build Coastguard Worker     int16_t *dgd_avg, int dgd_avg_stride, int16_t *src_avg, int src_avg_stride,
147*77c1e3ccSAndroid Build Coastguard Worker     int width, int height, int64_t *M, int64_t *H, int downsample_factor) {
148*77c1e3ccSAndroid Build Coastguard Worker   const int wiener_win = 7;
149*77c1e3ccSAndroid Build Coastguard Worker   const int wiener_win2 = wiener_win * wiener_win;
150*77c1e3ccSAndroid Build Coastguard Worker 
151*77c1e3ccSAndroid Build Coastguard Worker   // Use a predicate to compute the last columns of the block for H.
152*77c1e3ccSAndroid Build Coastguard Worker   svbool_t pattern = svwhilelt_b16_u32(0, width % 8);
153*77c1e3ccSAndroid Build Coastguard Worker 
154*77c1e3ccSAndroid Build Coastguard Worker   // Use intermediate matrices for H and M to perform the computation, they
155*77c1e3ccSAndroid Build Coastguard Worker   // will be accumulated into the original H and M at the end.
156*77c1e3ccSAndroid Build Coastguard Worker   int64_t M_trn[49];
157*77c1e3ccSAndroid Build Coastguard Worker   memset(M_trn, 0, sizeof(M_trn));
158*77c1e3ccSAndroid Build Coastguard Worker 
159*77c1e3ccSAndroid Build Coastguard Worker   int64_t H_tmp[49 * 49];
160*77c1e3ccSAndroid Build Coastguard Worker   memset(H_tmp, 0, sizeof(H_tmp));
161*77c1e3ccSAndroid Build Coastguard Worker 
162*77c1e3ccSAndroid Build Coastguard Worker   assert(height > 0);
163*77c1e3ccSAndroid Build Coastguard Worker   do {
164*77c1e3ccSAndroid Build Coastguard Worker     // Cross-correlation (M).
165*77c1e3ccSAndroid Build Coastguard Worker     for (int row = 0; row < wiener_win; row++) {
166*77c1e3ccSAndroid Build Coastguard Worker       int j = 0;
167*77c1e3ccSAndroid Build Coastguard Worker       while (j < width) {
168*77c1e3ccSAndroid Build Coastguard Worker         int16x8_t dgd[7];
169*77c1e3ccSAndroid Build Coastguard Worker         load_s16_8x7(dgd_avg + row * dgd_avg_stride + j, 1, &dgd[0], &dgd[1],
170*77c1e3ccSAndroid Build Coastguard Worker                      &dgd[2], &dgd[3], &dgd[4], &dgd[5], &dgd[6]);
171*77c1e3ccSAndroid Build Coastguard Worker         int16x8_t s = vld1q_s16(src_avg + j);
172*77c1e3ccSAndroid Build Coastguard Worker 
173*77c1e3ccSAndroid Build Coastguard Worker         // Compute all the elements of one row of M.
174*77c1e3ccSAndroid Build Coastguard Worker         compute_M_one_row_win7(s, dgd, M_trn, row);
175*77c1e3ccSAndroid Build Coastguard Worker 
176*77c1e3ccSAndroid Build Coastguard Worker         j += 8;
177*77c1e3ccSAndroid Build Coastguard Worker       }
178*77c1e3ccSAndroid Build Coastguard Worker     }
179*77c1e3ccSAndroid Build Coastguard Worker 
180*77c1e3ccSAndroid Build Coastguard Worker     // Auto-covariance (H).
181*77c1e3ccSAndroid Build Coastguard Worker     int j = 0;
182*77c1e3ccSAndroid Build Coastguard Worker     while (j <= width - 8) {
183*77c1e3ccSAndroid Build Coastguard Worker       for (int col0 = 0; col0 < wiener_win; col0++) {
184*77c1e3ccSAndroid Build Coastguard Worker         int16x8_t dgd0[7];
185*77c1e3ccSAndroid Build Coastguard Worker         load_s16_8x7(dgd_avg + j + col0, dgd_avg_stride, &dgd0[0], &dgd0[1],
186*77c1e3ccSAndroid Build Coastguard Worker                      &dgd0[2], &dgd0[3], &dgd0[4], &dgd0[5], &dgd0[6]);
187*77c1e3ccSAndroid Build Coastguard Worker 
188*77c1e3ccSAndroid Build Coastguard Worker         // Perform computation of the first column with itself (28 elements).
189*77c1e3ccSAndroid Build Coastguard Worker         // For the first column this will fill the upper triangle of the 7x7
190*77c1e3ccSAndroid Build Coastguard Worker         // matrix at the top left of the H matrix. For the next columns this
191*77c1e3ccSAndroid Build Coastguard Worker         // will fill the upper triangle of the other 7x7 matrices around H's
192*77c1e3ccSAndroid Build Coastguard Worker         // diagonal.
193*77c1e3ccSAndroid Build Coastguard Worker         compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2);
194*77c1e3ccSAndroid Build Coastguard Worker 
195*77c1e3ccSAndroid Build Coastguard Worker         // All computation next to the matrix diagonal has already been done.
196*77c1e3ccSAndroid Build Coastguard Worker         for (int col1 = col0 + 1; col1 < wiener_win; col1++) {
197*77c1e3ccSAndroid Build Coastguard Worker           // Load second column and scale based on downsampling factor.
198*77c1e3ccSAndroid Build Coastguard Worker           int16x8_t dgd1[7];
199*77c1e3ccSAndroid Build Coastguard Worker           load_s16_8x7(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1],
200*77c1e3ccSAndroid Build Coastguard Worker                        &dgd1[2], &dgd1[3], &dgd1[4], &dgd1[5], &dgd1[6]);
201*77c1e3ccSAndroid Build Coastguard Worker 
202*77c1e3ccSAndroid Build Coastguard Worker           // Compute all elements from the combination of both columns (49
203*77c1e3ccSAndroid Build Coastguard Worker           // elements).
204*77c1e3ccSAndroid Build Coastguard Worker           compute_H_two_rows_win7(dgd0, dgd1, col0, col1, H_tmp);
205*77c1e3ccSAndroid Build Coastguard Worker         }
206*77c1e3ccSAndroid Build Coastguard Worker       }
207*77c1e3ccSAndroid Build Coastguard Worker       j += 8;
208*77c1e3ccSAndroid Build Coastguard Worker     }
209*77c1e3ccSAndroid Build Coastguard Worker 
210*77c1e3ccSAndroid Build Coastguard Worker     if (j < width) {
211*77c1e3ccSAndroid Build Coastguard Worker       // Process remaining columns using a predicate to discard excess elements.
212*77c1e3ccSAndroid Build Coastguard Worker       for (int col0 = 0; col0 < wiener_win; col0++) {
213*77c1e3ccSAndroid Build Coastguard Worker         // Load first column.
214*77c1e3ccSAndroid Build Coastguard Worker         int16x8_t dgd0[7];
215*77c1e3ccSAndroid Build Coastguard Worker         dgd0[0] = svget_neonq_s16(
216*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 0 * dgd_avg_stride + j + col0));
217*77c1e3ccSAndroid Build Coastguard Worker         dgd0[1] = svget_neonq_s16(
218*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 1 * dgd_avg_stride + j + col0));
219*77c1e3ccSAndroid Build Coastguard Worker         dgd0[2] = svget_neonq_s16(
220*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 2 * dgd_avg_stride + j + col0));
221*77c1e3ccSAndroid Build Coastguard Worker         dgd0[3] = svget_neonq_s16(
222*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 3 * dgd_avg_stride + j + col0));
223*77c1e3ccSAndroid Build Coastguard Worker         dgd0[4] = svget_neonq_s16(
224*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 4 * dgd_avg_stride + j + col0));
225*77c1e3ccSAndroid Build Coastguard Worker         dgd0[5] = svget_neonq_s16(
226*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 5 * dgd_avg_stride + j + col0));
227*77c1e3ccSAndroid Build Coastguard Worker         dgd0[6] = svget_neonq_s16(
228*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 6 * dgd_avg_stride + j + col0));
229*77c1e3ccSAndroid Build Coastguard Worker 
230*77c1e3ccSAndroid Build Coastguard Worker         // Perform computation of the first column with itself (28 elements).
231*77c1e3ccSAndroid Build Coastguard Worker         // For the first column this will fill the upper triangle of the 7x7
232*77c1e3ccSAndroid Build Coastguard Worker         // matrix at the top left of the H matrix. For the next columns this
233*77c1e3ccSAndroid Build Coastguard Worker         // will fill the upper triangle of the other 7x7 matrices around H's
234*77c1e3ccSAndroid Build Coastguard Worker         // diagonal.
235*77c1e3ccSAndroid Build Coastguard Worker         compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2);
236*77c1e3ccSAndroid Build Coastguard Worker 
237*77c1e3ccSAndroid Build Coastguard Worker         // All computation next to the matrix diagonal has already been done.
238*77c1e3ccSAndroid Build Coastguard Worker         for (int col1 = col0 + 1; col1 < wiener_win; col1++) {
239*77c1e3ccSAndroid Build Coastguard Worker           // Load second column and scale based on downsampling factor.
240*77c1e3ccSAndroid Build Coastguard Worker           int16x8_t dgd1[7];
241*77c1e3ccSAndroid Build Coastguard Worker           load_s16_8x7(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1],
242*77c1e3ccSAndroid Build Coastguard Worker                        &dgd1[2], &dgd1[3], &dgd1[4], &dgd1[5], &dgd1[6]);
243*77c1e3ccSAndroid Build Coastguard Worker 
244*77c1e3ccSAndroid Build Coastguard Worker           // Compute all elements from the combination of both columns (49
245*77c1e3ccSAndroid Build Coastguard Worker           // elements).
246*77c1e3ccSAndroid Build Coastguard Worker           compute_H_two_rows_win7(dgd0, dgd1, col0, col1, H_tmp);
247*77c1e3ccSAndroid Build Coastguard Worker         }
248*77c1e3ccSAndroid Build Coastguard Worker       }
249*77c1e3ccSAndroid Build Coastguard Worker     }
250*77c1e3ccSAndroid Build Coastguard Worker     dgd_avg += downsample_factor * dgd_avg_stride;
251*77c1e3ccSAndroid Build Coastguard Worker     src_avg += src_avg_stride;
252*77c1e3ccSAndroid Build Coastguard Worker   } while (--height != 0);
253*77c1e3ccSAndroid Build Coastguard Worker 
254*77c1e3ccSAndroid Build Coastguard Worker   // Transpose M_trn.
255*77c1e3ccSAndroid Build Coastguard Worker   acc_transpose_M(M, M_trn, 7, downsample_factor);
256*77c1e3ccSAndroid Build Coastguard Worker 
257*77c1e3ccSAndroid Build Coastguard Worker   // Copy upper triangle of H in the lower one.
258*77c1e3ccSAndroid Build Coastguard Worker   copy_upper_triangle(H, H_tmp, wiener_win2, downsample_factor);
259*77c1e3ccSAndroid Build Coastguard Worker }
260*77c1e3ccSAndroid Build Coastguard Worker 
261*77c1e3ccSAndroid Build Coastguard Worker // This function computes two matrices: the cross-correlation between the src
262*77c1e3ccSAndroid Build Coastguard Worker // buffer and dgd buffer (M), and the auto-covariance of the dgd buffer (H).
263*77c1e3ccSAndroid Build Coastguard Worker //
264*77c1e3ccSAndroid Build Coastguard Worker // M is of size 5 * 5. It needs to be filled such that multiplying one element
265*77c1e3ccSAndroid Build Coastguard Worker // from src with each element of a row of the wiener window will fill one
266*77c1e3ccSAndroid Build Coastguard Worker // column of M. However this is not very convenient in terms of memory
267*77c1e3ccSAndroid Build Coastguard Worker // accesses, as it means we do contiguous loads of dgd but strided stores to M.
268*77c1e3ccSAndroid Build Coastguard Worker // As a result, we use an intermediate matrix M_trn which is instead filled
269*77c1e3ccSAndroid Build Coastguard Worker // such that one row of the wiener window gives one row of M_trn. Once fully
270*77c1e3ccSAndroid Build Coastguard Worker // computed, M_trn is then transposed to return M.
271*77c1e3ccSAndroid Build Coastguard Worker //
272*77c1e3ccSAndroid Build Coastguard Worker // H is of size 25 * 25. It is filled by multiplying every pair of elements of
273*77c1e3ccSAndroid Build Coastguard Worker // the wiener window together. Since it is a symmetric matrix, we only compute
274*77c1e3ccSAndroid Build Coastguard Worker // the upper triangle, and then copy it down to the lower one. Here we fill it
275*77c1e3ccSAndroid Build Coastguard Worker // by taking each different pair of columns, and multiplying all the elements of
276*77c1e3ccSAndroid Build Coastguard Worker // the first one with all the elements of the second one, with a special case
277*77c1e3ccSAndroid Build Coastguard Worker // when multiplying a column by itself.
compute_stats_win5_downsampled_sve(int16_t * dgd_avg,int dgd_avg_stride,int16_t * src_avg,int src_avg_stride,int width,int height,int64_t * M,int64_t * H,int downsample_factor)278*77c1e3ccSAndroid Build Coastguard Worker static inline void compute_stats_win5_downsampled_sve(
279*77c1e3ccSAndroid Build Coastguard Worker     int16_t *dgd_avg, int dgd_avg_stride, int16_t *src_avg, int src_avg_stride,
280*77c1e3ccSAndroid Build Coastguard Worker     int width, int height, int64_t *M, int64_t *H, int downsample_factor) {
281*77c1e3ccSAndroid Build Coastguard Worker   const int wiener_win = 5;
282*77c1e3ccSAndroid Build Coastguard Worker   const int wiener_win2 = wiener_win * wiener_win;
283*77c1e3ccSAndroid Build Coastguard Worker 
284*77c1e3ccSAndroid Build Coastguard Worker   // Use a predicate to compute the last columns of the block for H.
285*77c1e3ccSAndroid Build Coastguard Worker   svbool_t pattern = svwhilelt_b16_u32(0, width % 8);
286*77c1e3ccSAndroid Build Coastguard Worker 
287*77c1e3ccSAndroid Build Coastguard Worker   // Use intermediate matrices for H and M to perform the computation, they
288*77c1e3ccSAndroid Build Coastguard Worker   // will be accumulated into the original H and M at the end.
289*77c1e3ccSAndroid Build Coastguard Worker   int64_t M_trn[25];
290*77c1e3ccSAndroid Build Coastguard Worker   memset(M_trn, 0, sizeof(M_trn));
291*77c1e3ccSAndroid Build Coastguard Worker 
292*77c1e3ccSAndroid Build Coastguard Worker   int64_t H_tmp[25 * 25];
293*77c1e3ccSAndroid Build Coastguard Worker   memset(H_tmp, 0, sizeof(H_tmp));
294*77c1e3ccSAndroid Build Coastguard Worker 
295*77c1e3ccSAndroid Build Coastguard Worker   assert(height > 0);
296*77c1e3ccSAndroid Build Coastguard Worker   do {
297*77c1e3ccSAndroid Build Coastguard Worker     // Cross-correlation (M).
298*77c1e3ccSAndroid Build Coastguard Worker     for (int row = 0; row < wiener_win; row++) {
299*77c1e3ccSAndroid Build Coastguard Worker       int j = 0;
300*77c1e3ccSAndroid Build Coastguard Worker       while (j < width) {
301*77c1e3ccSAndroid Build Coastguard Worker         int16x8_t dgd[5];
302*77c1e3ccSAndroid Build Coastguard Worker         load_s16_8x5(dgd_avg + row * dgd_avg_stride + j, 1, &dgd[0], &dgd[1],
303*77c1e3ccSAndroid Build Coastguard Worker                      &dgd[2], &dgd[3], &dgd[4]);
304*77c1e3ccSAndroid Build Coastguard Worker         int16x8_t s = vld1q_s16(src_avg + j);
305*77c1e3ccSAndroid Build Coastguard Worker 
306*77c1e3ccSAndroid Build Coastguard Worker         // Compute all the elements of one row of M.
307*77c1e3ccSAndroid Build Coastguard Worker         compute_M_one_row_win5(s, dgd, M_trn, row);
308*77c1e3ccSAndroid Build Coastguard Worker 
309*77c1e3ccSAndroid Build Coastguard Worker         j += 8;
310*77c1e3ccSAndroid Build Coastguard Worker       }
311*77c1e3ccSAndroid Build Coastguard Worker     }
312*77c1e3ccSAndroid Build Coastguard Worker 
313*77c1e3ccSAndroid Build Coastguard Worker     // Auto-covariance (H).
314*77c1e3ccSAndroid Build Coastguard Worker     int j = 0;
315*77c1e3ccSAndroid Build Coastguard Worker     while (j <= width - 8) {
316*77c1e3ccSAndroid Build Coastguard Worker       for (int col0 = 0; col0 < wiener_win; col0++) {
317*77c1e3ccSAndroid Build Coastguard Worker         // Load first column.
318*77c1e3ccSAndroid Build Coastguard Worker         int16x8_t dgd0[5];
319*77c1e3ccSAndroid Build Coastguard Worker         load_s16_8x5(dgd_avg + j + col0, dgd_avg_stride, &dgd0[0], &dgd0[1],
320*77c1e3ccSAndroid Build Coastguard Worker                      &dgd0[2], &dgd0[3], &dgd0[4]);
321*77c1e3ccSAndroid Build Coastguard Worker 
322*77c1e3ccSAndroid Build Coastguard Worker         // Perform computation of the first column with itself (15 elements).
323*77c1e3ccSAndroid Build Coastguard Worker         // For the first column this will fill the upper triangle of the 5x5
324*77c1e3ccSAndroid Build Coastguard Worker         // matrix at the top left of the H matrix. For the next columns this
325*77c1e3ccSAndroid Build Coastguard Worker         // will fill the upper triangle of the other 5x5 matrices around H's
326*77c1e3ccSAndroid Build Coastguard Worker         // diagonal.
327*77c1e3ccSAndroid Build Coastguard Worker         compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2);
328*77c1e3ccSAndroid Build Coastguard Worker 
329*77c1e3ccSAndroid Build Coastguard Worker         // All computation next to the matrix diagonal has already been done.
330*77c1e3ccSAndroid Build Coastguard Worker         for (int col1 = col0 + 1; col1 < wiener_win; col1++) {
331*77c1e3ccSAndroid Build Coastguard Worker           // Load second column and scale based on downsampling factor.
332*77c1e3ccSAndroid Build Coastguard Worker           int16x8_t dgd1[5];
333*77c1e3ccSAndroid Build Coastguard Worker           load_s16_8x5(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1],
334*77c1e3ccSAndroid Build Coastguard Worker                        &dgd1[2], &dgd1[3], &dgd1[4]);
335*77c1e3ccSAndroid Build Coastguard Worker 
336*77c1e3ccSAndroid Build Coastguard Worker           // Compute all elements from the combination of both columns (25
337*77c1e3ccSAndroid Build Coastguard Worker           // elements).
338*77c1e3ccSAndroid Build Coastguard Worker           compute_H_two_rows_win5(dgd0, dgd1, col0, col1, H_tmp);
339*77c1e3ccSAndroid Build Coastguard Worker         }
340*77c1e3ccSAndroid Build Coastguard Worker       }
341*77c1e3ccSAndroid Build Coastguard Worker       j += 8;
342*77c1e3ccSAndroid Build Coastguard Worker     }
343*77c1e3ccSAndroid Build Coastguard Worker 
344*77c1e3ccSAndroid Build Coastguard Worker     // Process remaining columns using a predicate to discard excess elements.
345*77c1e3ccSAndroid Build Coastguard Worker     if (j < width) {
346*77c1e3ccSAndroid Build Coastguard Worker       for (int col0 = 0; col0 < wiener_win; col0++) {
347*77c1e3ccSAndroid Build Coastguard Worker         int16x8_t dgd0[5];
348*77c1e3ccSAndroid Build Coastguard Worker         dgd0[0] = svget_neonq_s16(
349*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 0 * dgd_avg_stride + j + col0));
350*77c1e3ccSAndroid Build Coastguard Worker         dgd0[1] = svget_neonq_s16(
351*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 1 * dgd_avg_stride + j + col0));
352*77c1e3ccSAndroid Build Coastguard Worker         dgd0[2] = svget_neonq_s16(
353*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 2 * dgd_avg_stride + j + col0));
354*77c1e3ccSAndroid Build Coastguard Worker         dgd0[3] = svget_neonq_s16(
355*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 3 * dgd_avg_stride + j + col0));
356*77c1e3ccSAndroid Build Coastguard Worker         dgd0[4] = svget_neonq_s16(
357*77c1e3ccSAndroid Build Coastguard Worker             svld1_s16(pattern, dgd_avg + 4 * dgd_avg_stride + j + col0));
358*77c1e3ccSAndroid Build Coastguard Worker 
359*77c1e3ccSAndroid Build Coastguard Worker         // Perform computation of the first column with itself (15 elements).
360*77c1e3ccSAndroid Build Coastguard Worker         // For the first column this will fill the upper triangle of the 5x5
361*77c1e3ccSAndroid Build Coastguard Worker         // matrix at the top left of the H matrix. For the next columns this
362*77c1e3ccSAndroid Build Coastguard Worker         // will fill the upper triangle of the other 5x5 matrices around H's
363*77c1e3ccSAndroid Build Coastguard Worker         // diagonal.
364*77c1e3ccSAndroid Build Coastguard Worker         compute_H_one_col(dgd0, col0, H_tmp, wiener_win, wiener_win2);
365*77c1e3ccSAndroid Build Coastguard Worker 
366*77c1e3ccSAndroid Build Coastguard Worker         // All computation next to the matrix diagonal has already been done.
367*77c1e3ccSAndroid Build Coastguard Worker         for (int col1 = col0 + 1; col1 < wiener_win; col1++) {
368*77c1e3ccSAndroid Build Coastguard Worker           // Load second column and scale based on downsampling factor.
369*77c1e3ccSAndroid Build Coastguard Worker           int16x8_t dgd1[5];
370*77c1e3ccSAndroid Build Coastguard Worker           load_s16_8x5(dgd_avg + j + col1, dgd_avg_stride, &dgd1[0], &dgd1[1],
371*77c1e3ccSAndroid Build Coastguard Worker                        &dgd1[2], &dgd1[3], &dgd1[4]);
372*77c1e3ccSAndroid Build Coastguard Worker 
373*77c1e3ccSAndroid Build Coastguard Worker           // Compute all elements from the combination of both columns (25
374*77c1e3ccSAndroid Build Coastguard Worker           // elements).
375*77c1e3ccSAndroid Build Coastguard Worker           compute_H_two_rows_win5(dgd0, dgd1, col0, col1, H_tmp);
376*77c1e3ccSAndroid Build Coastguard Worker         }
377*77c1e3ccSAndroid Build Coastguard Worker       }
378*77c1e3ccSAndroid Build Coastguard Worker     }
379*77c1e3ccSAndroid Build Coastguard Worker     dgd_avg += downsample_factor * dgd_avg_stride;
380*77c1e3ccSAndroid Build Coastguard Worker     src_avg += src_avg_stride;
381*77c1e3ccSAndroid Build Coastguard Worker   } while (--height != 0);
382*77c1e3ccSAndroid Build Coastguard Worker 
383*77c1e3ccSAndroid Build Coastguard Worker   // Transpose M_trn.
384*77c1e3ccSAndroid Build Coastguard Worker   acc_transpose_M(M, M_trn, 5, downsample_factor);
385*77c1e3ccSAndroid Build Coastguard Worker 
386*77c1e3ccSAndroid Build Coastguard Worker   // Copy upper triangle of H in the lower one.
387*77c1e3ccSAndroid Build Coastguard Worker   copy_upper_triangle(H, H_tmp, wiener_win2, downsample_factor);
388*77c1e3ccSAndroid Build Coastguard Worker }
389*77c1e3ccSAndroid Build Coastguard Worker 
av1_compute_stats_downsampled_sve(int wiener_win,const uint8_t * dgd,const uint8_t * src,int16_t * dgd_avg,int16_t * src_avg,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H)390*77c1e3ccSAndroid Build Coastguard Worker static inline void av1_compute_stats_downsampled_sve(
391*77c1e3ccSAndroid Build Coastguard Worker     int wiener_win, const uint8_t *dgd, const uint8_t *src, int16_t *dgd_avg,
392*77c1e3ccSAndroid Build Coastguard Worker     int16_t *src_avg, int h_start, int h_end, int v_start, int v_end,
393*77c1e3ccSAndroid Build Coastguard Worker     int dgd_stride, int src_stride, int64_t *M, int64_t *H) {
394*77c1e3ccSAndroid Build Coastguard Worker   assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA);
395*77c1e3ccSAndroid Build Coastguard Worker 
396*77c1e3ccSAndroid Build Coastguard Worker   const int wiener_win2 = wiener_win * wiener_win;
397*77c1e3ccSAndroid Build Coastguard Worker   const int wiener_halfwin = wiener_win >> 1;
398*77c1e3ccSAndroid Build Coastguard Worker   const int32_t width = h_end - h_start;
399*77c1e3ccSAndroid Build Coastguard Worker   const int32_t height = v_end - v_start;
400*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *dgd_start = &dgd[v_start * dgd_stride + h_start];
401*77c1e3ccSAndroid Build Coastguard Worker   memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
402*77c1e3ccSAndroid Build Coastguard Worker   memset(M, 0, sizeof(*M) * wiener_win * wiener_win);
403*77c1e3ccSAndroid Build Coastguard Worker 
404*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t avg = find_average_sve(dgd_start, dgd_stride, width, height);
405*77c1e3ccSAndroid Build Coastguard Worker   const int downsample_factor = WIENER_STATS_DOWNSAMPLE_FACTOR;
406*77c1e3ccSAndroid Build Coastguard Worker 
407*77c1e3ccSAndroid Build Coastguard Worker   // dgd_avg and src_avg have been memset to zero before calling this
408*77c1e3ccSAndroid Build Coastguard Worker   // function, so round up the stride to the next multiple of 8 so that we
409*77c1e3ccSAndroid Build Coastguard Worker   // don't have to worry about a tail loop when computing M.
410*77c1e3ccSAndroid Build Coastguard Worker   const int dgd_avg_stride = ((width + 2 * wiener_halfwin) & ~7) + 8;
411*77c1e3ccSAndroid Build Coastguard Worker   const int src_avg_stride = (width & ~7) + 8;
412*77c1e3ccSAndroid Build Coastguard Worker 
413*77c1e3ccSAndroid Build Coastguard Worker   // Compute (dgd - avg) and store it in dgd_avg.
414*77c1e3ccSAndroid Build Coastguard Worker   // The wiener window will slide along the dgd frame, centered on each pixel.
415*77c1e3ccSAndroid Build Coastguard Worker   // For the top left pixel and all the pixels on the side of the frame this
416*77c1e3ccSAndroid Build Coastguard Worker   // means half of the window will be outside of the frame. As such the actual
417*77c1e3ccSAndroid Build Coastguard Worker   // buffer that we need to subtract the avg from will be 2 * wiener_halfwin
418*77c1e3ccSAndroid Build Coastguard Worker   // wider and 2 * wiener_halfwin higher than the original dgd buffer.
419*77c1e3ccSAndroid Build Coastguard Worker   const int vert_offset = v_start - wiener_halfwin;
420*77c1e3ccSAndroid Build Coastguard Worker   const int horiz_offset = h_start - wiener_halfwin;
421*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride;
422*77c1e3ccSAndroid Build Coastguard Worker   compute_sub_avg(dgd_win, dgd_stride, avg, dgd_avg, dgd_avg_stride,
423*77c1e3ccSAndroid Build Coastguard Worker                   width + 2 * wiener_halfwin, height + 2 * wiener_halfwin, 1);
424*77c1e3ccSAndroid Build Coastguard Worker 
425*77c1e3ccSAndroid Build Coastguard Worker   // Compute (src - avg), downsample and store in src-avg.
426*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *src_start = src + h_start + v_start * src_stride;
427*77c1e3ccSAndroid Build Coastguard Worker   compute_sub_avg(src_start, src_stride * downsample_factor, avg, src_avg,
428*77c1e3ccSAndroid Build Coastguard Worker                   src_avg_stride, width, height, downsample_factor);
429*77c1e3ccSAndroid Build Coastguard Worker 
430*77c1e3ccSAndroid Build Coastguard Worker   const int downsample_height = height / downsample_factor;
431*77c1e3ccSAndroid Build Coastguard Worker 
432*77c1e3ccSAndroid Build Coastguard Worker   // Since the height is not necessarily a multiple of the downsample factor,
433*77c1e3ccSAndroid Build Coastguard Worker   // the last line of src will be scaled according to how many rows remain.
434*77c1e3ccSAndroid Build Coastguard Worker   const int downsample_remainder = height % downsample_factor;
435*77c1e3ccSAndroid Build Coastguard Worker 
436*77c1e3ccSAndroid Build Coastguard Worker   if (downsample_height > 0) {
437*77c1e3ccSAndroid Build Coastguard Worker     if (wiener_win == WIENER_WIN) {
438*77c1e3ccSAndroid Build Coastguard Worker       compute_stats_win7_downsampled_sve(
439*77c1e3ccSAndroid Build Coastguard Worker           dgd_avg, dgd_avg_stride, src_avg, src_avg_stride, width,
440*77c1e3ccSAndroid Build Coastguard Worker           downsample_height, M, H, downsample_factor);
441*77c1e3ccSAndroid Build Coastguard Worker     } else {
442*77c1e3ccSAndroid Build Coastguard Worker       compute_stats_win5_downsampled_sve(
443*77c1e3ccSAndroid Build Coastguard Worker           dgd_avg, dgd_avg_stride, src_avg, src_avg_stride, width,
444*77c1e3ccSAndroid Build Coastguard Worker           downsample_height, M, H, downsample_factor);
445*77c1e3ccSAndroid Build Coastguard Worker     }
446*77c1e3ccSAndroid Build Coastguard Worker   }
447*77c1e3ccSAndroid Build Coastguard Worker 
448*77c1e3ccSAndroid Build Coastguard Worker   if (downsample_remainder > 0) {
449*77c1e3ccSAndroid Build Coastguard Worker     const int remainder_offset = height - downsample_remainder;
450*77c1e3ccSAndroid Build Coastguard Worker     if (wiener_win == WIENER_WIN) {
451*77c1e3ccSAndroid Build Coastguard Worker       compute_stats_win7_downsampled_sve(
452*77c1e3ccSAndroid Build Coastguard Worker           dgd_avg + remainder_offset * dgd_avg_stride, dgd_avg_stride,
453*77c1e3ccSAndroid Build Coastguard Worker           src_avg + downsample_height * src_avg_stride, src_avg_stride, width,
454*77c1e3ccSAndroid Build Coastguard Worker           1, M, H, downsample_remainder);
455*77c1e3ccSAndroid Build Coastguard Worker     } else {
456*77c1e3ccSAndroid Build Coastguard Worker       compute_stats_win5_downsampled_sve(
457*77c1e3ccSAndroid Build Coastguard Worker           dgd_avg + remainder_offset * dgd_avg_stride, dgd_avg_stride,
458*77c1e3ccSAndroid Build Coastguard Worker           src_avg + downsample_height * src_avg_stride, src_avg_stride, width,
459*77c1e3ccSAndroid Build Coastguard Worker           1, M, H, downsample_remainder);
460*77c1e3ccSAndroid Build Coastguard Worker     }
461*77c1e3ccSAndroid Build Coastguard Worker   }
462*77c1e3ccSAndroid Build Coastguard Worker }
463*77c1e3ccSAndroid Build Coastguard Worker 
av1_compute_stats_sve(int wiener_win,const uint8_t * dgd,const uint8_t * src,int16_t * dgd_avg,int16_t * src_avg,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,int use_downsampled_wiener_stats)464*77c1e3ccSAndroid Build Coastguard Worker void av1_compute_stats_sve(int wiener_win, const uint8_t *dgd,
465*77c1e3ccSAndroid Build Coastguard Worker                            const uint8_t *src, int16_t *dgd_avg,
466*77c1e3ccSAndroid Build Coastguard Worker                            int16_t *src_avg, int h_start, int h_end,
467*77c1e3ccSAndroid Build Coastguard Worker                            int v_start, int v_end, int dgd_stride,
468*77c1e3ccSAndroid Build Coastguard Worker                            int src_stride, int64_t *M, int64_t *H,
469*77c1e3ccSAndroid Build Coastguard Worker                            int use_downsampled_wiener_stats) {
470*77c1e3ccSAndroid Build Coastguard Worker   assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA);
471*77c1e3ccSAndroid Build Coastguard Worker 
472*77c1e3ccSAndroid Build Coastguard Worker   if (use_downsampled_wiener_stats) {
473*77c1e3ccSAndroid Build Coastguard Worker     av1_compute_stats_downsampled_sve(wiener_win, dgd, src, dgd_avg, src_avg,
474*77c1e3ccSAndroid Build Coastguard Worker                                       h_start, h_end, v_start, v_end,
475*77c1e3ccSAndroid Build Coastguard Worker                                       dgd_stride, src_stride, M, H);
476*77c1e3ccSAndroid Build Coastguard Worker     return;
477*77c1e3ccSAndroid Build Coastguard Worker   }
478*77c1e3ccSAndroid Build Coastguard Worker 
479*77c1e3ccSAndroid Build Coastguard Worker   const int wiener_win2 = wiener_win * wiener_win;
480*77c1e3ccSAndroid Build Coastguard Worker   const int wiener_halfwin = wiener_win >> 1;
481*77c1e3ccSAndroid Build Coastguard Worker   const int32_t width = h_end - h_start;
482*77c1e3ccSAndroid Build Coastguard Worker   const int32_t height = v_end - v_start;
483*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *dgd_start = &dgd[v_start * dgd_stride + h_start];
484*77c1e3ccSAndroid Build Coastguard Worker   memset(H, 0, sizeof(*H) * wiener_win2 * wiener_win2);
485*77c1e3ccSAndroid Build Coastguard Worker   memset(M, 0, sizeof(*M) * wiener_win * wiener_win);
486*77c1e3ccSAndroid Build Coastguard Worker 
487*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t avg = find_average_sve(dgd_start, dgd_stride, width, height);
488*77c1e3ccSAndroid Build Coastguard Worker 
489*77c1e3ccSAndroid Build Coastguard Worker   // dgd_avg and src_avg have been memset to zero before calling this
490*77c1e3ccSAndroid Build Coastguard Worker   // function, so round up the stride to the next multiple of 8 so that we
491*77c1e3ccSAndroid Build Coastguard Worker   // don't have to worry about a tail loop when computing M.
492*77c1e3ccSAndroid Build Coastguard Worker   const int dgd_avg_stride = ((width + 2 * wiener_halfwin) & ~7) + 8;
493*77c1e3ccSAndroid Build Coastguard Worker   const int src_avg_stride = (width & ~7) + 8;
494*77c1e3ccSAndroid Build Coastguard Worker 
495*77c1e3ccSAndroid Build Coastguard Worker   // Compute (dgd - avg) and store it in dgd_avg.
496*77c1e3ccSAndroid Build Coastguard Worker   // The wiener window will slide along the dgd frame, centered on each pixel.
497*77c1e3ccSAndroid Build Coastguard Worker   // For the top left pixel and all the pixels on the side of the frame this
498*77c1e3ccSAndroid Build Coastguard Worker   // means half of the window will be outside of the frame. As such the actual
499*77c1e3ccSAndroid Build Coastguard Worker   // buffer that we need to subtract the avg from will be 2 * wiener_halfwin
500*77c1e3ccSAndroid Build Coastguard Worker   // wider and 2 * wiener_halfwin higher than the original dgd buffer.
501*77c1e3ccSAndroid Build Coastguard Worker   const int vert_offset = v_start - wiener_halfwin;
502*77c1e3ccSAndroid Build Coastguard Worker   const int horiz_offset = h_start - wiener_halfwin;
503*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride;
504*77c1e3ccSAndroid Build Coastguard Worker   compute_sub_avg(dgd_win, dgd_stride, avg, dgd_avg, dgd_avg_stride,
505*77c1e3ccSAndroid Build Coastguard Worker                   width + 2 * wiener_halfwin, height + 2 * wiener_halfwin, 1);
506*77c1e3ccSAndroid Build Coastguard Worker 
507*77c1e3ccSAndroid Build Coastguard Worker   // Compute (src - avg), and store in src-avg.
508*77c1e3ccSAndroid Build Coastguard Worker   const uint8_t *src_start = src + h_start + v_start * src_stride;
509*77c1e3ccSAndroid Build Coastguard Worker   compute_sub_avg(src_start, src_stride, avg, src_avg, src_avg_stride, width,
510*77c1e3ccSAndroid Build Coastguard Worker                   height, 1);
511*77c1e3ccSAndroid Build Coastguard Worker 
512*77c1e3ccSAndroid Build Coastguard Worker   if (wiener_win == WIENER_WIN) {
513*77c1e3ccSAndroid Build Coastguard Worker     compute_stats_win7_sve(dgd_avg, dgd_avg_stride, src_avg, src_avg_stride,
514*77c1e3ccSAndroid Build Coastguard Worker                            width, height, M, H);
515*77c1e3ccSAndroid Build Coastguard Worker   } else {
516*77c1e3ccSAndroid Build Coastguard Worker     compute_stats_win5_sve(dgd_avg, dgd_avg_stride, src_avg, src_avg_stride,
517*77c1e3ccSAndroid Build Coastguard Worker                            width, height, M, H);
518*77c1e3ccSAndroid Build Coastguard Worker   }
519*77c1e3ccSAndroid Build Coastguard Worker 
520*77c1e3ccSAndroid Build Coastguard Worker   // H is a symmetric matrix, so we only need to fill out the upper triangle.
521*77c1e3ccSAndroid Build Coastguard Worker   // We can copy it down to the lower triangle outside the (i, j) loops.
522*77c1e3ccSAndroid Build Coastguard Worker   diagonal_copy_stats_neon(wiener_win2, H);
523*77c1e3ccSAndroid Build Coastguard Worker }
524