xref: /aosp_15_r20/external/libaom/av1/encoder/arm/encodetxb_neon.c (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2016, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #include <arm_neon.h>
13 #include <assert.h>
14 #include <math.h>
15 
16 #include "config/aom_config.h"
17 
18 #include "aom_dsp/arm/mem_neon.h"
19 #include "av1/common/txb_common.h"
20 #include "av1/encoder/encodetxb.h"
21 
av1_txb_init_levels_neon(const tran_low_t * const coeff,const int width,const int height,uint8_t * const levels)22 void av1_txb_init_levels_neon(const tran_low_t *const coeff, const int width,
23                               const int height, uint8_t *const levels) {
24   const int stride = height + TX_PAD_HOR;
25   memset(levels - TX_PAD_TOP * stride, 0,
26          sizeof(*levels) * TX_PAD_TOP * stride);
27   memset(levels + stride * width, 0,
28          sizeof(*levels) * (TX_PAD_BOTTOM * stride + TX_PAD_END));
29 
30   const int32x4_t zeros = vdupq_n_s32(0);
31   int i = 0;
32   uint8_t *ls = levels;
33   const tran_low_t *cf = coeff;
34   if (height == 4) {
35     do {
36       const int32x4_t coeffA = vld1q_s32(cf);
37       const int32x4_t coeffB = vld1q_s32(cf + height);
38       const int16x8_t coeffAB =
39           vcombine_s16(vqmovn_s32(coeffA), vqmovn_s32(coeffB));
40       const int16x8_t absAB = vqabsq_s16(coeffAB);
41       const int8x8_t absABs = vqmovn_s16(absAB);
42 #if AOM_ARCH_AARCH64
43       const int8x16_t absAB8 =
44           vcombine_s8(absABs, vreinterpret_s8_s32(vget_low_s32(zeros)));
45       const uint8x16_t lsAB =
46           vreinterpretq_u8_s32(vzip1q_s32(vreinterpretq_s32_s8(absAB8), zeros));
47 #else
48       const int32x2x2_t absAB8 =
49           vzip_s32(vreinterpret_s32_s8(absABs), vget_low_s32(zeros));
50       const uint8x16_t lsAB =
51           vreinterpretq_u8_s32(vcombine_s32(absAB8.val[0], absAB8.val[1]));
52 #endif
53       vst1q_u8(ls, lsAB);
54       ls += (stride << 1);
55       cf += (height << 1);
56       i += 2;
57     } while (i < width);
58   } else if (height == 8) {
59     do {
60       const int16x8_t coeffAB = load_tran_low_to_s16q(cf);
61       const int16x8_t absAB = vqabsq_s16(coeffAB);
62       const uint8x16_t absAB8 = vreinterpretq_u8_s8(vcombine_s8(
63           vqmovn_s16(absAB), vreinterpret_s8_s32(vget_low_s32(zeros))));
64       vst1q_u8(ls, absAB8);
65       ls += stride;
66       cf += height;
67       i += 1;
68     } while (i < width);
69   } else {
70     do {
71       int j = 0;
72       do {
73         const int16x8_t coeffAB = load_tran_low_to_s16q(cf);
74         const int16x8_t coeffCD = load_tran_low_to_s16q(cf + 8);
75         const int16x8_t absAB = vqabsq_s16(coeffAB);
76         const int16x8_t absCD = vqabsq_s16(coeffCD);
77         const uint8x16_t absABCD = vreinterpretq_u8_s8(
78             vcombine_s8(vqmovn_s16(absAB), vqmovn_s16(absCD)));
79         vst1q_u8((ls + j), absABCD);
80         j += 16;
81         cf += 16;
82       } while (j < height);
83       *(int32_t *)(ls + height) = 0;
84       ls += stride;
85       i += 1;
86     } while (i < width);
87   }
88 }
89 
90 // get_4_nz_map_contexts_2d coefficients:
91 static const DECLARE_ALIGNED(16, uint8_t, c_4_po_2d[2][16]) = {
92   { 0, 1, 6, 6, 1, 6, 6, 21, 6, 6, 21, 21, 6, 21, 21, 21 },
93   { 0, 16, 16, 16, 16, 16, 16, 16, 6, 6, 21, 21, 6, 21, 21, 21 }
94 };
95 
96 // get_4_nz_map_contexts_hor coefficients:
97 /* clang-format off */
98 #define SIG_COEF_CONTEXTS_2D_X4_051010                        \
99   (SIG_COEF_CONTEXTS_2D + ((SIG_COEF_CONTEXTS_2D + 5) << 8) + \
100   ((SIG_COEF_CONTEXTS_2D + 10) << 16) + ((SIG_COEF_CONTEXTS_2D + 10) << 24))
101 /* clang-format on */
102 
103 // get_4_nz_map_contexts_ver coefficients:
104 static const DECLARE_ALIGNED(16, uint8_t, c_4_po_hor[16]) = {
105   SIG_COEF_CONTEXTS_2D + 0,  SIG_COEF_CONTEXTS_2D + 0,
106   SIG_COEF_CONTEXTS_2D + 0,  SIG_COEF_CONTEXTS_2D + 0,
107   SIG_COEF_CONTEXTS_2D + 5,  SIG_COEF_CONTEXTS_2D + 5,
108   SIG_COEF_CONTEXTS_2D + 5,  SIG_COEF_CONTEXTS_2D + 5,
109   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
110   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
111   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
112   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10
113 };
114 
115 // get_8_coeff_contexts_2d coefficients:
116 // if (width == 8)
117 static const DECLARE_ALIGNED(16, uint8_t, c_8_po_2d_8[2][16]) = {
118   { 0, 1, 6, 6, 21, 21, 21, 21, 1, 6, 6, 21, 21, 21, 21, 21 },
119   { 6, 6, 21, 21, 21, 21, 21, 21, 6, 21, 21, 21, 21, 21, 21, 21 }
120 };
121 // if (width < 8)
122 static const DECLARE_ALIGNED(16, uint8_t, c_8_po_2d_l[2][16]) = {
123   { 0, 11, 6, 6, 21, 21, 21, 21, 11, 11, 6, 21, 21, 21, 21, 21 },
124   { 11, 11, 21, 21, 21, 21, 21, 21, 11, 11, 21, 21, 21, 21, 21, 21 }
125 };
126 
127 // if (width > 8)
128 static const DECLARE_ALIGNED(16, uint8_t, c_8_po_2d_g[2][16]) = {
129   { 0, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16 },
130   { 6, 6, 21, 21, 21, 21, 21, 21, 6, 21, 21, 21, 21, 21, 21, 21 }
131 };
132 
133 // get_4_nz_map_contexts_ver coefficients:
134 static const DECLARE_ALIGNED(16, uint8_t, c_8_po_ver[16]) = {
135   SIG_COEF_CONTEXTS_2D + 0,  SIG_COEF_CONTEXTS_2D + 5,
136   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
137   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
138   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
139   SIG_COEF_CONTEXTS_2D + 0,  SIG_COEF_CONTEXTS_2D + 5,
140   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
141   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
142   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10
143 };
144 
145 // get_16n_coeff_contexts_2d coefficients:
146 // real_width == real_height
147 static const DECLARE_ALIGNED(16, uint8_t, c_16_po_2d_e[4][16]) = {
148   { 0, 1, 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 },
149   { 1, 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 },
150   { 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 },
151   { 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }
152 };
153 
154 // real_width < real_height
155 static const DECLARE_ALIGNED(16, uint8_t, c_16_po_2d_g[3][16]) = {
156   { 0, 11, 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 },
157   { 11, 11, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 },
158   { 11, 11, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }
159 };
160 
161 // real_width > real_height
162 static const DECLARE_ALIGNED(16, uint8_t, c_16_po_2d_l[3][16]) = {
163   { 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16, 16 },
164   { 6, 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 },
165   { 6, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21 }
166 };
167 
168 // get_16n_coeff_contexts_hor coefficients:
169 static const DECLARE_ALIGNED(16, uint8_t, c_16_po_ver[16]) = {
170   SIG_COEF_CONTEXTS_2D + 0,  SIG_COEF_CONTEXTS_2D + 5,
171   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
172   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
173   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
174   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
175   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
176   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10,
177   SIG_COEF_CONTEXTS_2D + 10, SIG_COEF_CONTEXTS_2D + 10
178 };
179 
180 // end of coefficients declaration area
181 
load_8bit_4x4_to_1_reg(const uint8_t * const src,const int byte_stride)182 static inline uint8x16_t load_8bit_4x4_to_1_reg(const uint8_t *const src,
183                                                 const int byte_stride) {
184 #if AOM_ARCH_AARCH64
185   uint32x4_t v_data = vld1q_u32((uint32_t *)src);
186   v_data = vld1q_lane_u32((uint32_t *)(src + 1 * byte_stride), v_data, 1);
187   v_data = vld1q_lane_u32((uint32_t *)(src + 2 * byte_stride), v_data, 2);
188   v_data = vld1q_lane_u32((uint32_t *)(src + 3 * byte_stride), v_data, 3);
189 
190   return vreinterpretq_u8_u32(v_data);
191 #else
192   return load_unaligned_u8q(src, byte_stride);
193 #endif
194 }
195 
load_8bit_8x2_to_1_reg(const uint8_t * const src,const int byte_stride)196 static inline uint8x16_t load_8bit_8x2_to_1_reg(const uint8_t *const src,
197                                                 const int byte_stride) {
198 #if AOM_ARCH_AARCH64
199   uint64x2_t v_data = vld1q_u64((uint64_t *)src);
200   v_data = vld1q_lane_u64((uint64_t *)(src + 1 * byte_stride), v_data, 1);
201 
202   return vreinterpretq_u8_u64(v_data);
203 #else
204   uint8x8_t v_data_low = vld1_u8(src);
205   uint8x8_t v_data_high = vld1_u8(src + byte_stride);
206 
207   return vcombine_u8(v_data_low, v_data_high);
208 #endif
209 }
210 
load_8bit_16x1_to_1_reg(const uint8_t * const src,const int byte_stride)211 static inline uint8x16_t load_8bit_16x1_to_1_reg(const uint8_t *const src,
212                                                  const int byte_stride) {
213   (void)byte_stride;
214   return vld1q_u8(src);
215 }
216 
load_levels_4x4x5(const uint8_t * const src,const int stride,const ptrdiff_t * const offsets,uint8x16_t * const level)217 static inline void load_levels_4x4x5(const uint8_t *const src, const int stride,
218                                      const ptrdiff_t *const offsets,
219                                      uint8x16_t *const level) {
220   level[0] = load_8bit_4x4_to_1_reg(&src[1], stride);
221   level[1] = load_8bit_4x4_to_1_reg(&src[stride], stride);
222   level[2] = load_8bit_4x4_to_1_reg(&src[offsets[0]], stride);
223   level[3] = load_8bit_4x4_to_1_reg(&src[offsets[1]], stride);
224   level[4] = load_8bit_4x4_to_1_reg(&src[offsets[2]], stride);
225 }
226 
load_levels_8x2x5(const uint8_t * const src,const int stride,const ptrdiff_t * const offsets,uint8x16_t * const level)227 static inline void load_levels_8x2x5(const uint8_t *const src, const int stride,
228                                      const ptrdiff_t *const offsets,
229                                      uint8x16_t *const level) {
230   level[0] = load_8bit_8x2_to_1_reg(&src[1], stride);
231   level[1] = load_8bit_8x2_to_1_reg(&src[stride], stride);
232   level[2] = load_8bit_8x2_to_1_reg(&src[offsets[0]], stride);
233   level[3] = load_8bit_8x2_to_1_reg(&src[offsets[1]], stride);
234   level[4] = load_8bit_8x2_to_1_reg(&src[offsets[2]], stride);
235 }
236 
load_levels_16x1x5(const uint8_t * const src,const int stride,const ptrdiff_t * const offsets,uint8x16_t * const level)237 static inline void load_levels_16x1x5(const uint8_t *const src,
238                                       const int stride,
239                                       const ptrdiff_t *const offsets,
240                                       uint8x16_t *const level) {
241   level[0] = load_8bit_16x1_to_1_reg(&src[1], stride);
242   level[1] = load_8bit_16x1_to_1_reg(&src[stride], stride);
243   level[2] = load_8bit_16x1_to_1_reg(&src[offsets[0]], stride);
244   level[3] = load_8bit_16x1_to_1_reg(&src[offsets[1]], stride);
245   level[4] = load_8bit_16x1_to_1_reg(&src[offsets[2]], stride);
246 }
247 
get_coeff_contexts_kernel(uint8x16_t * const level)248 static inline uint8x16_t get_coeff_contexts_kernel(uint8x16_t *const level) {
249   const uint8x16_t const_3 = vdupq_n_u8(3);
250   const uint8x16_t const_4 = vdupq_n_u8(4);
251   uint8x16_t count;
252 
253   count = vminq_u8(level[0], const_3);
254   level[1] = vminq_u8(level[1], const_3);
255   level[2] = vminq_u8(level[2], const_3);
256   level[3] = vminq_u8(level[3], const_3);
257   level[4] = vminq_u8(level[4], const_3);
258   count = vaddq_u8(count, level[1]);
259   count = vaddq_u8(count, level[2]);
260   count = vaddq_u8(count, level[3]);
261   count = vaddq_u8(count, level[4]);
262 
263   count = vrshrq_n_u8(count, 1);
264   count = vminq_u8(count, const_4);
265   return count;
266 }
267 
get_4_nz_map_contexts_2d(const uint8_t * levels,const int width,const ptrdiff_t * const offsets,uint8_t * const coeff_contexts)268 static inline void get_4_nz_map_contexts_2d(const uint8_t *levels,
269                                             const int width,
270                                             const ptrdiff_t *const offsets,
271                                             uint8_t *const coeff_contexts) {
272   const int stride = 4 + TX_PAD_HOR;
273   const uint8x16_t pos_to_offset_large = vdupq_n_u8(21);
274 
275   uint8x16_t pos_to_offset =
276       (width == 4) ? vld1q_u8(c_4_po_2d[0]) : vld1q_u8(c_4_po_2d[1]);
277 
278   uint8x16_t count;
279   uint8x16_t level[5];
280   uint8_t *cc = coeff_contexts;
281 
282   assert(!(width % 4));
283 
284   int col = width;
285   do {
286     load_levels_4x4x5(levels, stride, offsets, level);
287     count = get_coeff_contexts_kernel(level);
288     count = vaddq_u8(count, pos_to_offset);
289     vst1q_u8(cc, count);
290     pos_to_offset = pos_to_offset_large;
291     levels += 4 * stride;
292     cc += 16;
293     col -= 4;
294   } while (col);
295 
296   coeff_contexts[0] = 0;
297 }
298 
get_4_nz_map_contexts_ver(const uint8_t * levels,const int width,const ptrdiff_t * const offsets,uint8_t * coeff_contexts)299 static inline void get_4_nz_map_contexts_ver(const uint8_t *levels,
300                                              const int width,
301                                              const ptrdiff_t *const offsets,
302                                              uint8_t *coeff_contexts) {
303   const int stride = 4 + TX_PAD_HOR;
304 
305   const uint8x16_t pos_to_offset =
306       vreinterpretq_u8_u32(vdupq_n_u32(SIG_COEF_CONTEXTS_2D_X4_051010));
307 
308   uint8x16_t count;
309   uint8x16_t level[5];
310 
311   assert(!(width % 4));
312 
313   int col = width;
314   do {
315     load_levels_4x4x5(levels, stride, offsets, level);
316     count = get_coeff_contexts_kernel(level);
317     count = vaddq_u8(count, pos_to_offset);
318     vst1q_u8(coeff_contexts, count);
319     levels += 4 * stride;
320     coeff_contexts += 16;
321     col -= 4;
322   } while (col);
323 }
324 
get_4_nz_map_contexts_hor(const uint8_t * levels,const int width,const ptrdiff_t * const offsets,uint8_t * coeff_contexts)325 static inline void get_4_nz_map_contexts_hor(const uint8_t *levels,
326                                              const int width,
327                                              const ptrdiff_t *const offsets,
328                                              uint8_t *coeff_contexts) {
329   const int stride = 4 + TX_PAD_HOR;
330   const uint8x16_t pos_to_offset_large = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 10);
331 
332   uint8x16_t pos_to_offset = vld1q_u8(c_4_po_hor);
333 
334   uint8x16_t count;
335   uint8x16_t level[5];
336 
337   assert(!(width % 4));
338 
339   int col = width;
340   do {
341     load_levels_4x4x5(levels, stride, offsets, level);
342     count = get_coeff_contexts_kernel(level);
343     count = vaddq_u8(count, pos_to_offset);
344     vst1q_u8(coeff_contexts, count);
345     pos_to_offset = pos_to_offset_large;
346     levels += 4 * stride;
347     coeff_contexts += 16;
348     col -= 4;
349   } while (col);
350 }
351 
get_8_coeff_contexts_2d(const uint8_t * levels,const int width,const ptrdiff_t * const offsets,uint8_t * coeff_contexts)352 static inline void get_8_coeff_contexts_2d(const uint8_t *levels,
353                                            const int width,
354                                            const ptrdiff_t *const offsets,
355                                            uint8_t *coeff_contexts) {
356   const int stride = 8 + TX_PAD_HOR;
357   uint8_t *cc = coeff_contexts;
358   uint8x16_t count;
359   uint8x16_t level[5];
360   uint8x16_t pos_to_offset[3];
361 
362   assert(!(width % 2));
363 
364   if (width == 8) {
365     pos_to_offset[0] = vld1q_u8(c_8_po_2d_8[0]);
366     pos_to_offset[1] = vld1q_u8(c_8_po_2d_8[1]);
367   } else if (width < 8) {
368     pos_to_offset[0] = vld1q_u8(c_8_po_2d_l[0]);
369     pos_to_offset[1] = vld1q_u8(c_8_po_2d_l[1]);
370   } else {
371     pos_to_offset[0] = vld1q_u8(c_8_po_2d_g[0]);
372     pos_to_offset[1] = vld1q_u8(c_8_po_2d_g[1]);
373   }
374   pos_to_offset[2] = vdupq_n_u8(21);
375 
376   int col = width;
377   do {
378     load_levels_8x2x5(levels, stride, offsets, level);
379     count = get_coeff_contexts_kernel(level);
380     count = vaddq_u8(count, pos_to_offset[0]);
381     vst1q_u8(cc, count);
382     pos_to_offset[0] = pos_to_offset[1];
383     pos_to_offset[1] = pos_to_offset[2];
384     levels += 2 * stride;
385     cc += 16;
386     col -= 2;
387   } while (col);
388 
389   coeff_contexts[0] = 0;
390 }
391 
get_8_coeff_contexts_ver(const uint8_t * levels,const int width,const ptrdiff_t * const offsets,uint8_t * coeff_contexts)392 static inline void get_8_coeff_contexts_ver(const uint8_t *levels,
393                                             const int width,
394                                             const ptrdiff_t *const offsets,
395                                             uint8_t *coeff_contexts) {
396   const int stride = 8 + TX_PAD_HOR;
397 
398   const uint8x16_t pos_to_offset = vld1q_u8(c_8_po_ver);
399 
400   uint8x16_t count;
401   uint8x16_t level[5];
402 
403   assert(!(width % 2));
404 
405   int col = width;
406   do {
407     load_levels_8x2x5(levels, stride, offsets, level);
408     count = get_coeff_contexts_kernel(level);
409     count = vaddq_u8(count, pos_to_offset);
410     vst1q_u8(coeff_contexts, count);
411     levels += 2 * stride;
412     coeff_contexts += 16;
413     col -= 2;
414   } while (col);
415 }
416 
get_8_coeff_contexts_hor(const uint8_t * levels,const int width,const ptrdiff_t * const offsets,uint8_t * coeff_contexts)417 static inline void get_8_coeff_contexts_hor(const uint8_t *levels,
418                                             const int width,
419                                             const ptrdiff_t *const offsets,
420                                             uint8_t *coeff_contexts) {
421   const int stride = 8 + TX_PAD_HOR;
422   const uint8x16_t pos_to_offset_large = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 10);
423 
424   uint8x16_t pos_to_offset = vcombine_u8(vdup_n_u8(SIG_COEF_CONTEXTS_2D + 0),
425                                          vdup_n_u8(SIG_COEF_CONTEXTS_2D + 5));
426 
427   uint8x16_t count;
428   uint8x16_t level[5];
429 
430   assert(!(width % 2));
431 
432   int col = width;
433   do {
434     load_levels_8x2x5(levels, stride, offsets, level);
435     count = get_coeff_contexts_kernel(level);
436     count = vaddq_u8(count, pos_to_offset);
437     vst1q_u8(coeff_contexts, count);
438     pos_to_offset = pos_to_offset_large;
439     levels += 2 * stride;
440     coeff_contexts += 16;
441     col -= 2;
442   } while (col);
443 }
444 
get_16n_coeff_contexts_2d(const uint8_t * levels,const int real_width,const int real_height,const int width,const int height,const ptrdiff_t * const offsets,uint8_t * coeff_contexts)445 static inline void get_16n_coeff_contexts_2d(const uint8_t *levels,
446                                              const int real_width,
447                                              const int real_height,
448                                              const int width, const int height,
449                                              const ptrdiff_t *const offsets,
450                                              uint8_t *coeff_contexts) {
451   const int stride = height + TX_PAD_HOR;
452   uint8_t *cc = coeff_contexts;
453   int col = width;
454   uint8x16_t pos_to_offset[5];
455   uint8x16_t pos_to_offset_large[3];
456   uint8x16_t count;
457   uint8x16_t level[5];
458 
459   assert(!(height % 16));
460 
461   pos_to_offset_large[2] = vdupq_n_u8(21);
462   if (real_width == real_height) {
463     pos_to_offset[0] = vld1q_u8(c_16_po_2d_e[0]);
464     pos_to_offset[1] = vld1q_u8(c_16_po_2d_e[1]);
465     pos_to_offset[2] = vld1q_u8(c_16_po_2d_e[2]);
466     pos_to_offset[3] = vld1q_u8(c_16_po_2d_e[3]);
467     pos_to_offset[4] = pos_to_offset_large[0] = pos_to_offset_large[1] =
468         pos_to_offset_large[2];
469   } else if (real_width < real_height) {
470     pos_to_offset[0] = vld1q_u8(c_16_po_2d_g[0]);
471     pos_to_offset[1] = vld1q_u8(c_16_po_2d_g[1]);
472     pos_to_offset[2] = pos_to_offset[3] = pos_to_offset[4] =
473         vld1q_u8(c_16_po_2d_g[2]);
474     pos_to_offset_large[0] = pos_to_offset_large[1] = pos_to_offset_large[2];
475   } else {  // real_width > real_height
476     pos_to_offset[0] = pos_to_offset[1] = vld1q_u8(c_16_po_2d_l[0]);
477     pos_to_offset[2] = vld1q_u8(c_16_po_2d_l[1]);
478     pos_to_offset[3] = vld1q_u8(c_16_po_2d_l[2]);
479     pos_to_offset[4] = pos_to_offset_large[2];
480     pos_to_offset_large[0] = pos_to_offset_large[1] = vdupq_n_u8(16);
481   }
482 
483   do {
484     int h = height;
485 
486     do {
487       load_levels_16x1x5(levels, stride, offsets, level);
488       count = get_coeff_contexts_kernel(level);
489       count = vaddq_u8(count, pos_to_offset[0]);
490       vst1q_u8(cc, count);
491       levels += 16;
492       cc += 16;
493       h -= 16;
494       pos_to_offset[0] = pos_to_offset_large[0];
495     } while (h);
496 
497     pos_to_offset[0] = pos_to_offset[1];
498     pos_to_offset[1] = pos_to_offset[2];
499     pos_to_offset[2] = pos_to_offset[3];
500     pos_to_offset[3] = pos_to_offset[4];
501     pos_to_offset_large[0] = pos_to_offset_large[1];
502     pos_to_offset_large[1] = pos_to_offset_large[2];
503     levels += TX_PAD_HOR;
504   } while (--col);
505 
506   coeff_contexts[0] = 0;
507 }
508 
get_16n_coeff_contexts_ver(const uint8_t * levels,const int width,const int height,const ptrdiff_t * const offsets,uint8_t * coeff_contexts)509 static inline void get_16n_coeff_contexts_ver(const uint8_t *levels,
510                                               const int width, const int height,
511                                               const ptrdiff_t *const offsets,
512                                               uint8_t *coeff_contexts) {
513   const int stride = height + TX_PAD_HOR;
514 
515   const uint8x16_t pos_to_offset_large = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 10);
516 
517   uint8x16_t count;
518   uint8x16_t level[5];
519 
520   assert(!(height % 16));
521 
522   int col = width;
523   do {
524     uint8x16_t pos_to_offset = vld1q_u8(c_16_po_ver);
525 
526     int h = height;
527     do {
528       load_levels_16x1x5(levels, stride, offsets, level);
529       count = get_coeff_contexts_kernel(level);
530       count = vaddq_u8(count, pos_to_offset);
531       vst1q_u8(coeff_contexts, count);
532       pos_to_offset = pos_to_offset_large;
533       levels += 16;
534       coeff_contexts += 16;
535       h -= 16;
536     } while (h);
537 
538     levels += TX_PAD_HOR;
539   } while (--col);
540 }
541 
get_16n_coeff_contexts_hor(const uint8_t * levels,const int width,const int height,const ptrdiff_t * const offsets,uint8_t * coeff_contexts)542 static inline void get_16n_coeff_contexts_hor(const uint8_t *levels,
543                                               const int width, const int height,
544                                               const ptrdiff_t *const offsets,
545                                               uint8_t *coeff_contexts) {
546   const int stride = height + TX_PAD_HOR;
547 
548   uint8x16_t pos_to_offset[3];
549   uint8x16_t count;
550   uint8x16_t level[5];
551 
552   assert(!(height % 16));
553 
554   pos_to_offset[0] = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 0);
555   pos_to_offset[1] = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 5);
556   pos_to_offset[2] = vdupq_n_u8(SIG_COEF_CONTEXTS_2D + 10);
557 
558   int col = width;
559   do {
560     int h = height;
561     do {
562       load_levels_16x1x5(levels, stride, offsets, level);
563       count = get_coeff_contexts_kernel(level);
564       count = vaddq_u8(count, pos_to_offset[0]);
565       vst1q_u8(coeff_contexts, count);
566       levels += 16;
567       coeff_contexts += 16;
568       h -= 16;
569     } while (h);
570 
571     pos_to_offset[0] = pos_to_offset[1];
572     pos_to_offset[1] = pos_to_offset[2];
573     levels += TX_PAD_HOR;
574   } while (--col);
575 }
576 
577 // Note: levels[] must be in the range [0, 127], inclusive.
av1_get_nz_map_contexts_neon(const uint8_t * const levels,const int16_t * const scan,const uint16_t eob,const TX_SIZE tx_size,const TX_CLASS tx_class,int8_t * const coeff_contexts)578 void av1_get_nz_map_contexts_neon(const uint8_t *const levels,
579                                   const int16_t *const scan, const uint16_t eob,
580                                   const TX_SIZE tx_size,
581                                   const TX_CLASS tx_class,
582                                   int8_t *const coeff_contexts) {
583   const int last_idx = eob - 1;
584   if (!last_idx) {
585     coeff_contexts[0] = 0;
586     return;
587   }
588 
589   uint8_t *const coefficients = (uint8_t *const)coeff_contexts;
590 
591   const int real_width = tx_size_wide[tx_size];
592   const int real_height = tx_size_high[tx_size];
593   const int width = get_txb_wide(tx_size);
594   const int height = get_txb_high(tx_size);
595   const int stride = height + TX_PAD_HOR;
596   ptrdiff_t offsets[3];
597 
598   /* coeff_contexts must be 16 byte aligned. */
599   assert(!((intptr_t)coeff_contexts & 0xf));
600 
601   if (tx_class == TX_CLASS_2D) {
602     offsets[0] = 0 * stride + 2;
603     offsets[1] = 1 * stride + 1;
604     offsets[2] = 2 * stride + 0;
605 
606     if (height == 4) {
607       get_4_nz_map_contexts_2d(levels, width, offsets, coefficients);
608     } else if (height == 8) {
609       get_8_coeff_contexts_2d(levels, width, offsets, coefficients);
610     } else {
611       get_16n_coeff_contexts_2d(levels, real_width, real_height, width, height,
612                                 offsets, coefficients);
613     }
614   } else if (tx_class == TX_CLASS_HORIZ) {
615     offsets[0] = 2 * stride;
616     offsets[1] = 3 * stride;
617     offsets[2] = 4 * stride;
618     if (height == 4) {
619       get_4_nz_map_contexts_hor(levels, width, offsets, coefficients);
620     } else if (height == 8) {
621       get_8_coeff_contexts_hor(levels, width, offsets, coefficients);
622     } else {
623       get_16n_coeff_contexts_hor(levels, width, height, offsets, coefficients);
624     }
625   } else {  // TX_CLASS_VERT
626     offsets[0] = 2;
627     offsets[1] = 3;
628     offsets[2] = 4;
629     if (height == 4) {
630       get_4_nz_map_contexts_ver(levels, width, offsets, coefficients);
631     } else if (height == 8) {
632       get_8_coeff_contexts_ver(levels, width, offsets, coefficients);
633     } else {
634       get_16n_coeff_contexts_ver(levels, width, height, offsets, coefficients);
635     }
636   }
637 
638   const int bhl = get_txb_bhl(tx_size);
639   const int pos = scan[last_idx];
640   if (last_idx <= (width << bhl) / 8)
641     coeff_contexts[pos] = 1;
642   else if (last_idx <= (width << bhl) / 4)
643     coeff_contexts[pos] = 2;
644   else
645     coeff_contexts[pos] = 3;
646 }
647