xref: /aosp_15_r20/external/libaom/av1/encoder/arm/pickrst_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2020, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 
14 #include "config/aom_config.h"
15 #include "config/av1_rtcd.h"
16 
17 #include "aom_dsp/arm/mem_neon.h"
18 #include "aom_dsp/arm/sum_neon.h"
19 #include "aom_dsp/arm/transpose_neon.h"
20 #include "av1/common/restoration.h"
21 #include "av1/encoder/arm/pickrst_neon.h"
22 #include "av1/encoder/pickrst.h"
23 
av1_lowbd_pixel_proj_error_neon(const uint8_t * src,int width,int height,int src_stride,const uint8_t * dat,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int xq[2],const sgr_params_type * params)24 int64_t av1_lowbd_pixel_proj_error_neon(
25     const uint8_t *src, int width, int height, int src_stride,
26     const uint8_t *dat, int dat_stride, int32_t *flt0, int flt0_stride,
27     int32_t *flt1, int flt1_stride, int xq[2], const sgr_params_type *params) {
28   int64_t sse = 0;
29   int64x2_t sse_s64 = vdupq_n_s64(0);
30 
31   if (params->r[0] > 0 && params->r[1] > 0) {
32     int32x2_t xq_v = vld1_s32(xq);
33     int32x2_t xq_sum_v = vshl_n_s32(vpadd_s32(xq_v, xq_v), SGRPROJ_RST_BITS);
34 
35     do {
36       int j = 0;
37       int32x4_t sse_s32 = vdupq_n_s32(0);
38 
39       do {
40         const uint8x8_t d = vld1_u8(&dat[j]);
41         const uint8x8_t s = vld1_u8(&src[j]);
42         int32x4_t flt0_0 = vld1q_s32(&flt0[j]);
43         int32x4_t flt0_1 = vld1q_s32(&flt0[j + 4]);
44         int32x4_t flt1_0 = vld1q_s32(&flt1[j]);
45         int32x4_t flt1_1 = vld1q_s32(&flt1[j + 4]);
46 
47         int32x4_t offset =
48             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1));
49         int32x4_t v0 = vmlaq_lane_s32(offset, flt0_0, xq_v, 0);
50         int32x4_t v1 = vmlaq_lane_s32(offset, flt0_1, xq_v, 0);
51 
52         v0 = vmlaq_lane_s32(v0, flt1_0, xq_v, 1);
53         v1 = vmlaq_lane_s32(v1, flt1_1, xq_v, 1);
54 
55         int16x8_t d_s16 = vreinterpretq_s16_u16(vmovl_u8(d));
56         v0 = vmlsl_lane_s16(v0, vget_low_s16(d_s16),
57                             vreinterpret_s16_s32(xq_sum_v), 0);
58         v1 = vmlsl_lane_s16(v1, vget_high_s16(d_s16),
59                             vreinterpret_s16_s32(xq_sum_v), 0);
60 
61         int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
62         int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
63 
64         int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(d, s));
65         int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), diff);
66         int16x4_t e_lo = vget_low_s16(e);
67         int16x4_t e_hi = vget_high_s16(e);
68 
69         sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
70         sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
71 
72         j += 8;
73       } while (j <= width - 8);
74 
75       for (int k = j; k < width; ++k) {
76         int32_t u = (dat[k] << SGRPROJ_RST_BITS);
77         int32_t v = (1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1)) +
78                     xq[0] * flt0[k] + xq[1] * flt1[k] - u * (xq[0] + xq[1]);
79         int32_t e =
80             (v >> (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS)) + dat[k] - src[k];
81         sse += e * e;
82       }
83 
84       sse_s64 = vpadalq_s32(sse_s64, sse_s32);
85 
86       dat += dat_stride;
87       src += src_stride;
88       flt0 += flt0_stride;
89       flt1 += flt1_stride;
90     } while (--height != 0);
91   } else if (params->r[0] > 0 || params->r[1] > 0) {
92     int xq_active = (params->r[0] > 0) ? xq[0] : xq[1];
93     int32_t *flt = (params->r[0] > 0) ? flt0 : flt1;
94     int flt_stride = (params->r[0] > 0) ? flt0_stride : flt1_stride;
95     int32x2_t xq_v = vdup_n_s32(xq_active);
96 
97     do {
98       int32x4_t sse_s32 = vdupq_n_s32(0);
99       int j = 0;
100 
101       do {
102         const uint8x8_t d = vld1_u8(&dat[j]);
103         const uint8x8_t s = vld1_u8(&src[j]);
104         int32x4_t flt_0 = vld1q_s32(&flt[j]);
105         int32x4_t flt_1 = vld1q_s32(&flt[j + 4]);
106         int16x8_t d_s16 =
107             vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
108 
109         int32x4_t sub_0 = vsubw_s16(flt_0, vget_low_s16(d_s16));
110         int32x4_t sub_1 = vsubw_s16(flt_1, vget_high_s16(d_s16));
111 
112         int32x4_t offset =
113             vdupq_n_s32(1 << (SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS - 1));
114         int32x4_t v0 = vmlaq_lane_s32(offset, sub_0, xq_v, 0);
115         int32x4_t v1 = vmlaq_lane_s32(offset, sub_1, xq_v, 0);
116 
117         int16x4_t vr0 = vshrn_n_s32(v0, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
118         int16x4_t vr1 = vshrn_n_s32(v1, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS);
119 
120         int16x8_t diff = vreinterpretq_s16_u16(vsubl_u8(d, s));
121         int16x8_t e = vaddq_s16(vcombine_s16(vr0, vr1), diff);
122         int16x4_t e_lo = vget_low_s16(e);
123         int16x4_t e_hi = vget_high_s16(e);
124 
125         sse_s32 = vmlal_s16(sse_s32, e_lo, e_lo);
126         sse_s32 = vmlal_s16(sse_s32, e_hi, e_hi);
127 
128         j += 8;
129       } while (j <= width - 8);
130 
131       for (int k = j; k < width; ++k) {
132         int32_t u = dat[k] << SGRPROJ_RST_BITS;
133         int32_t v = xq_active * (flt[k] - u);
134         int32_t e = ROUND_POWER_OF_TWO(v, SGRPROJ_RST_BITS + SGRPROJ_PRJ_BITS) +
135                     dat[k] - src[k];
136         sse += e * e;
137       }
138 
139       sse_s64 = vpadalq_s32(sse_s64, sse_s32);
140 
141       dat += dat_stride;
142       src += src_stride;
143       flt += flt_stride;
144     } while (--height != 0);
145   } else {
146     uint32x4_t sse_s32 = vdupq_n_u32(0);
147 
148     do {
149       int j = 0;
150 
151       do {
152         const uint8x16_t d = vld1q_u8(&dat[j]);
153         const uint8x16_t s = vld1q_u8(&src[j]);
154 
155         uint8x16_t diff = vabdq_u8(d, s);
156         uint8x8_t diff_lo = vget_low_u8(diff);
157         uint8x8_t diff_hi = vget_high_u8(diff);
158 
159         sse_s32 = vpadalq_u16(sse_s32, vmull_u8(diff_lo, diff_lo));
160         sse_s32 = vpadalq_u16(sse_s32, vmull_u8(diff_hi, diff_hi));
161 
162         j += 16;
163       } while (j <= width - 16);
164 
165       for (int k = j; k < width; ++k) {
166         int32_t e = dat[k] - src[k];
167         sse += e * e;
168       }
169 
170       dat += dat_stride;
171       src += src_stride;
172     } while (--height != 0);
173 
174     sse_s64 = vreinterpretq_s64_u64(vpaddlq_u32(sse_s32));
175   }
176 
177   sse += horizontal_add_s64x2(sse_s64);
178   return sse;
179 }
180 
181 // We can accumulate up to 32768 8-bit multiplication results in a signed
182 // 32-bit integer. We are processing 2 pixels at a time, so the accumulator max
183 // can be as high as 16384 for the compute stats.
184 #define STAT_ACCUMULATOR_MAX 16384
185 
tbl2(uint8x16_t a,uint8x16_t b,uint8x8_t idx)186 static inline uint8x8_t tbl2(uint8x16_t a, uint8x16_t b, uint8x8_t idx) {
187 #if AOM_ARCH_AARCH64
188   uint8x16x2_t table = { { a, b } };
189   return vqtbl2_u8(table, idx);
190 #else
191   uint8x8x4_t table = { { vget_low_u8(a), vget_high_u8(a), vget_low_u8(b),
192                           vget_high_u8(b) } };
193   return vtbl4_u8(table, idx);
194 #endif
195 }
196 
tbl2q(uint8x16_t a,uint8x16_t b,uint8x16_t idx)197 static inline uint8x16_t tbl2q(uint8x16_t a, uint8x16_t b, uint8x16_t idx) {
198 #if AOM_ARCH_AARCH64
199   uint8x16x2_t table = { { a, b } };
200   return vqtbl2q_u8(table, idx);
201 #else
202   uint8x8x4_t table = { { vget_low_u8(a), vget_high_u8(a), vget_low_u8(b),
203                           vget_high_u8(b) } };
204   return vcombine_u8(vtbl4_u8(table, vget_low_u8(idx)),
205                      vtbl4_u8(table, vget_high_u8(idx)));
206 #endif
207 }
208 
209 // The M matrix is accumulated in STAT_ACCUMULATOR_MAX steps to speed-up the
210 // computation. This function computes the final M from the accumulated
211 // (src_s64) and the residual parts (src_s32). It also transposes the result as
212 // the output needs to be column-major.
acc_transpose_M(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int scale)213 static inline void acc_transpose_M(int64_t *dst, const int64_t *src_s64,
214                                    const int32_t *src_s32, const int wiener_win,
215                                    int scale) {
216   for (int i = 0; i < wiener_win; ++i) {
217     for (int j = 0; j < wiener_win; ++j) {
218       int tr_idx = j * wiener_win + i;
219       *dst++ += (int64_t)(src_s64[tr_idx] + src_s32[tr_idx]) * scale;
220     }
221   }
222 }
223 
224 // The resulting H is a column-major matrix accumulated from the transposed
225 // (column-major) samples of the filter kernel (5x5 or 7x7) viewed as a single
226 // vector. For the 7x7 filter case: H(49x49) = [49 x 1] x [1 x 49]. This
227 // function transforms back to the originally expected format (double
228 // transpose). The H matrix is accumulated in STAT_ACCUMULATOR_MAX steps to
229 // speed-up the computation. This function computes the final H from the
230 // accumulated (src_s64) and the residual parts (src_s32). The computed H is
231 // only an upper triangle matrix, this function also fills the lower triangle of
232 // the resulting matrix.
update_H(int64_t * dst,const int64_t * src_s64,const int32_t * src_s32,const int wiener_win,int stride,int scale)233 static void update_H(int64_t *dst, const int64_t *src_s64,
234                      const int32_t *src_s32, const int wiener_win, int stride,
235                      int scale) {
236   // For a simplified theoretical 3x3 case where `wiener_win` is 3 and
237   // `wiener_win2` is 9, the M matrix is 3x3:
238   // 0, 3, 6
239   // 1, 4, 7
240   // 2, 5, 8
241   //
242   // This is viewed as a vector to compute H (9x9) by vector outer product:
243   // 0, 3, 6, 1, 4, 7, 2, 5, 8
244   //
245   // Double transpose and upper triangle remapping for 3x3 -> 9x9 case:
246   // 0,    3,    6,    1,    4,    7,    2,    5,    8,
247   // 3,   30,   33,   12,   31,   34,   21,   32,   35,
248   // 6,   33,   60,   15,   42,   61,   24,   51,   62,
249   // 1,   12,   15,   10,   13,   16,   11,   14,   17,
250   // 4,   31,   42,   13,   40,   43,   22,   41,   44,
251   // 7,   34,   61,   16,   43,   70,   25,   52,   71,
252   // 2,   21,   24,   11,   22,   25,   20,   23,   26,
253   // 5,   32,   51,   14,   41,   52,   23,   50,   53,
254   // 8,   35,   62,   17,   44,   71,   26,   53,   80,
255   const int wiener_win2 = wiener_win * wiener_win;
256 
257   // Loop through the indices according to the remapping above, along the
258   // columns:
259   // 0, wiener_win, 2 * wiener_win, ..., 1, 1 + 2 * wiener_win, ...,
260   // wiener_win - 1, wiener_win - 1 + wiener_win, ...
261   // For the 3x3 case `j` will be: 0, 3, 6, 1, 4, 7, 2, 5, 8.
262   for (int i = 0; i < wiener_win; ++i) {
263     for (int j = i; j < wiener_win2; j += wiener_win) {
264       // These two inner loops are the same as the two outer loops, but running
265       // along rows instead of columns. For the 3x3 case `l` will be:
266       // 0, 3, 6, 1, 4, 7, 2, 5, 8.
267       for (int k = 0; k < wiener_win; ++k) {
268         for (int l = k; l < wiener_win2; l += wiener_win) {
269           // The nominal double transpose indexing would be:
270           // int idx = stride * j + l;
271           // However we need the upper-triangle indices, it is easy with some
272           // min/max operations.
273           int tr_idx = stride * AOMMIN(j, l) + AOMMAX(j, l);
274 
275           // Resulting matrix is filled by combining the 64-bit and the residual
276           // 32-bit matrices together with scaling.
277           *dst++ += (int64_t)(src_s64[tr_idx] + src_s32[tr_idx]) * scale;
278         }
279       }
280     }
281   }
282 }
283 
284 // Load 7x7 matrix into 3 and a half 128-bit vectors from consecutive rows, the
285 // last load address is offset to prevent out-of-bounds access.
load_and_pack_u8_8x7(uint8x16_t dst[4],const uint8_t * src,ptrdiff_t stride)286 static inline void load_and_pack_u8_8x7(uint8x16_t dst[4], const uint8_t *src,
287                                         ptrdiff_t stride) {
288   dst[0] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
289   src += 2 * stride;
290   dst[1] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
291   src += 2 * stride;
292   dst[2] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
293   src += 2 * stride;
294   dst[3] = vcombine_u8(vld1_u8(src - 1), vdup_n_u8(0));
295 }
296 
compute_stats_win7_downsampled_neon(const uint8_t * dgd,const uint8_t * src,int width,int height,int dgd_stride,int src_stride,int avg,int64_t * M,int64_t * H,int downsample_factor)297 static inline void compute_stats_win7_downsampled_neon(
298     const uint8_t *dgd, const uint8_t *src, int width, int height,
299     int dgd_stride, int src_stride, int avg, int64_t *M, int64_t *H,
300     int downsample_factor) {
301   // Matrix names are capitalized to help readability.
302   DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_ALIGN3]);
303   DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_ALIGN3]);
304   DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_ALIGN3]);
305   DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_ALIGN3]);
306   DECLARE_ALIGNED(64, int32_t, H_s32[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
307   DECLARE_ALIGNED(64, int64_t, H_s64[WIENER_WIN2 * WIENER_WIN2_ALIGN2]);
308 
309   memset(M_s32, 0, sizeof(M_s32));
310   memset(M_s64, 0, sizeof(M_s64));
311   memset(H_s32, 0, sizeof(H_s32));
312   memset(H_s64, 0, sizeof(H_s64));
313 
314   // Look-up tables to create 8x6 matrix with consecutive elements from two 7x7
315   // matrices.
316   // clang-format off
317   DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats7[96]) = {
318     0,  1,  2,  3,  4,  5,  6,  8,  9, 10, 11, 12, 13, 14, 16, 17,
319     2,  3,  4,  5,  6,  8,  9, 10, 11, 12, 13, 14, 16, 17, 18, 19,
320     4,  5,  6,  8,  9, 10, 11, 12, 13, 14, 17, 18, 19, 20, 21, 22,
321     1,  2,  3,  4,  5,  6,  7,  9, 10, 11, 12, 13, 14, 15, 17, 18,
322     3,  4,  5,  6,  7,  9, 10, 11, 12, 13, 14, 15, 17, 18, 19, 20,
323     5,  6,  7,  9, 10, 11, 12, 13, 14, 15, 18, 19, 20, 21, 22, 23,
324   };
325   // clang-format on
326 
327   const uint8x16_t lut0 = vld1q_u8(shuffle_stats7 + 0);
328   const uint8x16_t lut1 = vld1q_u8(shuffle_stats7 + 16);
329   const uint8x16_t lut2 = vld1q_u8(shuffle_stats7 + 32);
330   const uint8x16_t lut3 = vld1q_u8(shuffle_stats7 + 48);
331   const uint8x16_t lut4 = vld1q_u8(shuffle_stats7 + 64);
332   const uint8x16_t lut5 = vld1q_u8(shuffle_stats7 + 80);
333 
334   int acc_cnt = STAT_ACCUMULATOR_MAX;
335   const int src_next = downsample_factor * src_stride - width;
336   const int dgd_next = downsample_factor * dgd_stride - width;
337   const uint8x8_t avg_u8 = vdup_n_u8(avg);
338 
339   do {
340     int j = width;
341     while (j >= 2) {
342       // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
343       // middle 6x7 elements being shared.
344       uint8x16_t dgd_rows[4];
345       load_and_pack_u8_8x7(dgd_rows, dgd, dgd_stride);
346 
347       const uint8_t *dgd_ptr = dgd + dgd_stride * 6;
348       dgd += 2;
349 
350       // Re-arrange (and widen) the combined 8x7 matrix to have the 2 whole 7x7
351       // matrices (1 for each of the 2 pixels) separated into distinct
352       // int16x8_t[6] arrays. These arrays contain 48 elements of the 49 (7x7).
353       // Compute `dgd - avg` for both buffers. Each DGD_AVG buffer contains 49
354       // consecutive elements.
355       int16x8_t dgd_avg0[6];
356       int16x8_t dgd_avg1[6];
357       uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
358       uint8x16_t dgd_shuf3 = tbl2q(dgd_rows[0], dgd_rows[1], lut3);
359 
360       dgd_avg0[0] =
361           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
362       dgd_avg0[1] =
363           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
364       dgd_avg1[0] =
365           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf3), avg_u8));
366       dgd_avg1[1] =
367           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf3), avg_u8));
368 
369       vst1q_s16(DGD_AVG0, dgd_avg0[0]);
370       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
371       vst1q_s16(DGD_AVG1, dgd_avg1[0]);
372       vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
373 
374       uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[1], dgd_rows[2], lut1);
375       uint8x16_t dgd_shuf4 = tbl2q(dgd_rows[1], dgd_rows[2], lut4);
376 
377       dgd_avg0[2] =
378           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
379       dgd_avg0[3] =
380           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
381       dgd_avg1[2] =
382           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf4), avg_u8));
383       dgd_avg1[3] =
384           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf4), avg_u8));
385 
386       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
387       vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
388       vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
389       vst1q_s16(DGD_AVG1 + 24, dgd_avg1[3]);
390 
391       uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[2], dgd_rows[3], lut2);
392       uint8x16_t dgd_shuf5 = tbl2q(dgd_rows[2], dgd_rows[3], lut5);
393 
394       dgd_avg0[4] =
395           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
396       dgd_avg0[5] =
397           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
398       dgd_avg1[4] =
399           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf5), avg_u8));
400       dgd_avg1[5] =
401           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf5), avg_u8));
402 
403       vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
404       vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
405       vst1q_s16(DGD_AVG1 + 32, dgd_avg1[4]);
406       vst1q_s16(DGD_AVG1 + 40, dgd_avg1[5]);
407 
408       // The remaining last (49th) elements of `dgd - avg`.
409       DGD_AVG0[48] = dgd_ptr[6] - avg;
410       DGD_AVG1[48] = dgd_ptr[7] - avg;
411 
412       // Accumulate into row-major variant of matrix M (cross-correlation) for 2
413       // output pixels at a time. M is of size 7 * 7. It needs to be filled such
414       // that multiplying one element from src with each element of a row of the
415       // wiener window will fill one column of M. However this is not very
416       // convenient in terms of memory access, as it means we do contiguous
417       // loads of dgd but strided stores to M. As a result, we use an
418       // intermediate matrix M_s32 which is instead filled such that one row of
419       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
420       // then transposed to return M.
421       int src_avg0 = *src++ - avg;
422       int src_avg1 = *src++ - avg;
423       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
424       int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
425       update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
426                        dgd_avg1[0]);
427       update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
428                        dgd_avg1[1]);
429       update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
430                        dgd_avg1[2]);
431       update_M_2pixels(M_s32 + 24, src_avg0_s16, src_avg1_s16, dgd_avg0[3],
432                        dgd_avg1[3]);
433       update_M_2pixels(M_s32 + 32, src_avg0_s16, src_avg1_s16, dgd_avg0[4],
434                        dgd_avg1[4]);
435       update_M_2pixels(M_s32 + 40, src_avg0_s16, src_avg1_s16, dgd_avg0[5],
436                        dgd_avg1[5]);
437 
438       // Last (49th) element of M_s32 can be computed as scalar more efficiently
439       // for 2 output pixels.
440       M_s32[48] += DGD_AVG0[48] * src_avg0 + DGD_AVG1[48] * src_avg1;
441 
442       // Start accumulating into row-major version of matrix H
443       // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
444       // row-major. H is of size 49 * 49. It is filled by multiplying every pair
445       // of elements of the wiener window together (vector outer product). Since
446       // it is a symmetric matrix, we only compute the upper-right triangle, and
447       // then copy it down to the lower-left later. The upper triangle is
448       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
449       // column-major and the resulting H matrix is also expected to be
450       // column-major. It is not efficient to work with column-major matrices,
451       // so we accumulate into a row-major matrix H_s32. At the end of the
452       // algorithm a double transpose transformation will convert H_s32 back to
453       // the expected output layout.
454       update_H_7x7_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
455 
456       // The last element of the triangle of H_s32 matrix can be computed as a
457       // scalar more efficiently.
458       H_s32[48 * WIENER_WIN2_ALIGN2 + 48] +=
459           DGD_AVG0[48] * DGD_AVG0[48] + DGD_AVG1[48] * DGD_AVG1[48];
460 
461       // Accumulate into 64-bit after STAT_ACCUMULATOR_MAX iterations to prevent
462       // overflow.
463       if (--acc_cnt == 0) {
464         acc_cnt = STAT_ACCUMULATOR_MAX;
465 
466         accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_ALIGN2);
467 
468         // The widening accumulation is only needed for the upper triangle part
469         // of the matrix.
470         int64_t *lh = H_s64;
471         int32_t *lh32 = H_s32;
472         for (int k = 0; k < WIENER_WIN2; ++k) {
473           // The widening accumulation is only run for the relevant parts
474           // (upper-right triangle) in a row 4-element aligned.
475           int k4 = k / 4 * 4;
476           accumulate_and_clear(lh + k4, lh32 + k4, 48 - k4);
477 
478           // Last element of the row is computed separately.
479           lh[48] += lh32[48];
480           lh32[48] = 0;
481 
482           lh += WIENER_WIN2_ALIGN2;
483           lh32 += WIENER_WIN2_ALIGN2;
484         }
485       }
486 
487       j -= 2;
488     }
489 
490     // Computations for odd pixel in the row.
491     if (width & 1) {
492       // Load two adjacent, overlapping 7x7 matrices: a 8x7 matrix with the
493       // middle 6x7 elements being shared.
494       uint8x16_t dgd_rows[4];
495       load_and_pack_u8_8x7(dgd_rows, dgd, dgd_stride);
496 
497       const uint8_t *dgd_ptr = dgd + dgd_stride * 6;
498       ++dgd;
499 
500       // Re-arrange (and widen) the combined 8x7 matrix to have a whole 7x7
501       // matrix tightly packed into a int16x8_t[6] array. This array contains
502       // 48 elements of the 49 (7x7). Compute `dgd - avg` for the whole buffer.
503       // The DGD_AVG buffer contains 49 consecutive elements.
504       int16x8_t dgd_avg0[6];
505       uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
506       dgd_avg0[0] =
507           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
508       dgd_avg0[1] =
509           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
510       vst1q_s16(DGD_AVG0, dgd_avg0[0]);
511       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
512 
513       uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[1], dgd_rows[2], lut1);
514       dgd_avg0[2] =
515           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
516       dgd_avg0[3] =
517           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
518       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
519       vst1q_s16(DGD_AVG0 + 24, dgd_avg0[3]);
520 
521       uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[2], dgd_rows[3], lut2);
522       dgd_avg0[4] =
523           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
524       dgd_avg0[5] =
525           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
526       vst1q_s16(DGD_AVG0 + 32, dgd_avg0[4]);
527       vst1q_s16(DGD_AVG0 + 40, dgd_avg0[5]);
528 
529       // The remaining last (49th) element of `dgd - avg`.
530       DGD_AVG0[48] = dgd_ptr[6] - avg;
531 
532       // Accumulate into row-major order variant of matrix M (cross-correlation)
533       // for 1 output pixel at a time. M is of size 7 * 7. It needs to be filled
534       // such that multiplying one element from src with each element of a row
535       // of the wiener window will fill one column of M. However this is not
536       // very convenient in terms of memory access, as it means we do
537       // contiguous loads of dgd but strided stores to M. As a result, we use an
538       // intermediate matrix M_s32 which is instead filled such that one row of
539       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
540       // then transposed to return M.
541       int src_avg0 = *src++ - avg;
542       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
543       update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
544       update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
545       update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
546       update_M_1pixel(M_s32 + 24, src_avg0_s16, dgd_avg0[3]);
547       update_M_1pixel(M_s32 + 32, src_avg0_s16, dgd_avg0[4]);
548       update_M_1pixel(M_s32 + 40, src_avg0_s16, dgd_avg0[5]);
549 
550       // Last (49th) element of M_s32 can be computed as scalar more efficiently
551       // for 1 output pixel.
552       M_s32[48] += DGD_AVG0[48] * src_avg0;
553 
554       // Start accumulating into row-major order version of matrix H
555       // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
556       // H is of size 49 * 49. It is filled by multiplying every pair of
557       // elements of the wiener window together (vector outer product). Since it
558       // is a symmetric matrix, we only compute the upper-right triangle, and
559       // then copy it down to the lower-left later. The upper triangle is
560       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
561       // column-major and the resulting H matrix is also expected to be
562       // column-major. It is not efficient to work column-major matrices, so we
563       // accumulate into a row-major matrix H_s32. At the end of the algorithm a
564       // double transpose transformation will convert H_s32 back to the expected
565       // output layout.
566       update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_ALIGN2, 48);
567 
568       // The last element of the triangle of H_s32 matrix can be computed as
569       // scalar more efficiently.
570       H_s32[48 * WIENER_WIN2_ALIGN2 + 48] += DGD_AVG0[48] * DGD_AVG0[48];
571     }
572 
573     src += src_next;
574     dgd += dgd_next;
575   } while (--height != 0);
576 
577   acc_transpose_M(M, M_s64, M_s32, WIENER_WIN, downsample_factor);
578 
579   update_H(H, H_s64, H_s32, WIENER_WIN, WIENER_WIN2_ALIGN2, downsample_factor);
580 }
581 
582 // Load 5x5 matrix into 2 and a half 128-bit vectors from consecutive rows, the
583 // last load address is offset to prevent out-of-bounds access.
load_and_pack_u8_6x5(uint8x16_t dst[3],const uint8_t * src,ptrdiff_t stride)584 static inline void load_and_pack_u8_6x5(uint8x16_t dst[3], const uint8_t *src,
585                                         ptrdiff_t stride) {
586   dst[0] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
587   src += 2 * stride;
588   dst[1] = vcombine_u8(vld1_u8(src), vld1_u8(src + stride));
589   src += 2 * stride;
590   dst[2] = vcombine_u8(vld1_u8(src - 3), vdup_n_u8(0));
591 }
592 
compute_stats_win5_downsampled_neon(const uint8_t * dgd,const uint8_t * src,int width,int height,int dgd_stride,int src_stride,int avg,int64_t * M,int64_t * H,int downsample_factor)593 static inline void compute_stats_win5_downsampled_neon(
594     const uint8_t *dgd, const uint8_t *src, int width, int height,
595     int dgd_stride, int src_stride, int avg, int64_t *M, int64_t *H,
596     int downsample_factor) {
597   // Matrix names are capitalized to help readability.
598   DECLARE_ALIGNED(64, int16_t, DGD_AVG0[WIENER_WIN2_REDUCED_ALIGN3]);
599   DECLARE_ALIGNED(64, int16_t, DGD_AVG1[WIENER_WIN2_REDUCED_ALIGN3]);
600   DECLARE_ALIGNED(64, int32_t, M_s32[WIENER_WIN2_REDUCED_ALIGN3]);
601   DECLARE_ALIGNED(64, int64_t, M_s64[WIENER_WIN2_REDUCED_ALIGN3]);
602   DECLARE_ALIGNED(64, int32_t,
603                   H_s32[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
604   DECLARE_ALIGNED(64, int64_t,
605                   H_s64[WIENER_WIN2_REDUCED * WIENER_WIN2_REDUCED_ALIGN2]);
606 
607   memset(M_s32, 0, sizeof(M_s32));
608   memset(M_s64, 0, sizeof(M_s64));
609   memset(H_s32, 0, sizeof(H_s32));
610   memset(H_s64, 0, sizeof(H_s64));
611 
612   // Look-up tables to create 8x3 matrix with consecutive elements from two 5x5
613   // matrices.
614   // clang-format off
615   DECLARE_ALIGNED(16, static const uint8_t, shuffle_stats5[48]) = {
616     0,  1,  2,  3,  4,  8,  9, 10, 11, 12, 16, 17, 18, 19, 20, 24,
617     1,  2,  3,  4,  5,  9, 10, 11, 12, 13, 17, 18, 19, 20, 21, 25,
618     9, 10, 11, 12, 19, 20, 21, 22, 10, 11, 12, 13, 20, 21, 22, 23,
619   };
620   // clang-format on
621 
622   const uint8x16_t lut0 = vld1q_u8(shuffle_stats5 + 0);
623   const uint8x16_t lut1 = vld1q_u8(shuffle_stats5 + 16);
624   const uint8x16_t lut2 = vld1q_u8(shuffle_stats5 + 32);
625 
626   int acc_cnt = STAT_ACCUMULATOR_MAX;
627   const int src_next = downsample_factor * src_stride - width;
628   const int dgd_next = downsample_factor * dgd_stride - width;
629   const uint8x8_t avg_u8 = vdup_n_u8(avg);
630 
631   do {
632     int j = width;
633     while (j >= 2) {
634       // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
635       // middle 4x5 elements being shared.
636       uint8x16_t dgd_rows[3];
637       load_and_pack_u8_6x5(dgd_rows, dgd, dgd_stride);
638 
639       const uint8_t *dgd_ptr = dgd + dgd_stride * 4;
640       dgd += 2;
641 
642       // Re-arrange (and widen) the combined 6x5 matrix to have the 2 whole 5x5
643       // matrices (1 for each of the 2 pixels) separated into distinct
644       // int16x8_t[3] arrays. These arrays contain 24 elements of the 25 (5x5).
645       // Compute `dgd - avg` for both buffers. Each DGD_AVG buffer contains 25
646       // consecutive elements.
647       int16x8_t dgd_avg0[3];
648       int16x8_t dgd_avg1[3];
649       uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
650       uint8x16_t dgd_shuf1 = tbl2q(dgd_rows[0], dgd_rows[1], lut1);
651       uint8x16_t dgd_shuf2 = tbl2q(dgd_rows[1], dgd_rows[2], lut2);
652 
653       dgd_avg0[0] =
654           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
655       dgd_avg0[1] =
656           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
657       dgd_avg0[2] =
658           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf2), avg_u8));
659       dgd_avg1[0] =
660           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf1), avg_u8));
661       dgd_avg1[1] =
662           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf1), avg_u8));
663       dgd_avg1[2] =
664           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf2), avg_u8));
665 
666       vst1q_s16(DGD_AVG0 + 0, dgd_avg0[0]);
667       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
668       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
669       vst1q_s16(DGD_AVG1 + 0, dgd_avg1[0]);
670       vst1q_s16(DGD_AVG1 + 8, dgd_avg1[1]);
671       vst1q_s16(DGD_AVG1 + 16, dgd_avg1[2]);
672 
673       // The remaining last (25th) elements of `dgd - avg`.
674       DGD_AVG0[24] = dgd_ptr[4] - avg;
675       DGD_AVG1[24] = dgd_ptr[5] - avg;
676 
677       // Accumulate into row-major variant of matrix M (cross-correlation) for 2
678       // output pixels at a time. M is of size 5 * 5. It needs to be filled such
679       // that multiplying one element from src with each element of a row of the
680       // wiener window will fill one column of M. However this is not very
681       // convenient in terms of memory access, as it means we do contiguous
682       // loads of dgd but strided stores to M. As a result, we use an
683       // intermediate matrix M_s32 which is instead filled such that one row of
684       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
685       // then transposed to return M.
686       int src_avg0 = *src++ - avg;
687       int src_avg1 = *src++ - avg;
688       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
689       int16x4_t src_avg1_s16 = vdup_n_s16(src_avg1);
690       update_M_2pixels(M_s32 + 0, src_avg0_s16, src_avg1_s16, dgd_avg0[0],
691                        dgd_avg1[0]);
692       update_M_2pixels(M_s32 + 8, src_avg0_s16, src_avg1_s16, dgd_avg0[1],
693                        dgd_avg1[1]);
694       update_M_2pixels(M_s32 + 16, src_avg0_s16, src_avg1_s16, dgd_avg0[2],
695                        dgd_avg1[2]);
696 
697       // Last (25th) element of M_s32 can be computed as scalar more efficiently
698       // for 2 output pixels.
699       M_s32[24] += DGD_AVG0[24] * src_avg0 + DGD_AVG1[24] * src_avg1;
700 
701       // Start accumulating into row-major version of matrix H
702       // (auto-covariance), it expects the DGD_AVG[01] matrices to also be
703       // row-major. H is of size 25 * 25. It is filled by multiplying every pair
704       // of elements of the wiener window together (vector outer product). Since
705       // it is a symmetric matrix, we only compute the upper-right triangle, and
706       // then copy it down to the lower-left later. The upper triangle is
707       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
708       // column-major and the resulting H matrix is also expected to be
709       // column-major. It is not efficient to work with column-major matrices,
710       // so we accumulate into a row-major matrix H_s32. At the end of the
711       // algorithm a double transpose transformation will convert H_s32 back to
712       // the expected output layout.
713       update_H_5x5_2pixels(H_s32, DGD_AVG0, DGD_AVG1);
714 
715       // The last element of the triangle of H_s32 matrix can be computed as a
716       // scalar more efficiently.
717       H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
718           DGD_AVG0[24] * DGD_AVG0[24] + DGD_AVG1[24] * DGD_AVG1[24];
719 
720       // Accumulate into 64-bit after STAT_ACCUMULATOR_MAX iterations to prevent
721       // overflow.
722       if (--acc_cnt == 0) {
723         acc_cnt = STAT_ACCUMULATOR_MAX;
724 
725         accumulate_and_clear(M_s64, M_s32, WIENER_WIN2_REDUCED_ALIGN2);
726 
727         // The widening accumulation is only needed for the upper triangle part
728         // of the matrix.
729         int64_t *lh = H_s64;
730         int32_t *lh32 = H_s32;
731         for (int k = 0; k < WIENER_WIN2_REDUCED; ++k) {
732           // The widening accumulation is only run for the relevant parts
733           // (upper-right triangle) in a row 4-element aligned.
734           int k4 = k / 4 * 4;
735           accumulate_and_clear(lh + k4, lh32 + k4, 24 - k4);
736 
737           // Last element of the row is computed separately.
738           lh[24] += lh32[24];
739           lh32[24] = 0;
740 
741           lh += WIENER_WIN2_REDUCED_ALIGN2;
742           lh32 += WIENER_WIN2_REDUCED_ALIGN2;
743         }
744       }
745 
746       j -= 2;
747     }
748 
749     // Computations for odd pixel in the row.
750     if (width & 1) {
751       // Load two adjacent, overlapping 5x5 matrices: a 6x5 matrix with the
752       // middle 4x5 elements being shared.
753       uint8x16_t dgd_rows[3];
754       load_and_pack_u8_6x5(dgd_rows, dgd, dgd_stride);
755 
756       const uint8_t *dgd_ptr = dgd + dgd_stride * 4;
757       ++dgd;
758 
759       // Re-arrange (and widen) the combined 6x5 matrix to have a whole 5x5
760       // matrix tightly packed into a int16x8_t[3] array. This array contains
761       // 24 elements of the 25 (5x5). Compute `dgd - avg` for the whole buffer.
762       // The DGD_AVG buffer contains 25 consecutive elements.
763       int16x8_t dgd_avg0[3];
764       uint8x16_t dgd_shuf0 = tbl2q(dgd_rows[0], dgd_rows[1], lut0);
765       uint8x8_t dgd_shuf1 = tbl2(dgd_rows[1], dgd_rows[2], vget_low_u8(lut2));
766 
767       dgd_avg0[0] =
768           vreinterpretq_s16_u16(vsubl_u8(vget_low_u8(dgd_shuf0), avg_u8));
769       dgd_avg0[1] =
770           vreinterpretq_s16_u16(vsubl_u8(vget_high_u8(dgd_shuf0), avg_u8));
771       dgd_avg0[2] = vreinterpretq_s16_u16(vsubl_u8(dgd_shuf1, avg_u8));
772 
773       vst1q_s16(DGD_AVG0 + 0, dgd_avg0[0]);
774       vst1q_s16(DGD_AVG0 + 8, dgd_avg0[1]);
775       vst1q_s16(DGD_AVG0 + 16, dgd_avg0[2]);
776 
777       // The remaining last (25th) element of `dgd - avg`.
778       DGD_AVG0[24] = dgd_ptr[4] - avg;
779 
780       // Accumulate into row-major order variant of matrix M (cross-correlation)
781       // for 1 output pixel at a time. M is of size 5 * 5. It needs to be filled
782       // such that multiplying one element from src with each element of a row
783       // of the wiener window will fill one column of M. However this is not
784       // very convenient in terms of memory access, as it means we do
785       // contiguous loads of dgd but strided stores to M. As a result, we use an
786       // intermediate matrix M_s32 which is instead filled such that one row of
787       // the wiener window gives one row of M_s32. Once fully computed, M_s32 is
788       // then transposed to return M.
789       int src_avg0 = *src++ - avg;
790       int16x4_t src_avg0_s16 = vdup_n_s16(src_avg0);
791       update_M_1pixel(M_s32 + 0, src_avg0_s16, dgd_avg0[0]);
792       update_M_1pixel(M_s32 + 8, src_avg0_s16, dgd_avg0[1]);
793       update_M_1pixel(M_s32 + 16, src_avg0_s16, dgd_avg0[2]);
794 
795       // Last (25th) element of M_s32 can be computed as scalar more efficiently
796       // for 1 output pixel.
797       M_s32[24] += DGD_AVG0[24] * src_avg0;
798 
799       // Start accumulating into row-major order version of matrix H
800       // (auto-covariance), it expects the DGD_AVG0 matrix to also be row-major.
801       // H is of size 25 * 25. It is filled by multiplying every pair of
802       // elements of the wiener window together (vector outer product). Since it
803       // is a symmetric matrix, we only compute the upper-right triangle, and
804       // then copy it down to the lower-left later. The upper triangle is
805       // covered by 4x4 tiles. The original algorithm assumes the M matrix is
806       // column-major and the resulting H matrix is also expected to be
807       // column-major. It is not efficient to work column-major matrices, so we
808       // accumulate into a row-major matrix H_s32. At the end of the algorithm a
809       // double transpose transformation will convert H_s32 back to the expected
810       // output layout.
811       update_H_1pixel(H_s32, DGD_AVG0, WIENER_WIN2_REDUCED_ALIGN2, 24);
812 
813       // The last element of the triangle of H_s32 matrix can be computed as a
814       // scalar more efficiently.
815       H_s32[24 * WIENER_WIN2_REDUCED_ALIGN2 + 24] +=
816           DGD_AVG0[24] * DGD_AVG0[24];
817     }
818 
819     src += src_next;
820     dgd += dgd_next;
821   } while (--height != 0);
822 
823   acc_transpose_M(M, M_s64, M_s32, WIENER_WIN_REDUCED, downsample_factor);
824 
825   update_H(H, H_s64, H_s32, WIENER_WIN_REDUCED, WIENER_WIN2_REDUCED_ALIGN2,
826            downsample_factor);
827 }
828 
hadd_update_6_stats_neon(const int64_t * const src,const int32x4_t * deltas,int64_t * const dst)829 static inline void hadd_update_6_stats_neon(const int64_t *const src,
830                                             const int32x4_t *deltas,
831                                             int64_t *const dst) {
832   int32x4_t delta01 = horizontal_add_2d_s32(deltas[0], deltas[1]);
833   int32x4_t delta23 = horizontal_add_2d_s32(deltas[2], deltas[3]);
834   int32x4_t delta45 = horizontal_add_2d_s32(deltas[4], deltas[5]);
835 
836   int64x2_t delta01_s64 = vpaddlq_s32(delta01);
837   int64x2_t delta23_s64 = vpaddlq_s32(delta23);
838   int64x2_t delta45_s64 = vpaddlq_s32(delta45);
839 
840   int64x2_t src0 = vld1q_s64(src);
841   int64x2_t src1 = vld1q_s64(src + 2);
842   int64x2_t src2 = vld1q_s64(src + 4);
843 
844   vst1q_s64(dst, vaddq_s64(src0, delta01_s64));
845   vst1q_s64(dst + 2, vaddq_s64(src1, delta23_s64));
846   vst1q_s64(dst + 4, vaddq_s64(src2, delta45_s64));
847 }
848 
hadd_update_4_stats_neon(const int64_t * const src,const int32x4_t * deltas,int64_t * const dst)849 static inline void hadd_update_4_stats_neon(const int64_t *const src,
850                                             const int32x4_t *deltas,
851                                             int64_t *const dst) {
852   int32x4_t delta01 = horizontal_add_2d_s32(deltas[0], deltas[1]);
853   int32x4_t delta23 = horizontal_add_2d_s32(deltas[2], deltas[3]);
854   int64x2_t delta01_s64 = vpaddlq_s32(delta01);
855   int64x2_t delta23_s64 = vpaddlq_s32(delta23);
856 
857   int64x2_t src0 = vld1q_s64(src);
858   int64x2_t src1 = vld1q_s64(src + 2);
859   vst1q_s64(dst, vaddq_s64(src0, delta01_s64));
860   vst1q_s64(dst + 2, vaddq_s64(src1, delta23_s64));
861 }
862 
compute_stats_win5_neon(const int16_t * const d,const int32_t d_stride,const int16_t * const s,const int32_t s_stride,const int32_t width,const int32_t height,int64_t * const M,int64_t * const H)863 static inline void compute_stats_win5_neon(
864     const int16_t *const d, const int32_t d_stride, const int16_t *const s,
865     const int32_t s_stride, const int32_t width, const int32_t height,
866     int64_t *const M, int64_t *const H) {
867   const int32_t wiener_win = WIENER_WIN_CHROMA;
868   const int32_t wiener_win2 = wiener_win * wiener_win;
869   const int32_t w16 = width & ~15;
870   const int32_t h8 = height & ~7;
871   int16x8_t mask[2];
872   mask[0] = vld1q_s16(&(mask_16bit[16]) - width % 16);
873   mask[1] = vld1q_s16(&(mask_16bit[16]) - width % 16 + 8);
874   const int bit_depth = 8;
875   int32_t i, j, x, y;
876 
877   const int32_t num_bit_left =
878       32 - 1 /* sign */ - 2 * bit_depth /* energy */ + 2 /* SIMD */;
879   const int32_t h_allowed =
880       (1 << num_bit_left) / (w16 + ((w16 != width) ? 16 : 0));
881 
882   // Step 1: Calculate the top edge of the whole matrix, i.e., the top
883   // edge of each triangle and square on the top row.
884   j = 0;
885   do {
886     const int16_t *s_t = s;
887     const int16_t *d_t = d;
888     int32_t height_t = 0;
889     int64x2_t sum_m[WIENER_WIN_CHROMA] = { vdupq_n_s64(0) };
890     int64x2_t sum_h[WIENER_WIN_CHROMA] = { vdupq_n_s64(0) };
891     int16x8_t src[2], dgd[2];
892 
893     do {
894       const int32_t h_t =
895           ((height - height_t) < h_allowed) ? (height - height_t) : h_allowed;
896       int32x4_t row_m[WIENER_WIN_CHROMA] = { vdupq_n_s32(0) };
897       int32x4_t row_h[WIENER_WIN_CHROMA] = { vdupq_n_s32(0) };
898 
899       y = h_t;
900       do {
901         x = 0;
902         while (x < w16) {
903           src[0] = vld1q_s16(s_t + x + 0);
904           src[1] = vld1q_s16(s_t + x + 8);
905           dgd[0] = vld1q_s16(d_t + x + 0);
906           dgd[1] = vld1q_s16(d_t + x + 8);
907           stats_top_win5_neon(src, dgd, d_t + j + x, d_stride, row_m, row_h);
908           x += 16;
909         }
910 
911         if (w16 != width) {
912           src[0] = vld1q_s16(s_t + w16 + 0);
913           src[1] = vld1q_s16(s_t + w16 + 8);
914           dgd[0] = vld1q_s16(d_t + w16 + 0);
915           dgd[1] = vld1q_s16(d_t + w16 + 8);
916           src[0] = vandq_s16(src[0], mask[0]);
917           src[1] = vandq_s16(src[1], mask[1]);
918           dgd[0] = vandq_s16(dgd[0], mask[0]);
919           dgd[1] = vandq_s16(dgd[1], mask[1]);
920           stats_top_win5_neon(src, dgd, d_t + j + w16, d_stride, row_m, row_h);
921         }
922 
923         s_t += s_stride;
924         d_t += d_stride;
925       } while (--y);
926 
927       sum_m[0] = vpadalq_s32(sum_m[0], row_m[0]);
928       sum_m[1] = vpadalq_s32(sum_m[1], row_m[1]);
929       sum_m[2] = vpadalq_s32(sum_m[2], row_m[2]);
930       sum_m[3] = vpadalq_s32(sum_m[3], row_m[3]);
931       sum_m[4] = vpadalq_s32(sum_m[4], row_m[4]);
932       sum_h[0] = vpadalq_s32(sum_h[0], row_h[0]);
933       sum_h[1] = vpadalq_s32(sum_h[1], row_h[1]);
934       sum_h[2] = vpadalq_s32(sum_h[2], row_h[2]);
935       sum_h[3] = vpadalq_s32(sum_h[3], row_h[3]);
936       sum_h[4] = vpadalq_s32(sum_h[4], row_h[4]);
937 
938       height_t += h_t;
939     } while (height_t < height);
940 
941 #if AOM_ARCH_AARCH64
942     int64x2_t sum_m0 = vpaddq_s64(sum_m[0], sum_m[1]);
943     int64x2_t sum_m2 = vpaddq_s64(sum_m[2], sum_m[3]);
944     vst1q_s64(&M[wiener_win * j + 0], sum_m0);
945     vst1q_s64(&M[wiener_win * j + 2], sum_m2);
946     M[wiener_win * j + 4] = vaddvq_s64(sum_m[4]);
947 
948     int64x2_t sum_h0 = vpaddq_s64(sum_h[0], sum_h[1]);
949     int64x2_t sum_h2 = vpaddq_s64(sum_h[2], sum_h[3]);
950     vst1q_s64(&H[wiener_win * j + 0], sum_h0);
951     vst1q_s64(&H[wiener_win * j + 2], sum_h2);
952     H[wiener_win * j + 4] = vaddvq_s64(sum_h[4]);
953 #else
954     M[wiener_win * j + 0] = horizontal_add_s64x2(sum_m[0]);
955     M[wiener_win * j + 1] = horizontal_add_s64x2(sum_m[1]);
956     M[wiener_win * j + 2] = horizontal_add_s64x2(sum_m[2]);
957     M[wiener_win * j + 3] = horizontal_add_s64x2(sum_m[3]);
958     M[wiener_win * j + 4] = horizontal_add_s64x2(sum_m[4]);
959 
960     H[wiener_win * j + 0] = horizontal_add_s64x2(sum_h[0]);
961     H[wiener_win * j + 1] = horizontal_add_s64x2(sum_h[1]);
962     H[wiener_win * j + 2] = horizontal_add_s64x2(sum_h[2]);
963     H[wiener_win * j + 3] = horizontal_add_s64x2(sum_h[3]);
964     H[wiener_win * j + 4] = horizontal_add_s64x2(sum_h[4]);
965 #endif  // AOM_ARCH_AARCH64
966   } while (++j < wiener_win);
967 
968   // Step 2: Calculate the left edge of each square on the top row.
969   j = 1;
970   do {
971     const int16_t *d_t = d;
972     int32_t height_t = 0;
973     int64x2_t sum_h[WIENER_WIN_CHROMA - 1] = { vdupq_n_s64(0) };
974     int16x8_t dgd[2];
975 
976     do {
977       const int32_t h_t =
978           ((height - height_t) < h_allowed) ? (height - height_t) : h_allowed;
979       int32x4_t row_h[WIENER_WIN_CHROMA - 1] = { vdupq_n_s32(0) };
980 
981       y = h_t;
982       do {
983         x = 0;
984         while (x < w16) {
985           dgd[0] = vld1q_s16(d_t + j + x + 0);
986           dgd[1] = vld1q_s16(d_t + j + x + 8);
987           stats_left_win5_neon(dgd, d_t + x, d_stride, row_h);
988           x += 16;
989         }
990 
991         if (w16 != width) {
992           dgd[0] = vld1q_s16(d_t + j + x + 0);
993           dgd[1] = vld1q_s16(d_t + j + x + 8);
994           dgd[0] = vandq_s16(dgd[0], mask[0]);
995           dgd[1] = vandq_s16(dgd[1], mask[1]);
996           stats_left_win5_neon(dgd, d_t + x, d_stride, row_h);
997         }
998 
999         d_t += d_stride;
1000       } while (--y);
1001 
1002       sum_h[0] = vpadalq_s32(sum_h[0], row_h[0]);
1003       sum_h[1] = vpadalq_s32(sum_h[1], row_h[1]);
1004       sum_h[2] = vpadalq_s32(sum_h[2], row_h[2]);
1005       sum_h[3] = vpadalq_s32(sum_h[3], row_h[3]);
1006 
1007       height_t += h_t;
1008     } while (height_t < height);
1009 
1010 #if AOM_ARCH_AARCH64
1011     int64x2_t sum_h0 = vpaddq_s64(sum_h[0], sum_h[1]);
1012     int64x2_t sum_h1 = vpaddq_s64(sum_h[2], sum_h[3]);
1013     vst1_s64(&H[1 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h0));
1014     vst1_s64(&H[2 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h0));
1015     vst1_s64(&H[3 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h1));
1016     vst1_s64(&H[4 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h1));
1017 #else
1018     H[1 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[0]);
1019     H[2 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[1]);
1020     H[3 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[2]);
1021     H[4 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[3]);
1022 #endif  // AOM_ARCH_AARCH64
1023   } while (++j < wiener_win);
1024 
1025   // Step 3: Derive the top edge of each triangle along the diagonal. No
1026   // triangle in top row.
1027   {
1028     const int16_t *d_t = d;
1029 
1030     if (height % 2) {
1031       int32x4_t deltas[(WIENER_WIN + 1) * 2] = { vdupq_n_s32(0) };
1032       int32x4_t deltas_tr[(WIENER_WIN + 1) * 2] = { vdupq_n_s32(0) };
1033       int16x8_t ds[WIENER_WIN * 2];
1034 
1035       load_s16_8x4(d_t, d_stride, &ds[0], &ds[2], &ds[4], &ds[6]);
1036       load_s16_8x4(d_t + width, d_stride, &ds[1], &ds[3], &ds[5], &ds[7]);
1037       d_t += 4 * d_stride;
1038 
1039       step3_win5_oneline_neon(&d_t, d_stride, width, height, ds, deltas);
1040       transpose_arrays_s32_8x8(deltas, deltas_tr);
1041 
1042       update_5_stats_neon(H + 0 * wiener_win * wiener_win2 + 0 * wiener_win,
1043                           deltas_tr[0], vgetq_lane_s32(deltas_tr[4], 0),
1044                           H + 1 * wiener_win * wiener_win2 + 1 * wiener_win);
1045 
1046       update_5_stats_neon(H + 1 * wiener_win * wiener_win2 + 1 * wiener_win,
1047                           deltas_tr[1], vgetq_lane_s32(deltas_tr[5], 0),
1048                           H + 2 * wiener_win * wiener_win2 + 2 * wiener_win);
1049 
1050       update_5_stats_neon(H + 2 * wiener_win * wiener_win2 + 2 * wiener_win,
1051                           deltas_tr[2], vgetq_lane_s32(deltas_tr[6], 0),
1052                           H + 3 * wiener_win * wiener_win2 + 3 * wiener_win);
1053 
1054       update_5_stats_neon(H + 3 * wiener_win * wiener_win2 + 3 * wiener_win,
1055                           deltas_tr[3], vgetq_lane_s32(deltas_tr[7], 0),
1056                           H + 4 * wiener_win * wiener_win2 + 4 * wiener_win);
1057 
1058     } else {
1059       int32x4_t deltas[WIENER_WIN_CHROMA * 2] = { vdupq_n_s32(0) };
1060       int16x8_t ds[WIENER_WIN_CHROMA * 2];
1061 
1062       ds[0] = load_unaligned_s16_4x2(d_t + 0 * d_stride, width);
1063       ds[1] = load_unaligned_s16_4x2(d_t + 1 * d_stride, width);
1064       ds[2] = load_unaligned_s16_4x2(d_t + 2 * d_stride, width);
1065       ds[3] = load_unaligned_s16_4x2(d_t + 3 * d_stride, width);
1066 
1067       step3_win5_neon(d_t + 4 * d_stride, d_stride, width, height, ds, deltas);
1068 
1069       transpose_elems_inplace_s32_4x4(&deltas[0], &deltas[1], &deltas[2],
1070                                       &deltas[3]);
1071 
1072       update_5_stats_neon(H + 0 * wiener_win * wiener_win2 + 0 * wiener_win,
1073                           deltas[0], vgetq_lane_s32(deltas[4], 0),
1074                           H + 1 * wiener_win * wiener_win2 + 1 * wiener_win);
1075 
1076       update_5_stats_neon(H + 1 * wiener_win * wiener_win2 + 1 * wiener_win,
1077                           deltas[1], vgetq_lane_s32(deltas[4], 1),
1078                           H + 2 * wiener_win * wiener_win2 + 2 * wiener_win);
1079 
1080       update_5_stats_neon(H + 2 * wiener_win * wiener_win2 + 2 * wiener_win,
1081                           deltas[2], vgetq_lane_s32(deltas[4], 2),
1082                           H + 3 * wiener_win * wiener_win2 + 3 * wiener_win);
1083 
1084       update_5_stats_neon(H + 3 * wiener_win * wiener_win2 + 3 * wiener_win,
1085                           deltas[3], vgetq_lane_s32(deltas[4], 3),
1086                           H + 4 * wiener_win * wiener_win2 + 4 * wiener_win);
1087     }
1088   }
1089 
1090   // Step 4: Derive the top and left edge of each square. No square in top and
1091   // bottom row.
1092 
1093   {
1094     y = h8;
1095 
1096     int16x4_t d_s[12];
1097     int16x4_t d_e[12];
1098     const int16_t *d_t = d;
1099     int16x4_t zeros = vdup_n_s16(0);
1100     load_s16_4x4(d_t, d_stride, &d_s[0], &d_s[1], &d_s[2], &d_s[3]);
1101     load_s16_4x4(d_t + width, d_stride, &d_e[0], &d_e[1], &d_e[2], &d_e[3]);
1102     int32x4_t deltas[6][18] = { { vdupq_n_s32(0) }, { vdupq_n_s32(0) } };
1103 
1104     while (y >= 8) {
1105       load_s16_4x8(d_t + 4 * d_stride, d_stride, &d_s[4], &d_s[5], &d_s[6],
1106                    &d_s[7], &d_s[8], &d_s[9], &d_s[10], &d_s[11]);
1107       load_s16_4x8(d_t + width + 4 * d_stride, d_stride, &d_e[4], &d_e[5],
1108                    &d_e[6], &d_e[7], &d_e[8], &d_e[9], &d_e[10], &d_e[11]);
1109 
1110       int16x8_t s_tr[8], e_tr[8];
1111       transpose_elems_s16_4x8(d_s[0], d_s[1], d_s[2], d_s[3], d_s[4], d_s[5],
1112                               d_s[6], d_s[7], &s_tr[0], &s_tr[1], &s_tr[2],
1113                               &s_tr[3]);
1114       transpose_elems_s16_4x8(d_s[8], d_s[9], d_s[10], d_s[11], zeros, zeros,
1115                               zeros, zeros, &s_tr[4], &s_tr[5], &s_tr[6],
1116                               &s_tr[7]);
1117 
1118       transpose_elems_s16_4x8(d_e[0], d_e[1], d_e[2], d_e[3], d_e[4], d_e[5],
1119                               d_e[6], d_e[7], &e_tr[0], &e_tr[1], &e_tr[2],
1120                               &e_tr[3]);
1121       transpose_elems_s16_4x8(d_e[8], d_e[9], d_e[10], d_e[11], zeros, zeros,
1122                               zeros, zeros, &e_tr[4], &e_tr[5], &e_tr[6],
1123                               &e_tr[7]);
1124 
1125       int16x8_t start_col0[5], start_col1[5], start_col2[5], start_col3[5];
1126       start_col0[0] = s_tr[0];
1127       start_col0[1] = vextq_s16(s_tr[0], s_tr[4], 1);
1128       start_col0[2] = vextq_s16(s_tr[0], s_tr[4], 2);
1129       start_col0[3] = vextq_s16(s_tr[0], s_tr[4], 3);
1130       start_col0[4] = vextq_s16(s_tr[0], s_tr[4], 4);
1131 
1132       start_col1[0] = s_tr[1];
1133       start_col1[1] = vextq_s16(s_tr[1], s_tr[5], 1);
1134       start_col1[2] = vextq_s16(s_tr[1], s_tr[5], 2);
1135       start_col1[3] = vextq_s16(s_tr[1], s_tr[5], 3);
1136       start_col1[4] = vextq_s16(s_tr[1], s_tr[5], 4);
1137 
1138       start_col2[0] = s_tr[2];
1139       start_col2[1] = vextq_s16(s_tr[2], s_tr[6], 1);
1140       start_col2[2] = vextq_s16(s_tr[2], s_tr[6], 2);
1141       start_col2[3] = vextq_s16(s_tr[2], s_tr[6], 3);
1142       start_col2[4] = vextq_s16(s_tr[2], s_tr[6], 4);
1143 
1144       start_col3[0] = s_tr[3];
1145       start_col3[1] = vextq_s16(s_tr[3], s_tr[7], 1);
1146       start_col3[2] = vextq_s16(s_tr[3], s_tr[7], 2);
1147       start_col3[3] = vextq_s16(s_tr[3], s_tr[7], 3);
1148       start_col3[4] = vextq_s16(s_tr[3], s_tr[7], 4);
1149 
1150       // i = 1, j = 2;
1151       sub_deltas_step4(start_col0, start_col1, deltas[0]);
1152 
1153       // i = 1, j = 3;
1154       sub_deltas_step4(start_col0, start_col2, deltas[1]);
1155 
1156       // i = 1, j = 4
1157       sub_deltas_step4(start_col0, start_col3, deltas[2]);
1158 
1159       // i = 2, j =3
1160       sub_deltas_step4(start_col1, start_col2, deltas[3]);
1161 
1162       // i = 2, j = 4
1163       sub_deltas_step4(start_col1, start_col3, deltas[4]);
1164 
1165       // i = 3, j = 4
1166       sub_deltas_step4(start_col2, start_col3, deltas[5]);
1167 
1168       int16x8_t end_col0[5], end_col1[5], end_col2[5], end_col3[5];
1169       end_col0[0] = e_tr[0];
1170       end_col0[1] = vextq_s16(e_tr[0], e_tr[4], 1);
1171       end_col0[2] = vextq_s16(e_tr[0], e_tr[4], 2);
1172       end_col0[3] = vextq_s16(e_tr[0], e_tr[4], 3);
1173       end_col0[4] = vextq_s16(e_tr[0], e_tr[4], 4);
1174 
1175       end_col1[0] = e_tr[1];
1176       end_col1[1] = vextq_s16(e_tr[1], e_tr[5], 1);
1177       end_col1[2] = vextq_s16(e_tr[1], e_tr[5], 2);
1178       end_col1[3] = vextq_s16(e_tr[1], e_tr[5], 3);
1179       end_col1[4] = vextq_s16(e_tr[1], e_tr[5], 4);
1180 
1181       end_col2[0] = e_tr[2];
1182       end_col2[1] = vextq_s16(e_tr[2], e_tr[6], 1);
1183       end_col2[2] = vextq_s16(e_tr[2], e_tr[6], 2);
1184       end_col2[3] = vextq_s16(e_tr[2], e_tr[6], 3);
1185       end_col2[4] = vextq_s16(e_tr[2], e_tr[6], 4);
1186 
1187       end_col3[0] = e_tr[3];
1188       end_col3[1] = vextq_s16(e_tr[3], e_tr[7], 1);
1189       end_col3[2] = vextq_s16(e_tr[3], e_tr[7], 2);
1190       end_col3[3] = vextq_s16(e_tr[3], e_tr[7], 3);
1191       end_col3[4] = vextq_s16(e_tr[3], e_tr[7], 4);
1192 
1193       // i = 1, j = 2;
1194       add_deltas_step4(end_col0, end_col1, deltas[0]);
1195 
1196       // i = 1, j = 3;
1197       add_deltas_step4(end_col0, end_col2, deltas[1]);
1198 
1199       // i = 1, j = 4
1200       add_deltas_step4(end_col0, end_col3, deltas[2]);
1201 
1202       // i = 2, j =3
1203       add_deltas_step4(end_col1, end_col2, deltas[3]);
1204 
1205       // i = 2, j = 4
1206       add_deltas_step4(end_col1, end_col3, deltas[4]);
1207 
1208       // i = 3, j = 4
1209       add_deltas_step4(end_col2, end_col3, deltas[5]);
1210 
1211       d_s[0] = d_s[8];
1212       d_s[1] = d_s[9];
1213       d_s[2] = d_s[10];
1214       d_s[3] = d_s[11];
1215       d_e[0] = d_e[8];
1216       d_e[1] = d_e[9];
1217       d_e[2] = d_e[10];
1218       d_e[3] = d_e[11];
1219 
1220       d_t += 8 * d_stride;
1221       y -= 8;
1222     }
1223 
1224     if (h8 != height) {
1225       const int16x8_t mask_h = vld1q_s16(&mask_16bit[16] - (height % 8));
1226 
1227       load_s16_4x8(d_t + 4 * d_stride, d_stride, &d_s[4], &d_s[5], &d_s[6],
1228                    &d_s[7], &d_s[8], &d_s[9], &d_s[10], &d_s[11]);
1229       load_s16_4x8(d_t + width + 4 * d_stride, d_stride, &d_e[4], &d_e[5],
1230                    &d_e[6], &d_e[7], &d_e[8], &d_e[9], &d_e[10], &d_e[11]);
1231       int16x8_t s_tr[8], e_tr[8];
1232       transpose_elems_s16_4x8(d_s[0], d_s[1], d_s[2], d_s[3], d_s[4], d_s[5],
1233                               d_s[6], d_s[7], &s_tr[0], &s_tr[1], &s_tr[2],
1234                               &s_tr[3]);
1235       transpose_elems_s16_4x8(d_s[8], d_s[9], d_s[10], d_s[11], zeros, zeros,
1236                               zeros, zeros, &s_tr[4], &s_tr[5], &s_tr[6],
1237                               &s_tr[7]);
1238       transpose_elems_s16_4x8(d_e[0], d_e[1], d_e[2], d_e[3], d_e[4], d_e[5],
1239                               d_e[6], d_e[7], &e_tr[0], &e_tr[1], &e_tr[2],
1240                               &e_tr[3]);
1241       transpose_elems_s16_4x8(d_e[8], d_e[9], d_e[10], d_e[11], zeros, zeros,
1242                               zeros, zeros, &e_tr[4], &e_tr[5], &e_tr[6],
1243                               &e_tr[7]);
1244 
1245       int16x8_t start_col0[5], start_col1[5], start_col2[5], start_col3[5];
1246       start_col0[0] = vandq_s16(s_tr[0], mask_h);
1247       start_col0[1] = vandq_s16(vextq_s16(s_tr[0], s_tr[4], 1), mask_h);
1248       start_col0[2] = vandq_s16(vextq_s16(s_tr[0], s_tr[4], 2), mask_h);
1249       start_col0[3] = vandq_s16(vextq_s16(s_tr[0], s_tr[4], 3), mask_h);
1250       start_col0[4] = vandq_s16(vextq_s16(s_tr[0], s_tr[4], 4), mask_h);
1251 
1252       start_col1[0] = vandq_s16(s_tr[1], mask_h);
1253       start_col1[1] = vandq_s16(vextq_s16(s_tr[1], s_tr[5], 1), mask_h);
1254       start_col1[2] = vandq_s16(vextq_s16(s_tr[1], s_tr[5], 2), mask_h);
1255       start_col1[3] = vandq_s16(vextq_s16(s_tr[1], s_tr[5], 3), mask_h);
1256       start_col1[4] = vandq_s16(vextq_s16(s_tr[1], s_tr[5], 4), mask_h);
1257 
1258       start_col2[0] = vandq_s16(s_tr[2], mask_h);
1259       start_col2[1] = vandq_s16(vextq_s16(s_tr[2], s_tr[6], 1), mask_h);
1260       start_col2[2] = vandq_s16(vextq_s16(s_tr[2], s_tr[6], 2), mask_h);
1261       start_col2[3] = vandq_s16(vextq_s16(s_tr[2], s_tr[6], 3), mask_h);
1262       start_col2[4] = vandq_s16(vextq_s16(s_tr[2], s_tr[6], 4), mask_h);
1263 
1264       start_col3[0] = vandq_s16(s_tr[3], mask_h);
1265       start_col3[1] = vandq_s16(vextq_s16(s_tr[3], s_tr[7], 1), mask_h);
1266       start_col3[2] = vandq_s16(vextq_s16(s_tr[3], s_tr[7], 2), mask_h);
1267       start_col3[3] = vandq_s16(vextq_s16(s_tr[3], s_tr[7], 3), mask_h);
1268       start_col3[4] = vandq_s16(vextq_s16(s_tr[3], s_tr[7], 4), mask_h);
1269 
1270       // i = 1, j = 2;
1271       sub_deltas_step4(start_col0, start_col1, deltas[0]);
1272 
1273       // i = 1, j = 3;
1274       sub_deltas_step4(start_col0, start_col2, deltas[1]);
1275 
1276       // i = 1, j = 4
1277       sub_deltas_step4(start_col0, start_col3, deltas[2]);
1278 
1279       // i = 2, j = 3
1280       sub_deltas_step4(start_col1, start_col2, deltas[3]);
1281 
1282       // i = 2, j = 4
1283       sub_deltas_step4(start_col1, start_col3, deltas[4]);
1284 
1285       // i = 3, j = 4
1286       sub_deltas_step4(start_col2, start_col3, deltas[5]);
1287 
1288       int16x8_t end_col0[5], end_col1[5], end_col2[5], end_col3[5];
1289       end_col0[0] = vandq_s16(e_tr[0], mask_h);
1290       end_col0[1] = vandq_s16(vextq_s16(e_tr[0], e_tr[4], 1), mask_h);
1291       end_col0[2] = vandq_s16(vextq_s16(e_tr[0], e_tr[4], 2), mask_h);
1292       end_col0[3] = vandq_s16(vextq_s16(e_tr[0], e_tr[4], 3), mask_h);
1293       end_col0[4] = vandq_s16(vextq_s16(e_tr[0], e_tr[4], 4), mask_h);
1294 
1295       end_col1[0] = vandq_s16(e_tr[1], mask_h);
1296       end_col1[1] = vandq_s16(vextq_s16(e_tr[1], e_tr[5], 1), mask_h);
1297       end_col1[2] = vandq_s16(vextq_s16(e_tr[1], e_tr[5], 2), mask_h);
1298       end_col1[3] = vandq_s16(vextq_s16(e_tr[1], e_tr[5], 3), mask_h);
1299       end_col1[4] = vandq_s16(vextq_s16(e_tr[1], e_tr[5], 4), mask_h);
1300 
1301       end_col2[0] = vandq_s16(e_tr[2], mask_h);
1302       end_col2[1] = vandq_s16(vextq_s16(e_tr[2], e_tr[6], 1), mask_h);
1303       end_col2[2] = vandq_s16(vextq_s16(e_tr[2], e_tr[6], 2), mask_h);
1304       end_col2[3] = vandq_s16(vextq_s16(e_tr[2], e_tr[6], 3), mask_h);
1305       end_col2[4] = vandq_s16(vextq_s16(e_tr[2], e_tr[6], 4), mask_h);
1306 
1307       end_col3[0] = vandq_s16(e_tr[3], mask_h);
1308       end_col3[1] = vandq_s16(vextq_s16(e_tr[3], e_tr[7], 1), mask_h);
1309       end_col3[2] = vandq_s16(vextq_s16(e_tr[3], e_tr[7], 2), mask_h);
1310       end_col3[3] = vandq_s16(vextq_s16(e_tr[3], e_tr[7], 3), mask_h);
1311       end_col3[4] = vandq_s16(vextq_s16(e_tr[3], e_tr[7], 4), mask_h);
1312 
1313       // i = 1, j = 2;
1314       add_deltas_step4(end_col0, end_col1, deltas[0]);
1315 
1316       // i = 1, j = 3;
1317       add_deltas_step4(end_col0, end_col2, deltas[1]);
1318 
1319       // i = 1, j = 4
1320       add_deltas_step4(end_col0, end_col3, deltas[2]);
1321 
1322       // i = 2, j =3
1323       add_deltas_step4(end_col1, end_col2, deltas[3]);
1324 
1325       // i = 2, j = 4
1326       add_deltas_step4(end_col1, end_col3, deltas[4]);
1327 
1328       // i = 3, j = 4
1329       add_deltas_step4(end_col2, end_col3, deltas[5]);
1330     }
1331 
1332     int32x4_t delta[6][2];
1333     int32_t single_delta[6];
1334 
1335     delta[0][0] = horizontal_add_4d_s32x4(&deltas[0][0]);
1336     delta[1][0] = horizontal_add_4d_s32x4(&deltas[1][0]);
1337     delta[2][0] = horizontal_add_4d_s32x4(&deltas[2][0]);
1338     delta[3][0] = horizontal_add_4d_s32x4(&deltas[3][0]);
1339     delta[4][0] = horizontal_add_4d_s32x4(&deltas[4][0]);
1340     delta[5][0] = horizontal_add_4d_s32x4(&deltas[5][0]);
1341 
1342     delta[0][1] = horizontal_add_4d_s32x4(&deltas[0][5]);
1343     delta[1][1] = horizontal_add_4d_s32x4(&deltas[1][5]);
1344     delta[2][1] = horizontal_add_4d_s32x4(&deltas[2][5]);
1345     delta[3][1] = horizontal_add_4d_s32x4(&deltas[3][5]);
1346     delta[4][1] = horizontal_add_4d_s32x4(&deltas[4][5]);
1347     delta[5][1] = horizontal_add_4d_s32x4(&deltas[5][5]);
1348 
1349     single_delta[0] = horizontal_add_s32x4(deltas[0][4]);
1350     single_delta[1] = horizontal_add_s32x4(deltas[1][4]);
1351     single_delta[2] = horizontal_add_s32x4(deltas[2][4]);
1352     single_delta[3] = horizontal_add_s32x4(deltas[3][4]);
1353     single_delta[4] = horizontal_add_s32x4(deltas[4][4]);
1354     single_delta[5] = horizontal_add_s32x4(deltas[5][4]);
1355 
1356     int idx = 0;
1357     for (i = 1; i < wiener_win - 1; i++) {
1358       for (j = i + 1; j < wiener_win; j++) {
1359         update_4_stats_neon(
1360             H + (i - 1) * wiener_win * wiener_win2 + (j - 1) * wiener_win,
1361             delta[idx][0], H + i * wiener_win * wiener_win2 + j * wiener_win);
1362         H[i * wiener_win * wiener_win2 + j * wiener_win + 4] =
1363             H[(i - 1) * wiener_win * wiener_win2 + (j - 1) * wiener_win + 4] +
1364             single_delta[idx];
1365 
1366         H[(i * wiener_win + 1) * wiener_win2 + j * wiener_win] =
1367             H[((i - 1) * wiener_win + 1) * wiener_win2 + (j - 1) * wiener_win] +
1368             vgetq_lane_s32(delta[idx][1], 0);
1369         H[(i * wiener_win + 2) * wiener_win2 + j * wiener_win] =
1370             H[((i - 1) * wiener_win + 2) * wiener_win2 + (j - 1) * wiener_win] +
1371             vgetq_lane_s32(delta[idx][1], 1);
1372         H[(i * wiener_win + 3) * wiener_win2 + j * wiener_win] =
1373             H[((i - 1) * wiener_win + 3) * wiener_win2 + (j - 1) * wiener_win] +
1374             vgetq_lane_s32(delta[idx][1], 2);
1375         H[(i * wiener_win + 4) * wiener_win2 + j * wiener_win] =
1376             H[((i - 1) * wiener_win + 4) * wiener_win2 + (j - 1) * wiener_win] +
1377             vgetq_lane_s32(delta[idx][1], 3);
1378 
1379         idx++;
1380       }
1381     }
1382   }
1383 
1384   // Step 5: Derive other points of each square. No square in bottom row.
1385   i = 0;
1386   do {
1387     const int16_t *const di = d + i;
1388 
1389     j = i + 1;
1390     do {
1391       const int16_t *const dj = d + j;
1392       int32x4_t deltas[WIENER_WIN_CHROMA - 1][WIENER_WIN_CHROMA - 1] = {
1393         { vdupq_n_s32(0) }, { vdupq_n_s32(0) }
1394       };
1395       int16x8_t d_is[WIN_CHROMA], d_ie[WIN_CHROMA];
1396       int16x8_t d_js[WIN_CHROMA], d_je[WIN_CHROMA];
1397 
1398       x = 0;
1399       while (x < w16) {
1400         load_square_win5_neon(di + x, dj + x, d_stride, height, d_is, d_ie,
1401                               d_js, d_je);
1402         derive_square_win5_neon(d_is, d_ie, d_js, d_je, deltas);
1403         x += 16;
1404       }
1405 
1406       if (w16 != width) {
1407         load_square_win5_neon(di + x, dj + x, d_stride, height, d_is, d_ie,
1408                               d_js, d_je);
1409         d_is[0] = vandq_s16(d_is[0], mask[0]);
1410         d_is[1] = vandq_s16(d_is[1], mask[1]);
1411         d_is[2] = vandq_s16(d_is[2], mask[0]);
1412         d_is[3] = vandq_s16(d_is[3], mask[1]);
1413         d_is[4] = vandq_s16(d_is[4], mask[0]);
1414         d_is[5] = vandq_s16(d_is[5], mask[1]);
1415         d_is[6] = vandq_s16(d_is[6], mask[0]);
1416         d_is[7] = vandq_s16(d_is[7], mask[1]);
1417         d_ie[0] = vandq_s16(d_ie[0], mask[0]);
1418         d_ie[1] = vandq_s16(d_ie[1], mask[1]);
1419         d_ie[2] = vandq_s16(d_ie[2], mask[0]);
1420         d_ie[3] = vandq_s16(d_ie[3], mask[1]);
1421         d_ie[4] = vandq_s16(d_ie[4], mask[0]);
1422         d_ie[5] = vandq_s16(d_ie[5], mask[1]);
1423         d_ie[6] = vandq_s16(d_ie[6], mask[0]);
1424         d_ie[7] = vandq_s16(d_ie[7], mask[1]);
1425         derive_square_win5_neon(d_is, d_ie, d_js, d_je, deltas);
1426       }
1427 
1428       hadd_update_4_stats_neon(
1429           H + (i * wiener_win + 0) * wiener_win2 + j * wiener_win, deltas[0],
1430           H + (i * wiener_win + 1) * wiener_win2 + j * wiener_win + 1);
1431       hadd_update_4_stats_neon(
1432           H + (i * wiener_win + 1) * wiener_win2 + j * wiener_win, deltas[1],
1433           H + (i * wiener_win + 2) * wiener_win2 + j * wiener_win + 1);
1434       hadd_update_4_stats_neon(
1435           H + (i * wiener_win + 2) * wiener_win2 + j * wiener_win, deltas[2],
1436           H + (i * wiener_win + 3) * wiener_win2 + j * wiener_win + 1);
1437       hadd_update_4_stats_neon(
1438           H + (i * wiener_win + 3) * wiener_win2 + j * wiener_win, deltas[3],
1439           H + (i * wiener_win + 4) * wiener_win2 + j * wiener_win + 1);
1440     } while (++j < wiener_win);
1441   } while (++i < wiener_win - 1);
1442 
1443   // Step 6: Derive other points of each upper triangle along the diagonal.
1444   i = 0;
1445   do {
1446     const int16_t *const di = d + i;
1447     int32x4_t deltas[WIENER_WIN_CHROMA * 2 + 1] = { vdupq_n_s32(0) };
1448     int16x8_t d_is[WIN_CHROMA], d_ie[WIN_CHROMA];
1449 
1450     x = 0;
1451     while (x < w16) {
1452       load_triangle_win5_neon(di + x, d_stride, height, d_is, d_ie);
1453       derive_triangle_win5_neon(d_is, d_ie, deltas);
1454       x += 16;
1455     }
1456 
1457     if (w16 != width) {
1458       load_triangle_win5_neon(di + x, d_stride, height, d_is, d_ie);
1459       d_is[0] = vandq_s16(d_is[0], mask[0]);
1460       d_is[1] = vandq_s16(d_is[1], mask[1]);
1461       d_is[2] = vandq_s16(d_is[2], mask[0]);
1462       d_is[3] = vandq_s16(d_is[3], mask[1]);
1463       d_is[4] = vandq_s16(d_is[4], mask[0]);
1464       d_is[5] = vandq_s16(d_is[5], mask[1]);
1465       d_is[6] = vandq_s16(d_is[6], mask[0]);
1466       d_is[7] = vandq_s16(d_is[7], mask[1]);
1467       d_ie[0] = vandq_s16(d_ie[0], mask[0]);
1468       d_ie[1] = vandq_s16(d_ie[1], mask[1]);
1469       d_ie[2] = vandq_s16(d_ie[2], mask[0]);
1470       d_ie[3] = vandq_s16(d_ie[3], mask[1]);
1471       d_ie[4] = vandq_s16(d_ie[4], mask[0]);
1472       d_ie[5] = vandq_s16(d_ie[5], mask[1]);
1473       d_ie[6] = vandq_s16(d_ie[6], mask[0]);
1474       d_ie[7] = vandq_s16(d_ie[7], mask[1]);
1475       derive_triangle_win5_neon(d_is, d_ie, deltas);
1476     }
1477 
1478     // Row 1: 4 points
1479     hadd_update_4_stats_neon(
1480         H + (i * wiener_win + 0) * wiener_win2 + i * wiener_win, deltas,
1481         H + (i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1);
1482 
1483     // Row 2: 3 points
1484     int32x4_t deltas45 = horizontal_add_2d_s32(deltas[4], deltas[5]);
1485     int32x4_t deltas78 = horizontal_add_2d_s32(deltas[7], deltas[8]);
1486 
1487     int64x2_t deltas45_s64 = vpaddlq_s32(deltas45);
1488     int64x2_t deltas78_s64 = vpaddlq_s32(deltas78);
1489 
1490     int64x2_t src =
1491         vld1q_s64(H + (i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1);
1492     int64x2_t dst = vaddq_s64(src, deltas45_s64);
1493     vst1q_s64(H + (i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2, dst);
1494 
1495     int32x4_t delta69 = horizontal_add_2d_s32(deltas[6], deltas[9]);
1496     int64x2_t delta69_s64 = vpaddlq_s32(delta69);
1497     H[(i * wiener_win + 2) * wiener_win2 + i * wiener_win + 4] =
1498         H[(i * wiener_win + 1) * wiener_win2 + i * wiener_win + 3] +
1499         vgetq_lane_s64(delta69_s64, 0);
1500 
1501     // Row 3: 2 points
1502     vst1q_s64(H + (i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3,
1503               vaddq_s64(dst, deltas78_s64));
1504 
1505     // Row 4: 1 point
1506     H[(i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4] =
1507         H[(i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3] +
1508         vgetq_lane_s64(delta69_s64, 1);
1509   } while (++i < wiener_win);
1510 }
1511 
compute_stats_win7_neon(const int16_t * const d,const int32_t d_stride,const int16_t * const s,const int32_t s_stride,const int32_t width,const int32_t height,int64_t * const M,int64_t * const H)1512 static inline void compute_stats_win7_neon(
1513     const int16_t *const d, const int32_t d_stride, const int16_t *const s,
1514     const int32_t s_stride, const int32_t width, const int32_t height,
1515     int64_t *const M, int64_t *const H) {
1516   const int32_t wiener_win = WIENER_WIN;
1517   const int32_t wiener_win2 = wiener_win * wiener_win;
1518   const int32_t w16 = width & ~15;
1519   const int32_t h8 = height & ~7;
1520   int16x8_t mask[2];
1521   mask[0] = vld1q_s16(&(mask_16bit[16]) - width % 16);
1522   mask[1] = vld1q_s16(&(mask_16bit[16]) - width % 16 + 8);
1523   const int bit_depth = 8;
1524   int32_t i, j, x, y;
1525 
1526   const int32_t num_bit_left =
1527       32 - 1 /* sign */ - 2 * bit_depth /* energy */ + 2 /* SIMD */;
1528   const int32_t h_allowed =
1529       (1 << num_bit_left) / (w16 + ((w16 != width) ? 16 : 0));
1530 
1531   // Step 1: Calculate the top edge of the whole matrix, i.e., the top
1532   // edge of each triangle and square on the top row.
1533   j = 0;
1534   do {
1535     const int16_t *s_t = s;
1536     const int16_t *d_t = d;
1537     int32_t height_t = 0;
1538     int64x2_t sum_m[WIENER_WIN] = { vdupq_n_s64(0) };
1539     int64x2_t sum_h[WIENER_WIN] = { vdupq_n_s64(0) };
1540     int16x8_t src[2], dgd[2];
1541 
1542     do {
1543       const int32_t h_t =
1544           ((height - height_t) < h_allowed) ? (height - height_t) : h_allowed;
1545       int32x4_t row_m[WIENER_WIN * 2] = { vdupq_n_s32(0) };
1546       int32x4_t row_h[WIENER_WIN * 2] = { vdupq_n_s32(0) };
1547 
1548       y = h_t;
1549       do {
1550         x = 0;
1551         while (x < w16) {
1552           src[0] = vld1q_s16(s_t + x);
1553           src[1] = vld1q_s16(s_t + x + 8);
1554           dgd[0] = vld1q_s16(d_t + x);
1555           dgd[1] = vld1q_s16(d_t + x + 8);
1556           stats_top_win7_neon(src, dgd, d_t + j + x, d_stride, row_m, row_h);
1557           x += 16;
1558         }
1559 
1560         if (w16 != width) {
1561           src[0] = vld1q_s16(s_t + w16);
1562           src[1] = vld1q_s16(s_t + w16 + 8);
1563           dgd[0] = vld1q_s16(d_t + w16);
1564           dgd[1] = vld1q_s16(d_t + w16 + 8);
1565           src[0] = vandq_s16(src[0], mask[0]);
1566           src[1] = vandq_s16(src[1], mask[1]);
1567           dgd[0] = vandq_s16(dgd[0], mask[0]);
1568           dgd[1] = vandq_s16(dgd[1], mask[1]);
1569           stats_top_win7_neon(src, dgd, d_t + j + w16, d_stride, row_m, row_h);
1570         }
1571 
1572         s_t += s_stride;
1573         d_t += d_stride;
1574       } while (--y);
1575 
1576       sum_m[0] = vpadalq_s32(sum_m[0], row_m[0]);
1577       sum_m[1] = vpadalq_s32(sum_m[1], row_m[1]);
1578       sum_m[2] = vpadalq_s32(sum_m[2], row_m[2]);
1579       sum_m[3] = vpadalq_s32(sum_m[3], row_m[3]);
1580       sum_m[4] = vpadalq_s32(sum_m[4], row_m[4]);
1581       sum_m[5] = vpadalq_s32(sum_m[5], row_m[5]);
1582       sum_m[6] = vpadalq_s32(sum_m[6], row_m[6]);
1583 
1584       sum_h[0] = vpadalq_s32(sum_h[0], row_h[0]);
1585       sum_h[1] = vpadalq_s32(sum_h[1], row_h[1]);
1586       sum_h[2] = vpadalq_s32(sum_h[2], row_h[2]);
1587       sum_h[3] = vpadalq_s32(sum_h[3], row_h[3]);
1588       sum_h[4] = vpadalq_s32(sum_h[4], row_h[4]);
1589       sum_h[5] = vpadalq_s32(sum_h[5], row_h[5]);
1590       sum_h[6] = vpadalq_s32(sum_h[6], row_h[6]);
1591 
1592       height_t += h_t;
1593     } while (height_t < height);
1594 
1595 #if AOM_ARCH_AARCH64
1596     vst1q_s64(M + wiener_win * j + 0, vpaddq_s64(sum_m[0], sum_m[1]));
1597     vst1q_s64(M + wiener_win * j + 2, vpaddq_s64(sum_m[2], sum_m[3]));
1598     vst1q_s64(M + wiener_win * j + 4, vpaddq_s64(sum_m[4], sum_m[5]));
1599     M[wiener_win * j + 6] = vaddvq_s64(sum_m[6]);
1600 
1601     vst1q_s64(H + wiener_win * j + 0, vpaddq_s64(sum_h[0], sum_h[1]));
1602     vst1q_s64(H + wiener_win * j + 2, vpaddq_s64(sum_h[2], sum_h[3]));
1603     vst1q_s64(H + wiener_win * j + 4, vpaddq_s64(sum_h[4], sum_h[5]));
1604     H[wiener_win * j + 6] = vaddvq_s64(sum_h[6]);
1605 #else
1606     M[wiener_win * j + 0] = horizontal_add_s64x2(sum_m[0]);
1607     M[wiener_win * j + 1] = horizontal_add_s64x2(sum_m[1]);
1608     M[wiener_win * j + 2] = horizontal_add_s64x2(sum_m[2]);
1609     M[wiener_win * j + 3] = horizontal_add_s64x2(sum_m[3]);
1610     M[wiener_win * j + 4] = horizontal_add_s64x2(sum_m[4]);
1611     M[wiener_win * j + 5] = horizontal_add_s64x2(sum_m[5]);
1612     M[wiener_win * j + 6] = horizontal_add_s64x2(sum_m[6]);
1613 
1614     H[wiener_win * j + 0] = horizontal_add_s64x2(sum_h[0]);
1615     H[wiener_win * j + 1] = horizontal_add_s64x2(sum_h[1]);
1616     H[wiener_win * j + 2] = horizontal_add_s64x2(sum_h[2]);
1617     H[wiener_win * j + 3] = horizontal_add_s64x2(sum_h[3]);
1618     H[wiener_win * j + 4] = horizontal_add_s64x2(sum_h[4]);
1619     H[wiener_win * j + 5] = horizontal_add_s64x2(sum_h[5]);
1620     H[wiener_win * j + 6] = horizontal_add_s64x2(sum_h[6]);
1621 #endif  // AOM_ARCH_AARCH64
1622   } while (++j < wiener_win);
1623 
1624   // Step 2: Calculate the left edge of each square on the top row.
1625   j = 1;
1626   do {
1627     const int16_t *d_t = d;
1628     int32_t height_t = 0;
1629     int64x2_t sum_h[WIENER_WIN - 1] = { vdupq_n_s64(0) };
1630     int16x8_t dgd[2];
1631 
1632     do {
1633       const int32_t h_t =
1634           ((height - height_t) < h_allowed) ? (height - height_t) : h_allowed;
1635       int32x4_t row_h[WIENER_WIN - 1] = { vdupq_n_s32(0) };
1636 
1637       y = h_t;
1638       do {
1639         x = 0;
1640         while (x < w16) {
1641           dgd[0] = vld1q_s16(d_t + j + x + 0);
1642           dgd[1] = vld1q_s16(d_t + j + x + 8);
1643           stats_left_win7_neon(dgd, d_t + x, d_stride, row_h);
1644           x += 16;
1645         }
1646 
1647         if (w16 != width) {
1648           dgd[0] = vld1q_s16(d_t + j + x + 0);
1649           dgd[1] = vld1q_s16(d_t + j + x + 8);
1650           dgd[0] = vandq_s16(dgd[0], mask[0]);
1651           dgd[1] = vandq_s16(dgd[1], mask[1]);
1652           stats_left_win7_neon(dgd, d_t + x, d_stride, row_h);
1653         }
1654 
1655         d_t += d_stride;
1656       } while (--y);
1657 
1658       sum_h[0] = vpadalq_s32(sum_h[0], row_h[0]);
1659       sum_h[1] = vpadalq_s32(sum_h[1], row_h[1]);
1660       sum_h[2] = vpadalq_s32(sum_h[2], row_h[2]);
1661       sum_h[3] = vpadalq_s32(sum_h[3], row_h[3]);
1662       sum_h[4] = vpadalq_s32(sum_h[4], row_h[4]);
1663       sum_h[5] = vpadalq_s32(sum_h[5], row_h[5]);
1664 
1665       height_t += h_t;
1666     } while (height_t < height);
1667 
1668 #if AOM_ARCH_AARCH64
1669     int64x2_t sum_h0 = vpaddq_s64(sum_h[0], sum_h[1]);
1670     int64x2_t sum_h2 = vpaddq_s64(sum_h[2], sum_h[3]);
1671     int64x2_t sum_h4 = vpaddq_s64(sum_h[4], sum_h[5]);
1672     vst1_s64(&H[1 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h0));
1673     vst1_s64(&H[2 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h0));
1674     vst1_s64(&H[3 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h2));
1675     vst1_s64(&H[4 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h2));
1676     vst1_s64(&H[5 * wiener_win2 + j * wiener_win], vget_low_s64(sum_h4));
1677     vst1_s64(&H[6 * wiener_win2 + j * wiener_win], vget_high_s64(sum_h4));
1678 #else
1679     H[1 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[0]);
1680     H[2 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[1]);
1681     H[3 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[2]);
1682     H[4 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[3]);
1683     H[5 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[4]);
1684     H[6 * wiener_win2 + j * wiener_win] = horizontal_add_s64x2(sum_h[5]);
1685 #endif  // AOM_ARCH_AARCH64
1686   } while (++j < wiener_win);
1687 
1688   // Step 3: Derive the top edge of each triangle along the diagonal. No
1689   // triangle in top row.
1690   {
1691     const int16_t *d_t = d;
1692     // Pad to call transpose function.
1693     int32x4_t deltas[(WIENER_WIN + 1) * 2] = { vdupq_n_s32(0) };
1694     int32x4_t deltas_tr[(WIENER_WIN + 1) * 2] = { vdupq_n_s32(0) };
1695     int16x8_t ds[WIENER_WIN * 2];
1696 
1697     load_s16_8x6(d_t, d_stride, &ds[0], &ds[2], &ds[4], &ds[6], &ds[8],
1698                  &ds[10]);
1699     load_s16_8x6(d_t + width, d_stride, &ds[1], &ds[3], &ds[5], &ds[7], &ds[9],
1700                  &ds[11]);
1701 
1702     d_t += 6 * d_stride;
1703 
1704     step3_win7_neon(d_t, d_stride, width, height, ds, deltas);
1705     transpose_arrays_s32_8x8(deltas, deltas_tr);
1706 
1707     update_8_stats_neon(H + 0 * wiener_win * wiener_win2 + 0 * wiener_win,
1708                         deltas_tr[0], deltas_tr[4],
1709                         H + 1 * wiener_win * wiener_win2 + 1 * wiener_win);
1710     update_8_stats_neon(H + 1 * wiener_win * wiener_win2 + 1 * wiener_win,
1711                         deltas_tr[1], deltas_tr[5],
1712                         H + 2 * wiener_win * wiener_win2 + 2 * wiener_win);
1713     update_8_stats_neon(H + 2 * wiener_win * wiener_win2 + 2 * wiener_win,
1714                         deltas_tr[2], deltas_tr[6],
1715                         H + 3 * wiener_win * wiener_win2 + 3 * wiener_win);
1716     update_8_stats_neon(H + 3 * wiener_win * wiener_win2 + 3 * wiener_win,
1717                         deltas_tr[3], deltas_tr[7],
1718                         H + 4 * wiener_win * wiener_win2 + 4 * wiener_win);
1719     update_8_stats_neon(H + 4 * wiener_win * wiener_win2 + 4 * wiener_win,
1720                         deltas_tr[8], deltas_tr[12],
1721                         H + 5 * wiener_win * wiener_win2 + 5 * wiener_win);
1722     update_8_stats_neon(H + 5 * wiener_win * wiener_win2 + 5 * wiener_win,
1723                         deltas_tr[9], deltas_tr[13],
1724                         H + 6 * wiener_win * wiener_win2 + 6 * wiener_win);
1725   }
1726 
1727   // Step 4: Derive the top and left edge of each square. No square in top and
1728   // bottom row.
1729 
1730   i = 1;
1731   do {
1732     j = i + 1;
1733     do {
1734       const int16_t *di = d + i - 1;
1735       const int16_t *dj = d + j - 1;
1736       int32x4_t deltas[(2 * WIENER_WIN - 1) * 2] = { vdupq_n_s32(0) };
1737       int16x8_t dd[WIENER_WIN * 2], ds[WIENER_WIN * 2];
1738 
1739       dd[5] = vdupq_n_s16(0);  // Initialize to avoid warning.
1740       const int16_t dd0_values[] = { di[0 * d_stride],
1741                                      di[1 * d_stride],
1742                                      di[2 * d_stride],
1743                                      di[3 * d_stride],
1744                                      di[4 * d_stride],
1745                                      di[5 * d_stride],
1746                                      0,
1747                                      0 };
1748       dd[0] = vld1q_s16(dd0_values);
1749       const int16_t dd1_values[] = { di[0 * d_stride + width],
1750                                      di[1 * d_stride + width],
1751                                      di[2 * d_stride + width],
1752                                      di[3 * d_stride + width],
1753                                      di[4 * d_stride + width],
1754                                      di[5 * d_stride + width],
1755                                      0,
1756                                      0 };
1757       dd[1] = vld1q_s16(dd1_values);
1758       const int16_t ds0_values[] = { dj[0 * d_stride],
1759                                      dj[1 * d_stride],
1760                                      dj[2 * d_stride],
1761                                      dj[3 * d_stride],
1762                                      dj[4 * d_stride],
1763                                      dj[5 * d_stride],
1764                                      0,
1765                                      0 };
1766       ds[0] = vld1q_s16(ds0_values);
1767       int16_t ds1_values[] = { dj[0 * d_stride + width],
1768                                dj[1 * d_stride + width],
1769                                dj[2 * d_stride + width],
1770                                dj[3 * d_stride + width],
1771                                dj[4 * d_stride + width],
1772                                dj[5 * d_stride + width],
1773                                0,
1774                                0 };
1775       ds[1] = vld1q_s16(ds1_values);
1776 
1777       y = 0;
1778       while (y < h8) {
1779         // 00s 10s 20s 30s 40s 50s 60s 70s  00e 10e 20e 30e 40e 50e 60e 70e
1780         dd[0] = vsetq_lane_s16(di[6 * d_stride], dd[0], 6);
1781         dd[0] = vsetq_lane_s16(di[7 * d_stride], dd[0], 7);
1782         dd[1] = vsetq_lane_s16(di[6 * d_stride + width], dd[1], 6);
1783         dd[1] = vsetq_lane_s16(di[7 * d_stride + width], dd[1], 7);
1784 
1785         // 00s 10s 20s 30s 40s 50s 60s 70s  00e 10e 20e 30e 40e 50e 60e 70e
1786         // 01s 11s 21s 31s 41s 51s 61s 71s  01e 11e 21e 31e 41e 51e 61e 71e
1787         ds[0] = vsetq_lane_s16(dj[6 * d_stride], ds[0], 6);
1788         ds[0] = vsetq_lane_s16(dj[7 * d_stride], ds[0], 7);
1789         ds[1] = vsetq_lane_s16(dj[6 * d_stride + width], ds[1], 6);
1790         ds[1] = vsetq_lane_s16(dj[7 * d_stride + width], ds[1], 7);
1791 
1792         load_more_16_neon(di + 8 * d_stride, width, &dd[0], &dd[2]);
1793         load_more_16_neon(dj + 8 * d_stride, width, &ds[0], &ds[2]);
1794         load_more_16_neon(di + 9 * d_stride, width, &dd[2], &dd[4]);
1795         load_more_16_neon(dj + 9 * d_stride, width, &ds[2], &ds[4]);
1796         load_more_16_neon(di + 10 * d_stride, width, &dd[4], &dd[6]);
1797         load_more_16_neon(dj + 10 * d_stride, width, &ds[4], &ds[6]);
1798         load_more_16_neon(di + 11 * d_stride, width, &dd[6], &dd[8]);
1799         load_more_16_neon(dj + 11 * d_stride, width, &ds[6], &ds[8]);
1800         load_more_16_neon(di + 12 * d_stride, width, &dd[8], &dd[10]);
1801         load_more_16_neon(dj + 12 * d_stride, width, &ds[8], &ds[10]);
1802         load_more_16_neon(di + 13 * d_stride, width, &dd[10], &dd[12]);
1803         load_more_16_neon(dj + 13 * d_stride, width, &ds[10], &ds[12]);
1804 
1805         madd_neon(&deltas[0], dd[0], ds[0]);
1806         madd_neon(&deltas[1], dd[1], ds[1]);
1807         madd_neon(&deltas[2], dd[0], ds[2]);
1808         madd_neon(&deltas[3], dd[1], ds[3]);
1809         madd_neon(&deltas[4], dd[0], ds[4]);
1810         madd_neon(&deltas[5], dd[1], ds[5]);
1811         madd_neon(&deltas[6], dd[0], ds[6]);
1812         madd_neon(&deltas[7], dd[1], ds[7]);
1813         madd_neon(&deltas[8], dd[0], ds[8]);
1814         madd_neon(&deltas[9], dd[1], ds[9]);
1815         madd_neon(&deltas[10], dd[0], ds[10]);
1816         madd_neon(&deltas[11], dd[1], ds[11]);
1817         madd_neon(&deltas[12], dd[0], ds[12]);
1818         madd_neon(&deltas[13], dd[1], ds[13]);
1819         madd_neon(&deltas[14], dd[2], ds[0]);
1820         madd_neon(&deltas[15], dd[3], ds[1]);
1821         madd_neon(&deltas[16], dd[4], ds[0]);
1822         madd_neon(&deltas[17], dd[5], ds[1]);
1823         madd_neon(&deltas[18], dd[6], ds[0]);
1824         madd_neon(&deltas[19], dd[7], ds[1]);
1825         madd_neon(&deltas[20], dd[8], ds[0]);
1826         madd_neon(&deltas[21], dd[9], ds[1]);
1827         madd_neon(&deltas[22], dd[10], ds[0]);
1828         madd_neon(&deltas[23], dd[11], ds[1]);
1829         madd_neon(&deltas[24], dd[12], ds[0]);
1830         madd_neon(&deltas[25], dd[13], ds[1]);
1831 
1832         dd[0] = vextq_s16(dd[12], vdupq_n_s16(0), 2);
1833         dd[1] = vextq_s16(dd[13], vdupq_n_s16(0), 2);
1834         ds[0] = vextq_s16(ds[12], vdupq_n_s16(0), 2);
1835         ds[1] = vextq_s16(ds[13], vdupq_n_s16(0), 2);
1836 
1837         di += 8 * d_stride;
1838         dj += 8 * d_stride;
1839         y += 8;
1840       }
1841 
1842       deltas[0] = hadd_four_32_neon(deltas[0], deltas[2], deltas[4], deltas[6]);
1843       deltas[1] = hadd_four_32_neon(deltas[1], deltas[3], deltas[5], deltas[7]);
1844       deltas[2] =
1845           hadd_four_32_neon(deltas[8], deltas[10], deltas[12], deltas[12]);
1846       deltas[3] =
1847           hadd_four_32_neon(deltas[9], deltas[11], deltas[13], deltas[13]);
1848       deltas[4] =
1849           hadd_four_32_neon(deltas[14], deltas[16], deltas[18], deltas[20]);
1850       deltas[5] =
1851           hadd_four_32_neon(deltas[15], deltas[17], deltas[19], deltas[21]);
1852       deltas[6] =
1853           hadd_four_32_neon(deltas[22], deltas[24], deltas[22], deltas[24]);
1854       deltas[7] =
1855           hadd_four_32_neon(deltas[23], deltas[25], deltas[23], deltas[25]);
1856       deltas[0] = vsubq_s32(deltas[1], deltas[0]);
1857       deltas[1] = vsubq_s32(deltas[3], deltas[2]);
1858       deltas[2] = vsubq_s32(deltas[5], deltas[4]);
1859       deltas[3] = vsubq_s32(deltas[7], deltas[6]);
1860 
1861       if (h8 != height) {
1862         const int16_t ds0_vals[] = {
1863           dj[0 * d_stride], dj[0 * d_stride + width],
1864           dj[1 * d_stride], dj[1 * d_stride + width],
1865           dj[2 * d_stride], dj[2 * d_stride + width],
1866           dj[3 * d_stride], dj[3 * d_stride + width]
1867         };
1868         ds[0] = vld1q_s16(ds0_vals);
1869 
1870         ds[1] = vsetq_lane_s16(dj[4 * d_stride], ds[1], 0);
1871         ds[1] = vsetq_lane_s16(dj[4 * d_stride + width], ds[1], 1);
1872         ds[1] = vsetq_lane_s16(dj[5 * d_stride], ds[1], 2);
1873         ds[1] = vsetq_lane_s16(dj[5 * d_stride + width], ds[1], 3);
1874         const int16_t dd4_vals[] = {
1875           -di[1 * d_stride], di[1 * d_stride + width],
1876           -di[2 * d_stride], di[2 * d_stride + width],
1877           -di[3 * d_stride], di[3 * d_stride + width],
1878           -di[4 * d_stride], di[4 * d_stride + width]
1879         };
1880         dd[4] = vld1q_s16(dd4_vals);
1881 
1882         dd[5] = vsetq_lane_s16(-di[5 * d_stride], dd[5], 0);
1883         dd[5] = vsetq_lane_s16(di[5 * d_stride + width], dd[5], 1);
1884         do {
1885           dd[0] = vdupq_n_s16(-di[0 * d_stride]);
1886           dd[2] = dd[3] = vdupq_n_s16(di[0 * d_stride + width]);
1887           dd[0] = dd[1] = vzipq_s16(dd[0], dd[2]).val[0];
1888 
1889           ds[4] = vdupq_n_s16(dj[0 * d_stride]);
1890           ds[6] = ds[7] = vdupq_n_s16(dj[0 * d_stride + width]);
1891           ds[4] = ds[5] = vzipq_s16(ds[4], ds[6]).val[0];
1892 
1893           dd[5] = vsetq_lane_s16(-di[6 * d_stride], dd[5], 2);
1894           dd[5] = vsetq_lane_s16(di[6 * d_stride + width], dd[5], 3);
1895           ds[1] = vsetq_lane_s16(dj[6 * d_stride], ds[1], 4);
1896           ds[1] = vsetq_lane_s16(dj[6 * d_stride + width], ds[1], 5);
1897 
1898           madd_neon_pairwise(&deltas[0], dd[0], ds[0]);
1899           madd_neon_pairwise(&deltas[1], dd[1], ds[1]);
1900           madd_neon_pairwise(&deltas[2], dd[4], ds[4]);
1901           madd_neon_pairwise(&deltas[3], dd[5], ds[5]);
1902 
1903           int32_t tmp0 = vgetq_lane_s32(vreinterpretq_s32_s16(ds[0]), 0);
1904           ds[0] = vextq_s16(ds[0], ds[1], 2);
1905           ds[1] = vextq_s16(ds[1], ds[0], 2);
1906           ds[1] = vreinterpretq_s16_s32(
1907               vsetq_lane_s32(tmp0, vreinterpretq_s32_s16(ds[1]), 3));
1908           int32_t tmp1 = vgetq_lane_s32(vreinterpretq_s32_s16(dd[4]), 0);
1909           dd[4] = vextq_s16(dd[4], dd[5], 2);
1910           dd[5] = vextq_s16(dd[5], dd[4], 2);
1911           dd[5] = vreinterpretq_s16_s32(
1912               vsetq_lane_s32(tmp1, vreinterpretq_s32_s16(dd[5]), 3));
1913           di += d_stride;
1914           dj += d_stride;
1915         } while (++y < height);
1916       }
1917 
1918       // Writing one more element on the top edge of a square falls to
1919       // the next square in the same row or the first element in the next
1920       // row, which will just be overwritten later.
1921       update_8_stats_neon(
1922           H + (i - 1) * wiener_win * wiener_win2 + (j - 1) * wiener_win,
1923           deltas[0], deltas[1],
1924           H + i * wiener_win * wiener_win2 + j * wiener_win);
1925 
1926       H[(i * wiener_win + 1) * wiener_win2 + j * wiener_win] =
1927           H[((i - 1) * wiener_win + 1) * wiener_win2 + (j - 1) * wiener_win] +
1928           vgetq_lane_s32(deltas[2], 0);
1929       H[(i * wiener_win + 2) * wiener_win2 + j * wiener_win] =
1930           H[((i - 1) * wiener_win + 2) * wiener_win2 + (j - 1) * wiener_win] +
1931           vgetq_lane_s32(deltas[2], 1);
1932       H[(i * wiener_win + 3) * wiener_win2 + j * wiener_win] =
1933           H[((i - 1) * wiener_win + 3) * wiener_win2 + (j - 1) * wiener_win] +
1934           vgetq_lane_s32(deltas[2], 2);
1935       H[(i * wiener_win + 4) * wiener_win2 + j * wiener_win] =
1936           H[((i - 1) * wiener_win + 4) * wiener_win2 + (j - 1) * wiener_win] +
1937           vgetq_lane_s32(deltas[2], 3);
1938       H[(i * wiener_win + 5) * wiener_win2 + j * wiener_win] =
1939           H[((i - 1) * wiener_win + 5) * wiener_win2 + (j - 1) * wiener_win] +
1940           vgetq_lane_s32(deltas[3], 0);
1941       H[(i * wiener_win + 6) * wiener_win2 + j * wiener_win] =
1942           H[((i - 1) * wiener_win + 6) * wiener_win2 + (j - 1) * wiener_win] +
1943           vgetq_lane_s32(deltas[3], 1);
1944     } while (++j < wiener_win);
1945   } while (++i < wiener_win - 1);
1946 
1947   // Step 5: Derive other points of each square. No square in bottom row.
1948   i = 0;
1949   do {
1950     const int16_t *const di = d + i;
1951 
1952     j = i + 1;
1953     do {
1954       const int16_t *const dj = d + j;
1955       int32x4_t deltas[WIENER_WIN - 1][WIN_7] = { { vdupq_n_s32(0) },
1956                                                   { vdupq_n_s32(0) } };
1957       int16x8_t d_is[WIN_7];
1958       int16x8_t d_ie[WIN_7];
1959       int16x8_t d_js[WIN_7];
1960       int16x8_t d_je[WIN_7];
1961 
1962       x = 0;
1963       while (x < w16) {
1964         load_square_win7_neon(di + x, dj + x, d_stride, height, d_is, d_ie,
1965                               d_js, d_je);
1966         derive_square_win7_neon(d_is, d_ie, d_js, d_je, deltas);
1967         x += 16;
1968       }
1969 
1970       if (w16 != width) {
1971         load_square_win7_neon(di + x, dj + x, d_stride, height, d_is, d_ie,
1972                               d_js, d_je);
1973         d_is[0] = vandq_s16(d_is[0], mask[0]);
1974         d_is[1] = vandq_s16(d_is[1], mask[1]);
1975         d_is[2] = vandq_s16(d_is[2], mask[0]);
1976         d_is[3] = vandq_s16(d_is[3], mask[1]);
1977         d_is[4] = vandq_s16(d_is[4], mask[0]);
1978         d_is[5] = vandq_s16(d_is[5], mask[1]);
1979         d_is[6] = vandq_s16(d_is[6], mask[0]);
1980         d_is[7] = vandq_s16(d_is[7], mask[1]);
1981         d_is[8] = vandq_s16(d_is[8], mask[0]);
1982         d_is[9] = vandq_s16(d_is[9], mask[1]);
1983         d_is[10] = vandq_s16(d_is[10], mask[0]);
1984         d_is[11] = vandq_s16(d_is[11], mask[1]);
1985         d_ie[0] = vandq_s16(d_ie[0], mask[0]);
1986         d_ie[1] = vandq_s16(d_ie[1], mask[1]);
1987         d_ie[2] = vandq_s16(d_ie[2], mask[0]);
1988         d_ie[3] = vandq_s16(d_ie[3], mask[1]);
1989         d_ie[4] = vandq_s16(d_ie[4], mask[0]);
1990         d_ie[5] = vandq_s16(d_ie[5], mask[1]);
1991         d_ie[6] = vandq_s16(d_ie[6], mask[0]);
1992         d_ie[7] = vandq_s16(d_ie[7], mask[1]);
1993         d_ie[8] = vandq_s16(d_ie[8], mask[0]);
1994         d_ie[9] = vandq_s16(d_ie[9], mask[1]);
1995         d_ie[10] = vandq_s16(d_ie[10], mask[0]);
1996         d_ie[11] = vandq_s16(d_ie[11], mask[1]);
1997         derive_square_win7_neon(d_is, d_ie, d_js, d_je, deltas);
1998       }
1999 
2000       hadd_update_6_stats_neon(
2001           H + (i * wiener_win + 0) * wiener_win2 + j * wiener_win, deltas[0],
2002           H + (i * wiener_win + 1) * wiener_win2 + j * wiener_win + 1);
2003       hadd_update_6_stats_neon(
2004           H + (i * wiener_win + 1) * wiener_win2 + j * wiener_win, deltas[1],
2005           H + (i * wiener_win + 2) * wiener_win2 + j * wiener_win + 1);
2006       hadd_update_6_stats_neon(
2007           H + (i * wiener_win + 2) * wiener_win2 + j * wiener_win, deltas[2],
2008           H + (i * wiener_win + 3) * wiener_win2 + j * wiener_win + 1);
2009       hadd_update_6_stats_neon(
2010           H + (i * wiener_win + 3) * wiener_win2 + j * wiener_win, deltas[3],
2011           H + (i * wiener_win + 4) * wiener_win2 + j * wiener_win + 1);
2012       hadd_update_6_stats_neon(
2013           H + (i * wiener_win + 4) * wiener_win2 + j * wiener_win, deltas[4],
2014           H + (i * wiener_win + 5) * wiener_win2 + j * wiener_win + 1);
2015       hadd_update_6_stats_neon(
2016           H + (i * wiener_win + 5) * wiener_win2 + j * wiener_win, deltas[5],
2017           H + (i * wiener_win + 6) * wiener_win2 + j * wiener_win + 1);
2018     } while (++j < wiener_win);
2019   } while (++i < wiener_win - 1);
2020 
2021   // Step 6: Derive other points of each upper triangle along the diagonal.
2022   i = 0;
2023   do {
2024     const int16_t *const di = d + i;
2025     int32x4_t deltas[WIENER_WIN * (WIENER_WIN - 1)] = { vdupq_n_s32(0) };
2026     int16x8_t d_is[WIN_7], d_ie[WIN_7];
2027 
2028     x = 0;
2029     while (x < w16) {
2030       load_triangle_win7_neon(di + x, d_stride, height, d_is, d_ie);
2031       derive_triangle_win7_neon(d_is, d_ie, deltas);
2032       x += 16;
2033     }
2034 
2035     if (w16 != width) {
2036       load_triangle_win7_neon(di + x, d_stride, height, d_is, d_ie);
2037       d_is[0] = vandq_s16(d_is[0], mask[0]);
2038       d_is[1] = vandq_s16(d_is[1], mask[1]);
2039       d_is[2] = vandq_s16(d_is[2], mask[0]);
2040       d_is[3] = vandq_s16(d_is[3], mask[1]);
2041       d_is[4] = vandq_s16(d_is[4], mask[0]);
2042       d_is[5] = vandq_s16(d_is[5], mask[1]);
2043       d_is[6] = vandq_s16(d_is[6], mask[0]);
2044       d_is[7] = vandq_s16(d_is[7], mask[1]);
2045       d_is[8] = vandq_s16(d_is[8], mask[0]);
2046       d_is[9] = vandq_s16(d_is[9], mask[1]);
2047       d_is[10] = vandq_s16(d_is[10], mask[0]);
2048       d_is[11] = vandq_s16(d_is[11], mask[1]);
2049       d_ie[0] = vandq_s16(d_ie[0], mask[0]);
2050       d_ie[1] = vandq_s16(d_ie[1], mask[1]);
2051       d_ie[2] = vandq_s16(d_ie[2], mask[0]);
2052       d_ie[3] = vandq_s16(d_ie[3], mask[1]);
2053       d_ie[4] = vandq_s16(d_ie[4], mask[0]);
2054       d_ie[5] = vandq_s16(d_ie[5], mask[1]);
2055       d_ie[6] = vandq_s16(d_ie[6], mask[0]);
2056       d_ie[7] = vandq_s16(d_ie[7], mask[1]);
2057       d_ie[8] = vandq_s16(d_ie[8], mask[0]);
2058       d_ie[9] = vandq_s16(d_ie[9], mask[1]);
2059       d_ie[10] = vandq_s16(d_ie[10], mask[0]);
2060       d_ie[11] = vandq_s16(d_ie[11], mask[1]);
2061       derive_triangle_win7_neon(d_is, d_ie, deltas);
2062     }
2063 
2064     // Row 1: 6 points
2065     hadd_update_6_stats_neon(
2066         H + (i * wiener_win + 0) * wiener_win2 + i * wiener_win, deltas,
2067         H + (i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1);
2068 
2069     int32x4_t delta1710 = horizontal_add_2d_s32(deltas[17], deltas[10]);
2070     int32x4_t delta1516 = horizontal_add_2d_s32(deltas[15], deltas[16]);
2071 
2072     int64x2_t delta1710_s64 = vpaddlq_s32(delta1710);
2073     int64x2_t delta1516_s64 = vpaddlq_s32(delta1516);
2074 
2075     // Row 2: 5 points
2076     hadd_update_4_stats_neon(
2077         H + (i * wiener_win + 1) * wiener_win2 + i * wiener_win + 1, deltas + 6,
2078         H + (i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2);
2079     H[(i * wiener_win + 2) * wiener_win2 + i * wiener_win + 6] =
2080         H[(i * wiener_win + 1) * wiener_win2 + i * wiener_win + 5] +
2081         vgetq_lane_s64(delta1710_s64, 1);
2082 
2083     // Row 3: 4 points
2084     hadd_update_4_stats_neon(
2085         H + (i * wiener_win + 2) * wiener_win2 + i * wiener_win + 2,
2086         deltas + 11,
2087         H + (i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3);
2088 
2089     // Row 4: 3 points
2090     int64x2_t h0 =
2091         vld1q_s64(H + (i * wiener_win + 3) * wiener_win2 + i * wiener_win + 3);
2092     vst1q_s64(H + (i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4,
2093               vaddq_s64(h0, delta1516_s64));
2094     H[(i * wiener_win + 4) * wiener_win2 + i * wiener_win + 6] =
2095         H[(i * wiener_win + 3) * wiener_win2 + i * wiener_win + 5] +
2096         vgetq_lane_s64(delta1710_s64, 0);
2097 
2098     int32x4_t delta1819 = horizontal_add_2d_s32(deltas[18], deltas[19]);
2099     int64x2_t delta1819_s64 = vpaddlq_s32(delta1819);
2100 
2101     // Row 5: 2 points
2102     int64x2_t h1 =
2103         vld1q_s64(H + (i * wiener_win + 4) * wiener_win2 + i * wiener_win + 4);
2104     vst1q_s64(H + (i * wiener_win + 5) * wiener_win2 + i * wiener_win + 5,
2105               vaddq_s64(h1, delta1819_s64));
2106 
2107     // Row 6: 1 points
2108     H[(i * wiener_win + 6) * wiener_win2 + i * wiener_win + 6] =
2109         H[(i * wiener_win + 5) * wiener_win2 + i * wiener_win + 5] +
2110         horizontal_long_add_s32x4(deltas[20]);
2111   } while (++i < wiener_win);
2112 }
2113 
find_average_neon(const uint8_t * src,int src_stride,int width,int height)2114 static inline uint8_t find_average_neon(const uint8_t *src, int src_stride,
2115                                         int width, int height) {
2116   uint64_t sum = 0;
2117 
2118   if (width >= 16) {
2119     int h = 0;
2120     // We can accumulate up to 257 8-bit values in a 16-bit value, given
2121     // that each 16-bit vector has 8 elements, that means we can process up to
2122     // int(257*8/width) rows before we need to widen to 32-bit vector
2123     // elements.
2124     int h_overflow = 257 * 8 / width;
2125     int h_limit = height > h_overflow ? h_overflow : height;
2126     uint32x4_t avg_u32 = vdupq_n_u32(0);
2127     do {
2128       uint16x8_t avg_u16 = vdupq_n_u16(0);
2129       do {
2130         int j = width;
2131         const uint8_t *src_ptr = src;
2132         do {
2133           uint8x16_t s = vld1q_u8(src_ptr);
2134           avg_u16 = vpadalq_u8(avg_u16, s);
2135           j -= 16;
2136           src_ptr += 16;
2137         } while (j >= 16);
2138         if (j >= 8) {
2139           uint8x8_t s = vld1_u8(src_ptr);
2140           avg_u16 = vaddw_u8(avg_u16, s);
2141           j -= 8;
2142           src_ptr += 8;
2143         }
2144         // Scalar tail case.
2145         while (j > 0) {
2146           sum += src[width - j];
2147           j--;
2148         }
2149         src += src_stride;
2150       } while (++h < h_limit);
2151       avg_u32 = vpadalq_u16(avg_u32, avg_u16);
2152 
2153       h_limit += h_overflow;
2154       h_limit = height > h_overflow ? h_overflow : height;
2155     } while (h < height);
2156     return (uint8_t)((horizontal_long_add_u32x4(avg_u32) + sum) /
2157                      (width * height));
2158   }
2159   if (width >= 8) {
2160     int h = 0;
2161     // We can accumulate up to 257 8-bit values in a 16-bit value, given
2162     // that each 16-bit vector has 4 elements, that means we can process up to
2163     // int(257*4/width) rows before we need to widen to 32-bit vector
2164     // elements.
2165     int h_overflow = 257 * 4 / width;
2166     int h_limit = height > h_overflow ? h_overflow : height;
2167     uint32x2_t avg_u32 = vdup_n_u32(0);
2168     do {
2169       uint16x4_t avg_u16 = vdup_n_u16(0);
2170       do {
2171         int j = width;
2172         const uint8_t *src_ptr = src;
2173         uint8x8_t s = vld1_u8(src_ptr);
2174         avg_u16 = vpadal_u8(avg_u16, s);
2175         j -= 8;
2176         src_ptr += 8;
2177         // Scalar tail case.
2178         while (j > 0) {
2179           sum += src[width - j];
2180           j--;
2181         }
2182         src += src_stride;
2183       } while (++h < h_limit);
2184       avg_u32 = vpadal_u16(avg_u32, avg_u16);
2185 
2186       h_limit += h_overflow;
2187       h_limit = height > h_overflow ? h_overflow : height;
2188     } while (h < height);
2189     return (uint8_t)((horizontal_long_add_u32x2(avg_u32) + sum) /
2190                      (width * height));
2191   }
2192   int i = height;
2193   do {
2194     int j = 0;
2195     do {
2196       sum += src[j];
2197     } while (++j < width);
2198     src += src_stride;
2199   } while (--i != 0);
2200   return (uint8_t)(sum / (width * height));
2201 }
2202 
compute_sub_avg(const uint8_t * buf,int buf_stride,int avg,int16_t * buf_avg,int buf_avg_stride,int width,int height,int downsample_factor)2203 static inline void compute_sub_avg(const uint8_t *buf, int buf_stride, int avg,
2204                                    int16_t *buf_avg, int buf_avg_stride,
2205                                    int width, int height,
2206                                    int downsample_factor) {
2207   uint8x8_t avg_u8 = vdup_n_u8(avg);
2208 
2209   if (width > 8) {
2210     int i = 0;
2211     do {
2212       int j = width;
2213       const uint8_t *buf_ptr = buf;
2214       int16_t *buf_avg_ptr = buf_avg;
2215       do {
2216         uint8x8_t d = vld1_u8(buf_ptr);
2217         vst1q_s16(buf_avg_ptr, vreinterpretq_s16_u16(vsubl_u8(d, avg_u8)));
2218 
2219         j -= 8;
2220         buf_ptr += 8;
2221         buf_avg_ptr += 8;
2222       } while (j >= 8);
2223       while (j > 0) {
2224         *buf_avg_ptr = (int16_t)buf[width - j] - (int16_t)avg;
2225         buf_avg_ptr++;
2226         j--;
2227       }
2228       buf += buf_stride;
2229       buf_avg += buf_avg_stride;
2230       i += downsample_factor;
2231     } while (i < height);
2232   } else {
2233     // For width < 8, don't use Neon.
2234     for (int i = 0; i < height; i = i + downsample_factor) {
2235       for (int j = 0; j < width; j++) {
2236         buf_avg[j] = (int16_t)buf[j] - (int16_t)avg;
2237       }
2238       buf += buf_stride;
2239       buf_avg += buf_avg_stride;
2240     }
2241   }
2242 }
2243 
av1_compute_stats_downsampled_neon(int wiener_win,const uint8_t * dgd,const uint8_t * src,int16_t * dgd_avg,int16_t * src_avg,int h_start,int h_end,int v_start,int v_end,int dgd_stride,int src_stride,int64_t * M,int64_t * H,int use_downsampled_wiener_stats)2244 static inline void av1_compute_stats_downsampled_neon(
2245     int wiener_win, const uint8_t *dgd, const uint8_t *src, int16_t *dgd_avg,
2246     int16_t *src_avg, int h_start, int h_end, int v_start, int v_end,
2247     int dgd_stride, int src_stride, int64_t *M, int64_t *H,
2248     int use_downsampled_wiener_stats) {
2249   assert(wiener_win == WIENER_WIN || wiener_win == WIENER_WIN_CHROMA);
2250   assert(WIENER_STATS_DOWNSAMPLE_FACTOR == 4);
2251   (void)dgd_avg;
2252   (void)src_avg;
2253 
2254   const int wiener_win2 = wiener_win * wiener_win;
2255   const int wiener_halfwin = wiener_win >> 1;
2256   const int width = h_end - h_start;
2257   const int height = v_end - v_start;
2258 
2259   const uint8_t *dgd_start = dgd + h_start + v_start * dgd_stride;
2260   const uint8_t *src_start = src + h_start + v_start * src_stride;
2261 
2262   // The wiener window will slide along the dgd frame, centered on each pixel.
2263   // For the top left pixel and all the pixels on the side of the frame this
2264   // means half of the window will be outside of the frame. As such the actual
2265   // buffer that we need to subtract the avg from will be 2 * wiener_halfwin
2266   // wider and 2 * wiener_halfwin higher than the original dgd buffer.
2267   const int vert_offset = v_start - wiener_halfwin;
2268   const int horiz_offset = h_start - wiener_halfwin;
2269   const uint8_t *dgd_win = dgd + horiz_offset + vert_offset * dgd_stride;
2270 
2271   uint8_t avg = find_average_neon(dgd_start, dgd_stride, width, height);
2272 
2273   // Since the height is not necessarily a multiple of the downsample factor,
2274   // the last line of src will be scaled according to how many rows remain.
2275   int downsample_factor =
2276       use_downsampled_wiener_stats ? WIENER_STATS_DOWNSAMPLE_FACTOR : 1;
2277 
2278   int downsampled_height = height / downsample_factor;
2279   int downsample_remainder = height % downsample_factor;
2280 
2281   memset(M, 0, wiener_win2 * sizeof(*M));
2282   memset(H, 0, wiener_win2 * wiener_win2 * sizeof(*H));
2283 
2284   // Calculate the M and H matrices for the normal and downsampled cases.
2285   if (downsampled_height > 0) {
2286     if (wiener_win == WIENER_WIN) {
2287       compute_stats_win7_downsampled_neon(
2288           dgd_win, src_start, width, downsampled_height, dgd_stride, src_stride,
2289           avg, M, H, downsample_factor);
2290     } else {
2291       compute_stats_win5_downsampled_neon(
2292           dgd_win, src_start, width, downsampled_height, dgd_stride, src_stride,
2293           avg, M, H, downsample_factor);
2294     }
2295   }
2296 
2297   // Accumulate the remaining last rows in the downsampled case.
2298   if (downsample_remainder > 0) {
2299     int remainder_offset = height - downsample_remainder;
2300     if (wiener_win == WIENER_WIN) {
2301       compute_stats_win7_downsampled_neon(
2302           dgd_win + remainder_offset * dgd_stride,
2303           src_start + remainder_offset * src_stride, width, 1, dgd_stride,
2304           src_stride, avg, M, H, downsample_remainder);
2305     } else {
2306       compute_stats_win5_downsampled_neon(
2307           dgd_win + remainder_offset * dgd_stride,
2308           src_start + remainder_offset * src_stride, width, 1, dgd_stride,
2309           src_stride, avg, M, H, downsample_remainder);
2310     }
2311   }
2312 }
2313 
av1_compute_stats_neon(int32_t wiener_win,const uint8_t * dgd,const uint8_t * src,int16_t * dgd_avg,int16_t * src_avg,int32_t h_start,int32_t h_end,int32_t v_start,int32_t v_end,int32_t dgd_stride,int32_t src_stride,int64_t * M,int64_t * H,int use_downsampled_wiener_stats)2314 void av1_compute_stats_neon(int32_t wiener_win, const uint8_t *dgd,
2315                             const uint8_t *src, int16_t *dgd_avg,
2316                             int16_t *src_avg, int32_t h_start, int32_t h_end,
2317                             int32_t v_start, int32_t v_end, int32_t dgd_stride,
2318                             int32_t src_stride, int64_t *M, int64_t *H,
2319                             int use_downsampled_wiener_stats) {
2320   assert(WIENER_STATS_DOWNSAMPLE_FACTOR == 4);
2321   if (use_downsampled_wiener_stats) {
2322     av1_compute_stats_downsampled_neon(
2323         wiener_win, dgd, src, dgd_avg, src_avg, h_start, h_end, v_start, v_end,
2324         dgd_stride, src_stride, M, H, use_downsampled_wiener_stats);
2325     return;
2326   }
2327 
2328   const int32_t wiener_win2 = wiener_win * wiener_win;
2329   const int32_t wiener_halfwin = (wiener_win >> 1);
2330   const int32_t width = h_end - h_start;
2331   const int32_t height = v_end - v_start;
2332   const uint8_t *dgd_start = dgd + h_start + v_start * dgd_stride;
2333   const uint8_t avg = find_average_neon(dgd_start, dgd_stride, width, height);
2334   const int32_t d_stride = (width + 2 * wiener_halfwin + 15) & ~15;
2335   const int32_t s_stride = (width + 15) & ~15;
2336 
2337   compute_sub_avg(src + v_start * src_stride + h_start, src_stride, avg,
2338                   src_avg, s_stride, width, height, 1);
2339   compute_sub_avg(
2340       dgd + (v_start - wiener_halfwin) * dgd_stride + h_start - wiener_halfwin,
2341       dgd_stride, avg, dgd_avg, d_stride, width + 2 * wiener_halfwin,
2342       height + 2 * wiener_halfwin, 1);
2343 
2344   if (wiener_win == WIENER_WIN) {
2345     compute_stats_win7_neon(dgd_avg, d_stride, src_avg, s_stride, width, height,
2346                             M, H);
2347   } else if (wiener_win == WIENER_WIN_CHROMA) {
2348     compute_stats_win5_neon(dgd_avg, d_stride, src_avg, s_stride, width, height,
2349                             M, H);
2350   }
2351 
2352   // H is a symmetric matrix, so we only need to fill out the upper triangle.
2353   // We can copy it down to the lower triangle outside the (i, j) loops.
2354   diagonal_copy_stats_neon(wiener_win2, H);
2355 }
2356 
calc_proj_params_r0_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])2357 static inline void calc_proj_params_r0_r1_neon(
2358     const uint8_t *src8, int width, int height, int src_stride,
2359     const uint8_t *dat8, int dat_stride, int32_t *flt0, int flt0_stride,
2360     int32_t *flt1, int flt1_stride, int64_t H[2][2], int64_t C[2]) {
2361   assert(width % 8 == 0);
2362   const int size = width * height;
2363 
2364   int64x2_t h00_lo = vdupq_n_s64(0);
2365   int64x2_t h00_hi = vdupq_n_s64(0);
2366   int64x2_t h11_lo = vdupq_n_s64(0);
2367   int64x2_t h11_hi = vdupq_n_s64(0);
2368   int64x2_t h01_lo = vdupq_n_s64(0);
2369   int64x2_t h01_hi = vdupq_n_s64(0);
2370   int64x2_t c0_lo = vdupq_n_s64(0);
2371   int64x2_t c0_hi = vdupq_n_s64(0);
2372   int64x2_t c1_lo = vdupq_n_s64(0);
2373   int64x2_t c1_hi = vdupq_n_s64(0);
2374 
2375   do {
2376     const uint8_t *src_ptr = src8;
2377     const uint8_t *dat_ptr = dat8;
2378     int32_t *flt0_ptr = flt0;
2379     int32_t *flt1_ptr = flt1;
2380     int w = width;
2381 
2382     do {
2383       uint8x8_t s = vld1_u8(src_ptr);
2384       uint8x8_t d = vld1_u8(dat_ptr);
2385       int32x4_t f0_lo = vld1q_s32(flt0_ptr);
2386       int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
2387       int32x4_t f1_lo = vld1q_s32(flt1_ptr);
2388       int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
2389 
2390       int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
2391       int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
2392 
2393       int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
2394       int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
2395       f0_lo = vsubw_s16(f0_lo, vget_low_s16(u));
2396       f0_hi = vsubw_s16(f0_hi, vget_high_s16(u));
2397       f1_lo = vsubw_s16(f1_lo, vget_low_s16(u));
2398       f1_hi = vsubw_s16(f1_hi, vget_high_s16(u));
2399 
2400       h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
2401       h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
2402       h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
2403       h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
2404 
2405       h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
2406       h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
2407       h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
2408       h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
2409 
2410       h01_lo = vmlal_s32(h01_lo, vget_low_s32(f0_lo), vget_low_s32(f1_lo));
2411       h01_lo = vmlal_s32(h01_lo, vget_high_s32(f0_lo), vget_high_s32(f1_lo));
2412       h01_hi = vmlal_s32(h01_hi, vget_low_s32(f0_hi), vget_low_s32(f1_hi));
2413       h01_hi = vmlal_s32(h01_hi, vget_high_s32(f0_hi), vget_high_s32(f1_hi));
2414 
2415       c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
2416       c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
2417       c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
2418       c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
2419 
2420       c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
2421       c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
2422       c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
2423       c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
2424 
2425       src_ptr += 8;
2426       dat_ptr += 8;
2427       flt0_ptr += 8;
2428       flt1_ptr += 8;
2429       w -= 8;
2430     } while (w != 0);
2431 
2432     src8 += src_stride;
2433     dat8 += dat_stride;
2434     flt0 += flt0_stride;
2435     flt1 += flt1_stride;
2436   } while (--height != 0);
2437 
2438   H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
2439   H[0][1] = horizontal_add_s64x2(vaddq_s64(h01_lo, h01_hi)) / size;
2440   H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
2441   H[1][0] = H[0][1];
2442   C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
2443   C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
2444 }
2445 
calc_proj_params_r0_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int64_t H[2][2],int64_t C[2])2446 static inline void calc_proj_params_r0_neon(const uint8_t *src8, int width,
2447                                             int height, int src_stride,
2448                                             const uint8_t *dat8, int dat_stride,
2449                                             int32_t *flt0, int flt0_stride,
2450                                             int64_t H[2][2], int64_t C[2]) {
2451   assert(width % 8 == 0);
2452   const int size = width * height;
2453 
2454   int64x2_t h00_lo = vdupq_n_s64(0);
2455   int64x2_t h00_hi = vdupq_n_s64(0);
2456   int64x2_t c0_lo = vdupq_n_s64(0);
2457   int64x2_t c0_hi = vdupq_n_s64(0);
2458 
2459   do {
2460     const uint8_t *src_ptr = src8;
2461     const uint8_t *dat_ptr = dat8;
2462     int32_t *flt0_ptr = flt0;
2463     int w = width;
2464 
2465     do {
2466       uint8x8_t s = vld1_u8(src_ptr);
2467       uint8x8_t d = vld1_u8(dat_ptr);
2468       int32x4_t f0_lo = vld1q_s32(flt0_ptr);
2469       int32x4_t f0_hi = vld1q_s32(flt0_ptr + 4);
2470 
2471       int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
2472       int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
2473 
2474       int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
2475       int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
2476       f0_lo = vsubw_s16(f0_lo, vget_low_s16(u));
2477       f0_hi = vsubw_s16(f0_hi, vget_high_s16(u));
2478 
2479       h00_lo = vmlal_s32(h00_lo, vget_low_s32(f0_lo), vget_low_s32(f0_lo));
2480       h00_lo = vmlal_s32(h00_lo, vget_high_s32(f0_lo), vget_high_s32(f0_lo));
2481       h00_hi = vmlal_s32(h00_hi, vget_low_s32(f0_hi), vget_low_s32(f0_hi));
2482       h00_hi = vmlal_s32(h00_hi, vget_high_s32(f0_hi), vget_high_s32(f0_hi));
2483 
2484       c0_lo = vmlal_s32(c0_lo, vget_low_s32(f0_lo), vget_low_s32(s_lo));
2485       c0_lo = vmlal_s32(c0_lo, vget_high_s32(f0_lo), vget_high_s32(s_lo));
2486       c0_hi = vmlal_s32(c0_hi, vget_low_s32(f0_hi), vget_low_s32(s_hi));
2487       c0_hi = vmlal_s32(c0_hi, vget_high_s32(f0_hi), vget_high_s32(s_hi));
2488 
2489       src_ptr += 8;
2490       dat_ptr += 8;
2491       flt0_ptr += 8;
2492       w -= 8;
2493     } while (w != 0);
2494 
2495     src8 += src_stride;
2496     dat8 += dat_stride;
2497     flt0 += flt0_stride;
2498   } while (--height != 0);
2499 
2500   H[0][0] = horizontal_add_s64x2(vaddq_s64(h00_lo, h00_hi)) / size;
2501   C[0] = horizontal_add_s64x2(vaddq_s64(c0_lo, c0_hi)) / size;
2502 }
2503 
calc_proj_params_r1_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2])2504 static inline void calc_proj_params_r1_neon(const uint8_t *src8, int width,
2505                                             int height, int src_stride,
2506                                             const uint8_t *dat8, int dat_stride,
2507                                             int32_t *flt1, int flt1_stride,
2508                                             int64_t H[2][2], int64_t C[2]) {
2509   assert(width % 8 == 0);
2510   const int size = width * height;
2511 
2512   int64x2_t h11_lo = vdupq_n_s64(0);
2513   int64x2_t h11_hi = vdupq_n_s64(0);
2514   int64x2_t c1_lo = vdupq_n_s64(0);
2515   int64x2_t c1_hi = vdupq_n_s64(0);
2516 
2517   do {
2518     const uint8_t *src_ptr = src8;
2519     const uint8_t *dat_ptr = dat8;
2520     int32_t *flt1_ptr = flt1;
2521     int w = width;
2522 
2523     do {
2524       uint8x8_t s = vld1_u8(src_ptr);
2525       uint8x8_t d = vld1_u8(dat_ptr);
2526       int32x4_t f1_lo = vld1q_s32(flt1_ptr);
2527       int32x4_t f1_hi = vld1q_s32(flt1_ptr + 4);
2528 
2529       int16x8_t u = vreinterpretq_s16_u16(vshll_n_u8(d, SGRPROJ_RST_BITS));
2530       int16x8_t s_s16 = vreinterpretq_s16_u16(vshll_n_u8(s, SGRPROJ_RST_BITS));
2531 
2532       int32x4_t s_lo = vsubl_s16(vget_low_s16(s_s16), vget_low_s16(u));
2533       int32x4_t s_hi = vsubl_s16(vget_high_s16(s_s16), vget_high_s16(u));
2534       f1_lo = vsubw_s16(f1_lo, vget_low_s16(u));
2535       f1_hi = vsubw_s16(f1_hi, vget_high_s16(u));
2536 
2537       h11_lo = vmlal_s32(h11_lo, vget_low_s32(f1_lo), vget_low_s32(f1_lo));
2538       h11_lo = vmlal_s32(h11_lo, vget_high_s32(f1_lo), vget_high_s32(f1_lo));
2539       h11_hi = vmlal_s32(h11_hi, vget_low_s32(f1_hi), vget_low_s32(f1_hi));
2540       h11_hi = vmlal_s32(h11_hi, vget_high_s32(f1_hi), vget_high_s32(f1_hi));
2541 
2542       c1_lo = vmlal_s32(c1_lo, vget_low_s32(f1_lo), vget_low_s32(s_lo));
2543       c1_lo = vmlal_s32(c1_lo, vget_high_s32(f1_lo), vget_high_s32(s_lo));
2544       c1_hi = vmlal_s32(c1_hi, vget_low_s32(f1_hi), vget_low_s32(s_hi));
2545       c1_hi = vmlal_s32(c1_hi, vget_high_s32(f1_hi), vget_high_s32(s_hi));
2546 
2547       src_ptr += 8;
2548       dat_ptr += 8;
2549       flt1_ptr += 8;
2550       w -= 8;
2551     } while (w != 0);
2552 
2553     src8 += src_stride;
2554     dat8 += dat_stride;
2555     flt1 += flt1_stride;
2556   } while (--height != 0);
2557 
2558   H[1][1] = horizontal_add_s64x2(vaddq_s64(h11_lo, h11_hi)) / size;
2559   C[1] = horizontal_add_s64x2(vaddq_s64(c1_lo, c1_hi)) / size;
2560 }
2561 
2562 // The function calls 3 subfunctions for the following cases :
2563 // 1) When params->r[0] > 0 and params->r[1] > 0. In this case all elements
2564 //    of C and H need to be computed.
2565 // 2) When only params->r[0] > 0. In this case only H[0][0] and C[0] are
2566 //    non-zero and need to be computed.
2567 // 3) When only params->r[1] > 0. In this case only H[1][1] and C[1] are
2568 //    non-zero and need to be computed.
av1_calc_proj_params_neon(const uint8_t * src8,int width,int height,int src_stride,const uint8_t * dat8,int dat_stride,int32_t * flt0,int flt0_stride,int32_t * flt1,int flt1_stride,int64_t H[2][2],int64_t C[2],const sgr_params_type * params)2569 void av1_calc_proj_params_neon(const uint8_t *src8, int width, int height,
2570                                int src_stride, const uint8_t *dat8,
2571                                int dat_stride, int32_t *flt0, int flt0_stride,
2572                                int32_t *flt1, int flt1_stride, int64_t H[2][2],
2573                                int64_t C[2], const sgr_params_type *params) {
2574   if ((params->r[0] > 0) && (params->r[1] > 0)) {
2575     calc_proj_params_r0_r1_neon(src8, width, height, src_stride, dat8,
2576                                 dat_stride, flt0, flt0_stride, flt1,
2577                                 flt1_stride, H, C);
2578   } else if (params->r[0] > 0) {
2579     calc_proj_params_r0_neon(src8, width, height, src_stride, dat8, dat_stride,
2580                              flt0, flt0_stride, H, C);
2581   } else if (params->r[1] > 0) {
2582     calc_proj_params_r1_neon(src8, width, height, src_stride, dat8, dat_stride,
2583                              flt1, flt1_stride, H, C);
2584   }
2585 }
2586