1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 7$assert NR % 8 == 0 8$assert 8 <= NR <= 16 9$assert REQUANTIZATION == "RNDNU" 10#include <assert.h> 11 12#include <arm_neon.h> 13 14#include <xnnpack/common.h> 15#include <xnnpack/gemm.h> 16 17 18void xnn_qs8_igemm_minmax_rndnu_ukernel_${MR}x${NR}__neon_mull_addw_dup( 19 size_t mr, 20 size_t nc, 21 size_t kc, 22 size_t ks, 23 const int8_t** restrict a, 24 const void* restrict w, 25 int8_t* restrict c, 26 size_t cm_stride, 27 size_t cn_stride, 28 size_t a_offset, 29 const int8_t* zero, 30 const union xnn_qs8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 31{ 32 assert(mr != 0); 33 assert(mr <= ${MR}); 34 assert(nc != 0); 35 assert(kc != 0); 36 assert(ks != 0); 37 assert(ks % (${MR} * sizeof(void*)) == 0); 38 assert(a_offset % sizeof(int8_t) == 0); 39 assert(a != NULL); 40 assert(w != NULL); 41 assert(c != NULL); 42 43 int8_t* c0 = c; 44 $for M in range(1, MR): 45 int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride); 46 $if M % 2 == 0: 47 if XNN_UNPREDICTABLE(mr <= ${M}) { 48 c${M} = c${M-1}; 49 } 50 $elif M + 1 == MR: 51 if XNN_UNPREDICTABLE(mr != ${M+1}) { 52 c${M} = c${M-1}; 53 } 54 $else: 55 if XNN_UNPREDICTABLE(mr < ${M+1}) { 56 c${M} = c${M-1}; 57 } 58 59 do { 60 $for N in range(0, NR, 4): 61 int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)); 62 $for M in range(1, MR): 63 $for N in range(0, NR, 4): 64 int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]}; 65 66 size_t p = ks; 67 do { 68 $for M in range(MR): 69 const int8_t* restrict a${M} = a[${M}]; 70 if XNN_UNPREDICTABLE(a${M} != zero) { 71 a${M} = (const int8_t*) ((uintptr_t) a${M} + a_offset); 72 } 73 a += ${MR}; 74 75 size_t k = kc; 76 while (k >= 8 * sizeof(int8_t)) { 77 $for M in range(MR): 78 const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8; 79 80 $for K in range(8): 81 $for N in range(0, NR, 8): 82 const int8x8_t vb${ABC[N:N+8]}c${K} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); 83 84 $for M in range(MR): 85 const int16x8_t vprod${M}x${ABC[N:N+8]}c${K} = vmull_s8(vb${ABC[N:N+8]}c${K}, vdup_lane_s8(va${M}, ${K})); 86 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c${K})); 87 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c${K})); 88 89 k -= 8 * sizeof(int8_t); 90 } 91 if XNN_UNLIKELY(k != 0) { 92 $for M in range(MR): 93 const int8x8_t va${M} = vld1_s8(a${M}); a${M} = (const int8_t*) ((uintptr_t) a${M} + k); 94 95 $for N in range(0, NR, 8): 96 const int8x8_t vb${ABC[N:N+8]}c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); 97 98 $for M in range(MR): 99 $for N in range(0, NR, 8): 100 const int16x8_t vprod${M}x${ABC[N:N+8]}c0 = vmull_s8(vb${ABC[N:N+8]}c0, vdup_lane_s8(va${M}, 0)); 101 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c0)); 102 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c0)); 103 104 if (k >= 2 * sizeof(int8_t)) { 105 $for N in range(0, NR, 8): 106 const int8x8_t vb${ABC[N:N+8]}c1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); 107 108 $for M in range(MR): 109 $for N in range(0, NR, 8): 110 const int16x8_t vprod${M}x${ABC[N:N+8]}c1 = vmull_s8(vb${ABC[N:N+8]}c1, vdup_lane_s8(va${M}, 1)); 111 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c1)); 112 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c1)); 113 114 if (k > 2 * sizeof(int8_t)) { 115 $for N in range(0, NR, 8): 116 const int8x8_t vb${ABC[N:N+8]}c2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); 117 118 $for M in range(MR): 119 $for N in range(0, NR, 8): 120 const int16x8_t vprod${M}x${ABC[N:N+8]}c2 = vmull_s8(vb${ABC[N:N+8]}c2, vdup_lane_s8(va${M}, 2)); 121 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c2)); 122 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c2)); 123 124 if (k >= 4 * sizeof(int8_t)) { 125 $for N in range(0, NR, 8): 126 const int8x8_t vb${ABC[N:N+8]}c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); 127 128 $for M in range(MR): 129 $for N in range(0, NR, 8): 130 const int16x8_t vprod${M}x${ABC[N:N+8]}c3 = vmull_s8(vb${ABC[N:N+8]}c3, vdup_lane_s8(va${M}, 3)); 131 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c3)); 132 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c3)); 133 134 if (k > 4 * sizeof(int8_t)) { 135 $for N in range(0, NR, 8): 136 const int8x8_t vb${ABC[N:N+8]}c4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); 137 138 $for M in range(MR): 139 $for N in range(0, NR, 8): 140 const int16x8_t vprod${M}x${ABC[N:N+8]}c4 = vmull_s8(vb${ABC[N:N+8]}c4, vdup_lane_s8(va${M}, 4)); 141 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c4)); 142 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c4)); 143 144 if (k >= 6 * sizeof(int8_t)) { 145 $for N in range(0, NR, 8): 146 const int8x8_t vb${ABC[N:N+8]}c5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); 147 148 $for M in range(MR): 149 $for N in range(0, NR, 8): 150 const int16x8_t vprod${M}x${ABC[N:N+8]}c5 = vmull_s8(vb${ABC[N:N+8]}c5, vdup_lane_s8(va${M}, 5)); 151 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c5)); 152 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c5)); 153 154 if (k > 6 * sizeof(int8_t)) { 155 $for N in range(0, NR, 8): 156 const int8x8_t vb${ABC[N:N+8]}c6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t)); 157 158 $for M in range(MR): 159 $for N in range(0, NR, 8): 160 const int16x8_t vprod${M}x${ABC[N:N+8]}c6 = vmull_s8(vb${ABC[N:N+8]}c6, vdup_lane_s8(va${M}, 6)); 161 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c6)); 162 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c6)); 163 } 164 } 165 } 166 } 167 } 168 } 169 } 170 p -= ${MR} * sizeof(void*); 171 } while (p != 0); 172 173 const int32x4_t vright_pre_shift = vld1q_dup_s32(¶ms->rndnu_neon.right_pre_shift); 174 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->rndnu_neon.multiplier); 175 const int32x4_t vright_post_shift = vld1q_dup_s32(¶ms->rndnu_neon.right_post_shift); 176 177 $for M in range(MR): 178 $for N in range(0, NR, 4): 179 vacc${M}x${ABC[N:N+4]} = vqshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_pre_shift); 180 181 $for M in range(MR): 182 $for N in range(0, NR, 4): 183 vacc${M}x${ABC[N:N+4]} = vqdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier); 184 185 $for M in range(MR): 186 $for N in range(0, NR, 4): 187 vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_post_shift); 188 189 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->rndnu_neon.output_zero_point); 190#if XNN_ARCH_ARM64 191 $for M in range(MR): 192 $for N in range(0, NR, 8): 193 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}), voutput_zero_point); 194 195 $for M in range(MR): 196 $for N in range(0, NR, 16): 197 $if N + 8 < NR: 198 int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]}); 199 $elif M % 2 == 1: 200 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]}); 201 $elif M + 1 == MR: 202 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]}); 203#else 204 $for M in range(MR): 205 $for N in range(0, NR, 8): 206 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})), voutput_zero_point); 207 208 $for M in range(MR): 209 $for N in range(0, NR, 16): 210 $if N + 8 < NR: 211 int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]})); 212 $elif M % 2 == 1: 213 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]})); 214 $elif M + 1 == MR: 215 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]}); 216#endif 217 $if NR == 8 and MR == 1: 218 const int8x8_t voutput_min = vld1_dup_s8(¶ms->rndnu_neon.output_min); 219 const int8x8_t voutput_max = vld1_dup_s8(¶ms->rndnu_neon.output_max); 220 $else: 221 const int8x16_t voutput_min = vld1q_dup_s8(¶ms->rndnu_neon.output_min); 222 const int8x16_t voutput_max = vld1q_dup_s8(¶ms->rndnu_neon.output_max); 223 224 $for M in reversed(range(MR)): 225 $for N in range(0, NR, 16): 226 $if N + 8 < NR: 227 vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min); 228 $elif M % 2 == 1: 229 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min); 230 $elif M + 1 == MR: 231 $if NR == 8 and MR == 1: 232 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min); 233 $else: 234 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min)); 235 236 $for M in reversed(range(MR)): 237 $for N in range(0, NR, 16): 238 $if N + 8 < NR: 239 vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max); 240 $elif M % 2 == 1: 241 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max); 242 $elif M + 1 == MR: 243 $if NR == 8 and MR == 1: 244 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max); 245 $else: 246 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max)); 247 248 if (nc >= ${NR}) { 249 $for M in reversed(range(MR)): 250 $for N in range(0, NR, 16): 251 $if N + 8 < NR: 252 vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]}); 253 $elif M % 2 == 1: 254 vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 255 vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); 256 $elif M + 1 == MR: 257 vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]}); 258 259 $for M in reversed(range(MR)): 260 c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride); 261 262 a = (const int8_t**restrict) ((uintptr_t) a - ks); 263 264 nc -= ${NR}; 265 } else { 266 $if NR == 16: 267 $for M in reversed(range(MR)): 268 $if M % 2 == 1: 269 int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF)); 270 $elif M + 1 == MR: 271 int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF); 272 if (nc & 8) { 273 $for M in reversed(range(MR)): 274 $if M % 2 == 1: 275 vst1_s8(c${M}, vget_high_s8(vout${M-1}x01234567_${M}x01234567)); c${M} += 8; 276 vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x01234567_${M}x01234567)); c${M-1} += 8; 277 $elif M + 1 == MR: 278 vst1_s8(c${M}, vout${M}x01234567); c${M} += 8; 279 $for M in reversed(range(MR)): 280 $if M % 2 == 1: 281 vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF)); 282 $elif M + 1 == MR: 283 vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF); 284 } 285 if (nc & 4) { 286 $for M in reversed(range(MR)): 287 $if M % 2 == 1: 288 vst1q_lane_u32((void*) c${M}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4; 289 vst1q_lane_u32((void*) c${M-1}, vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4; 290 $elif M + 1 == MR: 291 vst1_lane_u32((void*) c${M}, vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4; 292 $for M in reversed(range(MR)): 293 $if M % 2 == 1: 294 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4); 295 $elif M + 1 == MR: 296 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4); 297 } 298 if (nc & 2) { 299 $for M in reversed(range(MR)): 300 $if M % 2 == 1: 301 vst1q_lane_u16((void*) c${M}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2; 302 vst1q_lane_u16((void*) c${M-1}, vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2; 303 $elif M + 1 == MR: 304 vst1_lane_u16((void*) c${M}, vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2; 305 $for M in reversed(range(MR)): 306 $if M % 2 == 1: 307 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2); 308 $elif M + 1 == MR: 309 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2); 310 } 311 if (nc & 1) { 312 $for M in reversed(range(MR)): 313 $if M % 2 == 1: 314 vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8); 315 vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0); 316 $elif M + 1 == MR: 317 vst1_lane_s8(c${M}, vout${M}x01234567, 0); 318 } 319 320 nc = 0; 321 } 322 } while (nc != 0); 323} 324