1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert REQUANTIZATION in ["FP32", "RNDNU"] 7$assert DATATYPE in ["QC8", "QS8", "QU8"] 8$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32" 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10$assert CHANNEL_TILE % 8 == 0 11$assert CHANNEL_TILE >= 8 12$assert KERNEL_TILE >= 2 13#include <assert.h> 14 15#include <arm_neon.h> 16 17#include <xnnpack/dwconv.h> 18$if REQUANTIZATION == "FP32" and ARMV8: 19 #include <xnnpack/intrinsics-polyfill.h> 20 21 22$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if ARMV8 else "neon") 23$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 24$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t" 25$XINT8X8_T = "uint8x8_t" if DATATYPE == "QU8" else "int8x8_t" 26$XINT8X16_T = "uint8x16_t" if DATATYPE == "QU8" else "int8x16_t" 27$VGET_LOW_X8 = "vget_low_u8" if DATATYPE == "QU8" else "vget_low_s8" 28$VCOMBINE_X8 = "vcombine_u8" if DATATYPE == "QU8" else "vcombine_s8" 29$VREINTERPRET_U32_X8 = "vreinterpret_u32_u8" if DATATYPE == "QU8" else "vreinterpret_u32_s8" 30$VREINTERPRET_U16_X8 = "vreinterpret_u16_u8" if DATATYPE == "QU8" else "vreinterpret_u16_s8" 31$VLD1_X8 = "vld1_u8" if DATATYPE == "QU8" else "vld1_s8" 32$VLD1_DUP_X8 = "vld1_dup_u8" if DATATYPE == "QU8" else "vld1_dup_s8" 33$VLD1Q_DUP_X8 = "vld1q_dup_u8" if DATATYPE == "QU8" else "vld1q_dup_s8" 34$VST1_X8 = "vst1_u8" if DATATYPE == "QU8" else "vst1_s8" 35$VST1Q_X8 = "vst1q_u8" if DATATYPE == "QU8" else "vst1q_s8" 36$VST1_LANE_X8 = "vst1_lane_u8" if DATATYPE == "QU8" else "vst1_lane_s8" 37$VST1Q_LANE_X8 = "vst1q_lane_u8" if DATATYPE == "QU8" else "vst1q_lane_s8" 38$VMIN_X8 = "vmin_u8" if DATATYPE == "QU8" else "vmin_s8" 39$VMAX_X8 = "vmax_u8" if DATATYPE == "QU8" else "vmax_s8" 40$VMINQ_X8 = "vminq_u8" if DATATYPE == "QU8" else "vminq_s8" 41$VMAXQ_X8 = "vmaxq_u8" if DATATYPE == "QU8" else "vmaxq_s8" 42$VEXT_X8 = "vext_u8" if DATATYPE == "QU8" else "vext_s8" 43$VQMOVXN_S16 = "vqmovun_s16" if DATATYPE == "QU8" else "vqmovn_s16" 44$VQMOVXN_HIGH_S16 = "vqmovun_high_s16" if DATATYPE == "QU8" else "vqmovn_high_s16" 45$ISA = "neonv8" if ARMV8 else "neon" 46void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${ISA}_mul16( 47 size_t channels, 48 size_t output_width, 49 const ${XINT8_T}** input, 50 const void* weights, 51 ${XINT8_T}* output, 52 size_t input_stride, 53 size_t output_increment, 54 size_t input_offset, 55 const ${XINT8_T}* zero, 56 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 57{ 58 assert(channels != 0); 59 assert(output_width != 0); 60 61 $if DATATYPE == "QU8": 62 const uint8x8_t vkernel_zero_point = vld1_dup_u8(params->${PARAMS_STRUCT}.kernel_zero_point); 63 $if REQUANTIZATION == "RNDNU": 64 const int32x4_t vright_pre_shift = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.right_pre_shift); 65 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.multiplier); 66 const int32x4_t vright_post_shift = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.right_post_shift); 67 $elif REQUANTIZATION == "FP32": 68 $if DATATYPE != "QC8": 69 const float32x4_t vscale = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.scale); 70 $if not ARMV8: 71 const float32x4_t vmagic_bias = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.magic_bias); 72 const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.magic_bias_less_output_zero_point); 73 $if REQUANTIZATION != "FP32" or ARMV8: 74 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->${PARAMS_STRUCT}.output_zero_point); 75 $if CHANNEL_TILE == 8: 76 const ${XINT8X8_T} voutput_min = ${VLD1_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_min); 77 const ${XINT8X8_T} voutput_max = ${VLD1_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_max); 78 $else: 79 const ${XINT8X16_T} voutput_min = ${VLD1Q_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_min); 80 const ${XINT8X16_T} voutput_max = ${VLD1Q_DUP_X8}(¶ms->${PARAMS_STRUCT}.output_max); 81 do { 82 $for K in range(KERNEL_TILE): 83 const ${XINT8_T}* i${K} = input[${K}]; 84 assert(i${K} != NULL); 85 if XNN_UNPREDICTABLE(i${K} != zero) { 86 i${K} = (const ${XINT8_T}*) ((uintptr_t) i${K} + input_offset); 87 } 88 input = (const ${XINT8_T}**) ((uintptr_t) input + input_stride); 89 90 size_t c = channels; 91 const void* w = weights; 92 for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { 93 $for C in range(0, CHANNEL_TILE, 4): 94 int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4); 95 96 $for K in range(KERNEL_TILE): 97 98 $for C in range(0, CHANNEL_TILE, 8): 99 $if DATATYPE == "QU8": 100 const int16x8_t vi${K}x${ABC[C:C+8]} = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(i${K}))); i${K} += 8; 101 const int16x8_t vk${K}x${ABC[C:C+8]} = vreinterpretq_s16_u16(vsubl_u8(vld1_u8(w), vkernel_zero_point)); w = (const void*) ((const ${XINT8_T}*) w + 8); 102 $else: 103 const int16x8_t vi${K}x${ABC[C:C+8]} = vmovl_s8(vld1_s8(i${K})); i${K} += 8; 104 const int16x8_t vk${K}x${ABC[C:C+8]} = vmovl_s8(vld1_s8(w)); w = (const void*) ((const ${XINT8_T}*) w + 8); 105 106 $for C in range(0, CHANNEL_TILE, 8): 107 vacc${ABC[C:C+4]} = vmlal_s16(vacc${ABC[C:C+4]}, vget_low_s16(vi${K}x${ABC[C:C+8]}), vget_low_s16(vk${K}x${ABC[C:C+8]})); 108 vacc${ABC[C+4:C+8]} = vmlal_s16(vacc${ABC[C+4:C+8]}, vget_high_s16(vi${K}x${ABC[C:C+8]}), vget_high_s16(vk${K}x${ABC[C:C+8]})); 109 110 $if REQUANTIZATION == "RNDNU": 111 $for C in range(0, CHANNEL_TILE, 4): 112 vacc${ABC[C:C+4]} = vqshlq_s32(vacc${ABC[C:C+4]}, vright_pre_shift); 113 114 $for C in range(0, CHANNEL_TILE, 4): 115 vacc${ABC[C:C+4]} = vqdmulhq_s32(vacc${ABC[C:C+4]}, vmultiplier); 116 117 $for C in range(0, CHANNEL_TILE, 4): 118 vacc${ABC[C:C+4]} = vrshlq_s32(vacc${ABC[C:C+4]}, vright_post_shift); 119 $elif REQUANTIZATION == "FP32": 120 $for C in range(0, CHANNEL_TILE, 4): 121 float32x4_t vfpacc${ABC[C:C+4]} = vcvtq_f32_s32(vacc${ABC[C:C+4]}); 122 123 $if DATATYPE == "QC8": 124 $for C in range(0, CHANNEL_TILE, 4): 125 const float32x4_t vscale${ABC[C:C+4]} = vld1q_f32((const float*) w); w = (const void*) ((const float*) w + 4); 126 127 $for C in range(0, CHANNEL_TILE, 4): 128 vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale${ABC[C:C+4]}); 129 $else: 130 $for C in range(0, CHANNEL_TILE, 4): 131 vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale); 132 133 $if ARMV8: 134 $for C in range(0, CHANNEL_TILE, 4): 135 vacc${ABC[C:C+4]} = vcvtnq_s32_f32(vfpacc${ABC[C:C+4]}); 136 $else: 137 $for C in range(0, CHANNEL_TILE, 4): 138 vacc${ABC[C:C+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[C:C+4]}, vmagic_bias)); 139 140 $for C in range(0, CHANNEL_TILE, 4): 141 vacc${ABC[C:C+4]} = vqsubq_s32(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point); 142 143#if XNN_ARCH_ARM64 144 $for C in range(0, CHANNEL_TILE, 8): 145 int16x8_t vacc${ABC[C:C+8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[C:C+4]}), vacc${ABC[C+4:C+8]}); 146 147 $if REQUANTIZATION != "FP32" or ARMV8: 148 $for C in range(0, CHANNEL_TILE, 8): 149 vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point); 150 151 $for C in range(0, CHANNEL_TILE, 16): 152 $if C + 8 < CHANNEL_TILE: 153 ${XINT8X16_T} vout${ABC[C:C+16]} = ${VQMOVXN_HIGH_S16}(${VQMOVXN_S16}(vacc${ABC[C:C+8]}), vacc${ABC[C+8:C+16]}); 154 $else: 155 ${XINT8X8_T} vout${ABC[C:C+8]} = ${VQMOVXN_S16}(vacc${ABC[C:C+8]}); 156#else // !XNN_ARCH_ARM64 157 $for C in range(0, CHANNEL_TILE, 8): 158 int16x8_t vacc${ABC[C:C+8]} = vcombine_s16(vqmovn_s32(vacc${ABC[C:C+4]}), vqmovn_s32(vacc${ABC[C+4:C+8]})); 159 160 $if REQUANTIZATION != "FP32" or ARMV8: 161 $for C in range(0, CHANNEL_TILE, 8): 162 vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point); 163 164 $for C in range(0, CHANNEL_TILE, 16): 165 $if C + 8 < CHANNEL_TILE: 166 ${XINT8X16_T} vout${ABC[C:C+16]} = ${VCOMBINE_X8}(${VQMOVXN_S16}(vacc${ABC[C:C+8]}), ${VQMOVXN_S16}(vacc${ABC[C+8:C+16]})); 167 $else: 168 ${XINT8X8_T} vout${ABC[C:C+8]} = ${VQMOVXN_S16}(vacc${ABC[C:C+8]}); 169#endif // !XNN_ARCH_ARM64 170 171 $for C in range(0, CHANNEL_TILE, 16): 172 $if C + 8 < CHANNEL_TILE: 173 vout${ABC[C:C+16]} = ${VMAXQ_X8}(vout${ABC[C:C+16]}, voutput_min); 174 $else: 175 $if CHANNEL_TILE == 8: 176 vout${ABC[C:C+8]} = ${VMAX_X8}(vout${ABC[C:C+8]}, voutput_min); 177 $else: 178 vout${ABC[C:C+8]} = ${VMAX_X8}(vout${ABC[C:C+8]}, ${VGET_LOW_X8}(voutput_min)); 179 180 $for C in range(0, CHANNEL_TILE, 16): 181 $if C + 8 < CHANNEL_TILE: 182 vout${ABC[C:C+16]} = ${VMINQ_X8}(vout${ABC[C:C+16]}, voutput_max); 183 $else: 184 $if CHANNEL_TILE == 8: 185 vout${ABC[C:C+8]} = ${VMIN_X8}(vout${ABC[C:C+8]}, voutput_max); 186 $else: 187 vout${ABC[C:C+8]} = ${VMIN_X8}(vout${ABC[C:C+8]}, ${VGET_LOW_X8}(voutput_max)); 188 189 $for C in range(0, CHANNEL_TILE, 16): 190 $if C + 8 < CHANNEL_TILE: 191 ${VST1Q_X8}(output, vout${ABC[C:C+16]}); output += 16; 192 $else: 193 ${VST1_X8}(output, vout${ABC[C:C+8]}); output += 8; 194 } 195 if XNN_UNLIKELY(c != 0) { 196 $if CHANNEL_TILE > 8: 197 const ${XINT8_T}* k = (const ${XINT8_T}*) ((const int32_t*) w + ${CHANNEL_TILE}); 198 ${"do " if CHANNEL_TILE > 8 else ""}{ 199 int32x4_t vacc${ABC[0:4]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4); 200 int32x4_t vacc${ABC[4:8]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4); 201 202 $for K in range(KERNEL_TILE): 203 $if CHANNEL_TILE > 8: 204 $if DATATYPE == "QU8": 205 const int16x8_t vi${K}x${ABC[0:8]} = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(i${K}))); i${K} += 8; 206 $else: 207 const int16x8_t vi${K}x${ABC[0:8]} = vmovl_s8(vld1_s8(i${K})); i${K} += 8; 208 $else: 209 $if DATATYPE == "QU8": 210 const int16x8_t vi${K}x${ABC[0:8]} = vreinterpretq_s16_u16(vmovl_u8(vld1_u8(i${K}))); 211 $else: 212 const int16x8_t vi${K}x${ABC[0:8]} = vmovl_s8(vld1_s8(i${K})); 213 $if CHANNEL_TILE > 8: 214 $if K == 0: 215 $if DATATYPE == "QU8": 216 const int16x8_t vk${K}x${ABC[0:8]} = vreinterpretq_s16_u16(vsubl_u8(vld1_u8(k), vkernel_zero_point)); k += 8; 217 $else: 218 const int16x8_t vk${K}x${ABC[0:8]} = vmovl_s8(vld1_s8(k)); k += 8; 219 $else: 220 $if DATATYPE == "QU8": 221 const int16x8_t vk${K}x${ABC[0:8]} = vreinterpretq_s16_u16(vsubl_u8(vld1_u8((const void*) (k + ${K * CHANNEL_TILE - 8})), vkernel_zero_point)); 222 $else: 223 const int16x8_t vk${K}x${ABC[0:8]} = vmovl_s8(vld1_s8((const void*) (k + ${K * CHANNEL_TILE - 8}))); 224 $else: 225 $if K == 0: 226 $if DATATYPE == "QU8": 227 const int16x8_t vk${K}x${ABC[0:8]} = vreinterpretq_s16_u16(vsubl_u8(vld1_u8(w), vkernel_zero_point)); 228 $else: 229 const int16x8_t vk${K}x${ABC[0:8]} = vmovl_s8(vld1_s8(w)); 230 $else: 231 $if DATATYPE == "QU8": 232 const int16x8_t vk${K}x${ABC[0:8]} = vreinterpretq_s16_u16(vsubl_u8(vld1_u8((const void*) ((const ${XINT8_T}*) w + ${K * CHANNEL_TILE})), vkernel_zero_point)); 233 $else: 234 const int16x8_t vk${K}x${ABC[0:8]} = vmovl_s8(vld1_s8((const void*) ((const ${XINT8_T}*) w + ${K * CHANNEL_TILE}))); 235 236 vacc${ABC[0:4]} = vmlal_s16(vacc${ABC[0:4]}, vget_low_s16(vi${K}x${ABC[0:8]}), vget_low_s16(vk${K}x${ABC[0:8]})); 237 vacc${ABC[4:8]} = vmlal_s16(vacc${ABC[4:8]}, vget_high_s16(vi${K}x${ABC[0:8]}), vget_high_s16(vk${K}x${ABC[0:8]})); 238 239 $if REQUANTIZATION == "RNDNU": 240 vacc${ABC[0:4]} = vqshlq_s32(vacc${ABC[0:4]}, vright_pre_shift); 241 vacc${ABC[4:8]} = vqshlq_s32(vacc${ABC[4:8]}, vright_pre_shift); 242 243 vacc${ABC[0:4]} = vqdmulhq_s32(vacc${ABC[0:4]}, vmultiplier); 244 vacc${ABC[4:8]} = vqdmulhq_s32(vacc${ABC[4:8]}, vmultiplier); 245 246 vacc${ABC[0:4]} = vrshlq_s32(vacc${ABC[0:4]}, vright_post_shift); 247 vacc${ABC[4:8]} = vrshlq_s32(vacc${ABC[4:8]}, vright_post_shift); 248 $elif REQUANTIZATION == "FP32": 249 float32x4_t vfpacc${ABC[0:4]} = vcvtq_f32_s32(vacc${ABC[0:4]}); 250 float32x4_t vfpacc${ABC[4:8]} = vcvtq_f32_s32(vacc${ABC[4:8]}); 251 252 $if DATATYPE == "QC8": 253 const float32x4_t vscale${ABC[0:4]} = vld1q_f32((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 8} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(${XINT8_T}))); 254 const float32x4_t vscale${ABC[4:8]} = vld1q_f32((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 8} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(${XINT8_T}) + 4 * sizeof(float))); 255 vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale${ABC[0:4]}); 256 vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale${ABC[4:8]}); 257 $else: 258 vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale); 259 vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale); 260 261 $if ARMV8: 262 vacc${ABC[0:4]} = vcvtnq_s32_f32(vfpacc${ABC[0:4]}); 263 vacc${ABC[4:8]} = vcvtnq_s32_f32(vfpacc${ABC[4:8]}); 264 $else: 265 vacc${ABC[0:4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[0:4]}, vmagic_bias)); 266 vacc${ABC[4:8]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[4:8]}, vmagic_bias)); 267 268 vacc${ABC[0:4]} = vqsubq_s32(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point); 269 vacc${ABC[4:8]} = vqsubq_s32(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point); 270 271#if XNN_ARCH_ARM64 272 int16x8_t vacc${ABC[0:8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[0:4]}), vacc${ABC[4:8]}); 273#else 274 int16x8_t vacc${ABC[0:8]} = vcombine_s16(vqmovn_s32(vacc${ABC[0:4]}), vqmovn_s32(vacc${ABC[4:8]})); 275#endif 276 $if REQUANTIZATION != "FP32" or ARMV8: 277 vacc${ABC[0:8]} = vqaddq_s16(vacc${ABC[0:8]}, voutput_zero_point); 278 279 ${XINT8X8_T} vout${ABC[0:8]} = ${VQMOVXN_S16}(vacc${ABC[0:8]}); 280 $if CHANNEL_TILE == 8: 281 vout${ABC[0:8]} = ${VMAX_X8}(vout${ABC[0:8]}, voutput_min); 282 vout${ABC[0:8]} = ${VMIN_X8}(vout${ABC[0:8]}, voutput_max); 283 $else: 284 vout${ABC[0:8]} = ${VMAX_X8}(vout${ABC[0:8]}, ${VGET_LOW_X8}(voutput_min)); 285 vout${ABC[0:8]} = ${VMIN_X8}(vout${ABC[0:8]}, ${VGET_LOW_X8}(voutput_max)); 286 287 $if CHANNEL_TILE > 8: 288 if XNN_LIKELY(c >= 8) { 289 ${VST1_X8}(output, vout${ABC[0:8]}); output += 8; 290 c -= 8; 291 } else { 292 if (c & 4) { 293 vst1_lane_u32((void*) output, ${VREINTERPRET_U32_X8}(vout${ABC[0:8]}), 0); output += 4; 294 vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 4); 295 } 296 if (c & 2) { 297 vst1_lane_u16((void*) output, ${VREINTERPRET_U16_X8}(vout${ABC[0:8]}), 0); output += 2; 298 vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 2); 299 } 300 if (c & 1) { 301 ${VST1_LANE_X8}(output, vout${ABC[0:8]}, 0); output += 1; 302 } 303 c = 0; 304 } 305 $else: 306 if (c & 4) { 307 vst1_lane_u32((void*) output, ${VREINTERPRET_U32_X8}(vout${ABC[0:8]}), 0); output += 4; 308 vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 4); 309 } 310 if (c & 2) { 311 vst1_lane_u16((void*) output, ${VREINTERPRET_U16_X8}(vout${ABC[0:8]}), 0); output += 2; 312 vout${ABC[0:8]} = ${VEXT_X8}(vout${ABC[0:8]}, vout${ABC[0:8]}, 2); 313 } 314 if (c & 1) { 315 ${VST1_LANE_X8}(output, vout${ABC[0:8]}, 0); output += 1; 316 } 317 }${" while (c != 0);" if CHANNEL_TILE > 8 else ""} 318 } 319 320 output = (${XINT8_T}*) ((uintptr_t) output + output_increment); 321 } while (--output_width != 0); 322} 323