1 /*
2 * Copyright (c) 2024, Alliance for Open Media. All rights reserved.
3 *
4 * This source code is subject to the terms of the BSD 2 Clause License and
5 * the Alliance for Open Media Patent License 1.0. If the BSD 2 Clause License
6 * was not distributed with this source code in the LICENSE file, you can
7 * obtain it at www.aomedia.org/license/software. If the Alliance for Open
8 * Media Patent License 1.0 was not distributed with this source code in the
9 * PATENTS file, you can obtain it at www.aomedia.org/license/patent.
10 */
11
12 #ifndef AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_
13 #define AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_
14
15 #include <arm_neon.h>
16 #include <assert.h>
17
18 #include "config/aom_config.h"
19 #include "config/av1_rtcd.h"
20
21 #include "aom/aom_integer.h"
22 #include "aom_dsp/aom_dsp_common.h"
23 #include "aom_dsp/arm/mem_neon.h"
24 #include "aom_ports/mem.h"
25
26 DECLARE_ALIGNED(16, static const uint8_t, kDotProdPermuteTbl[48]) = {
27 0, 1, 2, 3, 1, 2, 3, 4, 2, 3, 4, 5, 3, 4, 5, 6,
28 4, 5, 6, 7, 5, 6, 7, 8, 6, 7, 8, 9, 7, 8, 9, 10,
29 8, 9, 10, 11, 9, 10, 11, 12, 10, 11, 12, 13, 11, 12, 13, 14
30 };
31
32 DECLARE_ALIGNED(16, static const uint8_t, kMatMulPermuteTbl[32]) = {
33 // clang-format off
34 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9,
35 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13
36 // clang-format on
37 };
38
convolve12_4_2d_h(uint8x16_t samples[2],const int8x16_t filter[2],const uint8x16_t permute_tbl,int32x4_t horiz_const)39 static inline int16x4_t convolve12_4_2d_h(uint8x16_t samples[2],
40 const int8x16_t filter[2],
41 const uint8x16_t permute_tbl,
42 int32x4_t horiz_const) {
43 // Permute samples ready for matrix multiply.
44 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
45 // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 }
46 uint8x16_t perm_samples[2] = { vqtbl1q_u8(samples[0], permute_tbl),
47 vqtbl1q_u8(samples[1], permute_tbl) };
48
49 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
50 // (filter), destructively accumulating into the destination register.
51 int32x4_t sum = vusmmlaq_s32(horiz_const, perm_samples[0], filter[0]);
52 sum = vusmmlaq_s32(sum, perm_samples[1], filter[1]);
53
54 // Narrow and re-pack.
55 return vshrn_n_s32(sum, ROUND0_BITS);
56 }
57
convolve12_8_2d_h(uint8x16_t samples[2],const int8x16_t filter[2],const uint8x16x2_t permute_tbl,const int32x4_t horiz_const)58 static inline int16x8_t convolve12_8_2d_h(uint8x16_t samples[2],
59 const int8x16_t filter[2],
60 const uint8x16x2_t permute_tbl,
61 const int32x4_t horiz_const) {
62 /// Permute samples ready for matrix multiply.
63 // { 0, 1, 2, 3, 4, 5, 6, 7, 2, 3, 4, 5, 6, 7, 8, 9 }
64 // { 4, 5, 6, 7, 8, 9, 10, 11, 6, 7, 8, 9, 10, 11, 12, 13 }
65 // { 6, 7, 8, 9, 10, 11, 12, 13, 8, 9, 10, 11, 12, 13, 14, 15 }
66 // { 10, 11, 12, 13, 14, 15, 16, 17, 12, 13, 14, 15, 16, 17, 18, 19 }
67 uint8x16_t perm_samples[4] = { vqtbl1q_u8(samples[0], permute_tbl.val[0]),
68 vqtbl1q_u8(samples[0], permute_tbl.val[1]),
69 vqtbl1q_u8(samples[1], permute_tbl.val[0]),
70 vqtbl1q_u8(samples[1], permute_tbl.val[1]) };
71
72 // These instructions multiply a 2x8 matrix (samples) by an 8x2 matrix
73 // (filter), destructively accumulating into the destination register.
74 int32x4_t sum0123 = vusmmlaq_s32(horiz_const, perm_samples[0], filter[0]);
75 int32x4_t sum4567 = vusmmlaq_s32(horiz_const, perm_samples[1], filter[0]);
76 sum0123 = vusmmlaq_s32(sum0123, perm_samples[2], filter[1]);
77 sum4567 = vusmmlaq_s32(sum4567, perm_samples[3], filter[1]);
78
79 // Narrow and re-pack.
80 return vcombine_s16(vshrn_n_s32(sum0123, ROUND0_BITS),
81 vshrn_n_s32(sum4567, ROUND0_BITS));
82 }
83
convolve_2d_sr_horiz_12tap_neon_i8mm(const uint8_t * src_ptr,int src_stride,int16_t * dst_ptr,const int dst_stride,int w,int h,const int16_t * x_filter_ptr)84 static inline void convolve_2d_sr_horiz_12tap_neon_i8mm(
85 const uint8_t *src_ptr, int src_stride, int16_t *dst_ptr,
86 const int dst_stride, int w, int h, const int16_t *x_filter_ptr) {
87 // The no-op filter should never be used here.
88 assert(x_filter_ptr[5] != 128);
89
90 const int bd = 8;
91
92 // Split 12-tap filter into two 6-tap filters, masking the top two elements.
93 // { 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0 }
94 const int8x8_t mask = vcreate_s8(0x0000ffffffffffff);
95 const int8x8_t filter_0 = vand_s8(vmovn_s16(vld1q_s16(x_filter_ptr)), mask);
96 const int8x8_t filter_1 =
97 vext_s8(vmovn_s16(vld1q_s16(x_filter_ptr + 4)), vdup_n_s8(0), 2);
98
99 // Stagger each 6-tap filter to enable use of matrix multiply instructions.
100 // { f0, f1, f2, f3, f4, f5, 0, 0, 0, f0, f1, f2, f3, f4, f5, 0 }
101 const int8x16_t filter[2] = {
102 vcombine_s8(filter_0, vext_s8(filter_0, filter_0, 7)),
103 vcombine_s8(filter_1, vext_s8(filter_1, filter_1, 7))
104 };
105
106 // This shim of 1 << (ROUND0_BITS - 1) enables us to use non-rounding shifts
107 // in convolution kernels - which are generally faster than rounding shifts on
108 // modern CPUs.
109 const int32x4_t horiz_const =
110 vdupq_n_s32((1 << (bd + FILTER_BITS - 1)) + (1 << (ROUND0_BITS - 1)));
111
112 if (w <= 4) {
113 const uint8x16_t permute_tbl = vld1q_u8(kMatMulPermuteTbl);
114
115 do {
116 uint8x16_t s0[2], s1[2], s2[2], s3[2];
117 load_u8_16x4(src_ptr, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
118 load_u8_16x4(src_ptr + 6, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
119
120 int16x4_t d0 = convolve12_4_2d_h(s0, filter, permute_tbl, horiz_const);
121 int16x4_t d1 = convolve12_4_2d_h(s1, filter, permute_tbl, horiz_const);
122 int16x4_t d2 = convolve12_4_2d_h(s2, filter, permute_tbl, horiz_const);
123 int16x4_t d3 = convolve12_4_2d_h(s3, filter, permute_tbl, horiz_const);
124
125 store_s16_4x4(dst_ptr, dst_stride, d0, d1, d2, d3);
126
127 src_ptr += 4 * src_stride;
128 dst_ptr += 4 * dst_stride;
129 h -= 4;
130 } while (h > 4);
131
132 do {
133 uint8x16_t s0[2];
134 s0[0] = vld1q_u8(src_ptr);
135 s0[1] = vld1q_u8(src_ptr + 6);
136 int16x4_t d0 = convolve12_4_2d_h(s0, filter, permute_tbl, horiz_const);
137 vst1_s16(dst_ptr, d0);
138
139 src_ptr += src_stride;
140 dst_ptr += dst_stride;
141 } while (--h != 0);
142
143 } else {
144 const uint8x16x2_t permute_tbl = vld1q_u8_x2(kMatMulPermuteTbl);
145
146 do {
147 const uint8_t *s = src_ptr;
148 int16_t *d = dst_ptr;
149 int width = w;
150
151 do {
152 uint8x16_t s0[2], s1[2], s2[2], s3[2];
153 load_u8_16x4(s, src_stride, &s0[0], &s1[0], &s2[0], &s3[0]);
154 load_u8_16x4(s + 6, src_stride, &s0[1], &s1[1], &s2[1], &s3[1]);
155
156 int16x8_t d0 = convolve12_8_2d_h(s0, filter, permute_tbl, horiz_const);
157 int16x8_t d1 = convolve12_8_2d_h(s1, filter, permute_tbl, horiz_const);
158 int16x8_t d2 = convolve12_8_2d_h(s2, filter, permute_tbl, horiz_const);
159 int16x8_t d3 = convolve12_8_2d_h(s3, filter, permute_tbl, horiz_const);
160
161 store_s16_8x4(d, dst_stride, d0, d1, d2, d3);
162
163 s += 8;
164 d += 8;
165 width -= 8;
166 } while (width != 0);
167
168 src_ptr += 4 * src_stride;
169 dst_ptr += 4 * dst_stride;
170 h -= 4;
171 } while (h > 4);
172
173 do {
174 const uint8_t *s = src_ptr;
175 int16_t *d = dst_ptr;
176 int width = w;
177
178 do {
179 uint8x16_t s0[2];
180 s0[0] = vld1q_u8(s);
181 s0[1] = vld1q_u8(s + 6);
182 int16x8_t d0 = convolve12_8_2d_h(s0, filter, permute_tbl, horiz_const);
183 vst1q_s16(d, d0);
184
185 s += 8;
186 d += 8;
187 width -= 8;
188 } while (width != 0);
189 src_ptr += src_stride;
190 dst_ptr += dst_stride;
191 } while (--h != 0);
192 }
193 }
194
195 #endif // AOM_AV1_COMMON_ARM_CONVOLVE_NEON_I8MM_H_
196