xref: /aosp_15_r20/external/libaom/av1/common/arm/convolve_neon_i8mm.h (revision 77c1e3ccc04c968bd2bc212e87364f250e820521)
1 /*
2  * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3  *
4  * This source code is subject to the terms of the BSD 2 Clause License and
5  * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6  * was not distributed with this source code in the LICENSE file, you can
7  * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8  * Media Patent License 1.0 was not distributed with this source code in the
9  * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10  */
11 
12 #ifndef AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_
13 #define AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_
14 
15 #include <arm_neon.h>
16 #include <assert.h>
17 
18 #include "config/aom_config.h"
19 #include "config/av1_rtcd.h"
20 
21 #include "aom/aom_integer.h"
22 #include "aom_dsp/aom_dsp_common.h"
23 #include "aom_dsp/arm/mem_neon.h"
24 #include "aom_ports/mem.h"
25 
26 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
27   0, 1, 2,  3,  1, 2,  3,  4,  2,  3,  4,  5,  3,  4,  5,  6,
28   4, 5, 6,  7,  5, 6,  7,  8,  6,  7,  8,  9,  7,  8,  9,  10,
29   8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
30 };
31 
32 DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
33   // clang-format off
34   0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9,
35   4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13
36   // clang-format on
37 };
38 
convolve12_4_2d_h(uint8x16_t samples[2],const int8x16_t filter[2],const uint8x16_t permute_tbl,int32x4_t horiz_const)39 static inline int16x4_t convolve12_4_2d_h(uint8x16_t samples[2],
40                                           const int8x16_t filter[2],
41                                           const uint8x16_t permute_tbl,
42                                           int32x4_t horiz_const) {
43   // Permute samples ready for matrix multiply.
44   // {  0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
45   // {  4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
46   uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples[0], permute_tbl),
47                                  vqtbl1q_u8(samples[1], permute_tbl) };
48 
49   // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
50   // (filter), destructively accumulating into the destination register.
51   int32x4_t sum = vusmmlaq_s32(horiz_const, perm_samples[0], filter[0]);
52   sum = vusmmlaq_s32(sum, perm_samples[1], filter[1]);
53 
54   // Narrow and re-pack.
55   return vshrn_n_s32(sum, ROUND0_BITS);
56 }
57 
convolve12_8_2d_h(uint8x16_t samples[2],const int8x16_t filter[2],const uint8x16x2_t permute_tbl,const int32x4_t horiz_const)58 static inline int16x8_t convolve12_8_2d_h(uint8x16_t samples[2],
59                                           const int8x16_t filter[2],
60                                           const uint8x16x2_t permute_tbl,
61                                           const int32x4_t horiz_const) {
62   /// Permute samples ready for matrix multiply.
63   // {  0,  1,  2,  3,  4,  5,  6,  7,  2,  3,  4,  5,  6,  7,  8,  9 }
64   // {  4,  5,  6,  7,  8,  9, 10, 11,  6,  7,  8,  9, 10, 11, 12, 13 }
65   // {  6,  7,  8,  9, 10, 11, 12, 13,  8,  9, 10, 11, 12, 13, 14, 15 }
66   // { 10, 11, 12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17, 18, 19 }
67   uint8x16_t perm_samples[4] = { vqtbl1q_u8(samples[0], permute_tbl.val[0]),
68                                  vqtbl1q_u8(samples[0], permute_tbl.val[1]),
69                                  vqtbl1q_u8(samples[1], permute_tbl.val[0]),
70                                  vqtbl1q_u8(samples[1], permute_tbl.val[1]) };
71 
72   // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
73   // (filter), destructively accumulating into the destination register.
74   int32x4_t sum0123 = vusmmlaq_s32(horiz_const, perm_samples[0], filter[0]);
75   int32x4_t sum4567 = vusmmlaq_s32(horiz_const, perm_samples[1], filter[0]);
76   sum0123 = vusmmlaq_s32(sum0123, perm_samples[2], filter[1]);
77   sum4567 = vusmmlaq_s32(sum4567, perm_samples[3], filter[1]);
78 
79   // Narrow and re-pack.
80   return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS),
81                       vshrn_n_s32(sum4567, ROUND0_BITS));
82 }
83 
convolve_2d_sr_horiz_12tap_neon_i8mm(const uint8_t * src_ptr,int src_stride,int16_t * dst_ptr,const int dst_stride,int w,int h,const int16_t * x_filter_ptr)84 static inline void convolve_2d_sr_horiz_12tap_neon_i8mm(
85     const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
86     const int dst_stride, int w, int h, const int16_t *x_filter_ptr) {
87   // The no-op filter should never be used here.
88   assert(x_filter_ptr[5] != 128);
89 
90   const int bd = 8;
91 
92   // Split 12-tap filter into two 6-tap filters, masking the top two elements.
93   // { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0 }
94   const int8x8_t mask = vcreate_s8(0x0000ffffffffffff);
95   const int8x8_t filter_0 = vand_s8(vmovn_s16(vld1q_s16(x_filter_ptr)), mask);
96   const int8x8_t filter_1 =
97       vext_s8(vmovn_s16(vld1q_s16(x_filter_ptr + 4)), vdup_n_s8(0), 2);
98 
99   // Stagger each 6-tap filter to enable use of matrix multiply instructions.
100   // { f0, f1, f2, f3, f4, f5,  0,  0,  0, f0, f1, f2, f3, f4, f5,  0 }
101   const int8x16_t filter[2] = {
102     vcombine_s8(filter_0, vext_s8(filter_0, filter_0, 7)),
103     vcombine_s8(filter_1, vext_s8(filter_1, filter_1, 7))
104   };
105 
106   // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
107   // in convolution kernels - which are generally faster than rounding shifts on
108   // modern CPUs.
109   const int32x4_t horiz_const =
110       vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
111 
112   if (w <= 4) {
113     const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
114 
115     do {
116       uint8x16_t s0[2], s1[2], s2[2], s3[2];
117       load_u8_16x4(src_ptr, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
118       load_u8_16x4(src_ptr + 6, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
119 
120       int16x4_t d0 = convolve12_4_2d_h(s0, filter, permute_tbl, horiz_const);
121       int16x4_t d1 = convolve12_4_2d_h(s1, filter, permute_tbl, horiz_const);
122       int16x4_t d2 = convolve12_4_2d_h(s2, filter, permute_tbl, horiz_const);
123       int16x4_t d3 = convolve12_4_2d_h(s3, filter, permute_tbl, horiz_const);
124 
125       store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
126 
127       src_ptr += 4 * src_stride;
128       dst_ptr += 4 * dst_stride;
129       h -= 4;
130     } while (h > 4);
131 
132     do {
133       uint8x16_t s0[2];
134       s0[0] = vld1q_u8(src_ptr);
135       s0[1] = vld1q_u8(src_ptr + 6);
136       int16x4_t d0 = convolve12_4_2d_h(s0, filter, permute_tbl, horiz_const);
137       vst1_s16(dst_ptr, d0);
138 
139       src_ptr += src_stride;
140       dst_ptr += dst_stride;
141     } while (--h != 0);
142 
143   } else {
144     const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
145 
146     do {
147       const uint8_t *s = src_ptr;
148       int16_t *d = dst_ptr;
149       int width = w;
150 
151       do {
152         uint8x16_t s0[2], s1[2], s2[2], s3[2];
153         load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
154         load_u8_16x4(s + 6, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
155 
156         int16x8_t d0 = convolve12_8_2d_h(s0, filter, permute_tbl, horiz_const);
157         int16x8_t d1 = convolve12_8_2d_h(s1, filter, permute_tbl, horiz_const);
158         int16x8_t d2 = convolve12_8_2d_h(s2, filter, permute_tbl, horiz_const);
159         int16x8_t d3 = convolve12_8_2d_h(s3, filter, permute_tbl, horiz_const);
160 
161         store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
162 
163         s += 8;
164         d += 8;
165         width -= 8;
166       } while (width != 0);
167 
168       src_ptr += 4 * src_stride;
169       dst_ptr += 4 * dst_stride;
170       h -= 4;
171     } while (h > 4);
172 
173     do {
174       const uint8_t *s = src_ptr;
175       int16_t *d = dst_ptr;
176       int width = w;
177 
178       do {
179         uint8x16_t s0[2];
180         s0[0] = vld1q_u8(s);
181         s0[1] = vld1q_u8(s + 6);
182         int16x8_t d0 = convolve12_8_2d_h(s0, filter, permute_tbl, horiz_const);
183         vst1q_s16(d, d0);
184 
185         s += 8;
186         d += 8;
187         width -= 8;
188       } while (width != 0);
189       src_ptr += src_stride;
190       dst_ptr += dst_stride;
191     } while (--h != 0);
192   }
193 }
194 
195 #endif  // AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_
196