1// Copyright 2020 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert REQUANTIZATION in ["FP32", "RNDNU"] 7$assert DATATYPE in ["QC8", "QS8"] 8$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32" 9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ" 10$assert LOAD_VARIANT in ["LD64", "LD128"] 11$assert CHANNEL_TILE % {"LD64": 8, "LD128": 16}[LOAD_VARIANT] == 0 12$assert CHANNEL_TILE >= 8 13$assert KERNEL_TILE >= 2 14#include <assert.h> 15 16#include <arm_neon.h> 17 18#include <xnnpack/dwconv.h> 19$if REQUANTIZATION == "FP32" and ARMV8: 20 #include <xnnpack/intrinsics-polyfill.h> 21 22 23$PARAMS_STRUCT = REQUANTIZATION.lower() + "_" + ("neonv8" if ARMV8 else "neon") 24$PARAMS_UNION = "xnn_%s_conv_minmax_params" % DATATYPE.lower() 25$ISA = "neonv8" if ARMV8 else "neon" 26void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__${ISA}_${"mla8" if MLA else "mul8"}_${LOAD_VARIANT.lower()}( 27 size_t channels, 28 size_t output_width, 29 const int8_t** input, 30 const void* weights, 31 int8_t* output, 32 size_t input_stride, 33 size_t output_increment, 34 size_t input_offset, 35 const int8_t* zero, 36 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS 37{ 38 assert(channels != 0); 39 assert(output_width != 0); 40 41 $if REQUANTIZATION == "RNDNU": 42 const int32x4_t vright_pre_shift = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.right_pre_shift); 43 const int32x4_t vmultiplier = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.multiplier); 44 const int32x4_t vright_post_shift = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.right_post_shift); 45 $elif REQUANTIZATION == "FP32": 46 $if DATATYPE != "QC8": 47 const float32x4_t vscale = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.scale); 48 $if not ARMV8: 49 const float32x4_t vmagic_bias = vld1q_dup_f32(¶ms->${PARAMS_STRUCT}.magic_bias); 50 const int32x4_t vmagic_bias_less_output_zero_point = vld1q_dup_s32(¶ms->${PARAMS_STRUCT}.magic_bias_less_output_zero_point); 51 $if REQUANTIZATION != "FP32" or ARMV8: 52 const int16x8_t voutput_zero_point = vld1q_dup_s16(¶ms->${PARAMS_STRUCT}.output_zero_point); 53 $if CHANNEL_TILE == 8: 54 const int8x8_t voutput_min = vld1_dup_s8(¶ms->${PARAMS_STRUCT}.output_min); 55 const int8x8_t voutput_max = vld1_dup_s8(¶ms->${PARAMS_STRUCT}.output_max); 56 $else: 57 const int8x16_t voutput_min = vld1q_dup_s8(¶ms->${PARAMS_STRUCT}.output_min); 58 const int8x16_t voutput_max = vld1q_dup_s8(¶ms->${PARAMS_STRUCT}.output_max); 59 do { 60 $for K in range(KERNEL_TILE): 61 const int8_t* i${K} = input[${K}]; 62 assert(i${K} != NULL); 63 if XNN_UNPREDICTABLE(i${K} != zero) { 64 i${K} = (const int8_t*) ((uintptr_t) i${K} + input_offset); 65 } 66 input = (const int8_t**) ((uintptr_t) input + input_stride); 67 68 size_t c = channels; 69 const void* w = weights; 70 for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) { 71 $for C in range(0, CHANNEL_TILE, 4): 72 int32x4_t vacc${ABC[C:C+4]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4); 73 74 $for K in range(KERNEL_TILE): 75 $if LOAD_VARIANT == "LD128": 76 $for C in range(0, CHANNEL_TILE, 16): 77 const int8x16_t vi${K}x${ABC[C:C+16]} = vld1q_s8(i${K}); i${K} += 16; 78 const int8x16_t vk${K}x${ABC[C:C+16]} = vld1q_s8(w); w = (const void*) ((const int8_t*) w + 16); 79 80 $if K == 0: 81 $for C in range(0, CHANNEL_TILE, 16): 82 int16x8_t vprod${ABC[C:C+8]} = vmull_s8(vget_low_s8(vi${K}x${ABC[C:C+16]}), vget_low_s8(vk${K}x${ABC[C:C+16]})); 83 int16x8_t vprod${ABC[C+8:C+16]} = vmull_s8(vget_high_s8(vi${K}x${ABC[C:C+16]}), vget_high_s8(vk${K}x${ABC[C:C+16]})); 84 $elif K % 2 == 0 or K + 1 == KERNEL_TILE or not MLA: 85 $for C in range(0, CHANNEL_TILE, 16): 86 vprod${ABC[C:C+8]} = vmull_s8(vget_low_s8(vi${K}x${ABC[C:C+16]}), vget_low_s8(vk${K}x${ABC[C:C+16]})); 87 vprod${ABC[C+8:C+16]} = vmull_s8(vget_high_s8(vi${K}x${ABC[C:C+16]}), vget_high_s8(vk${K}x${ABC[C:C+16]})); 88 $else: 89 $for C in range(0, CHANNEL_TILE, 16): 90 vprod${ABC[C:C+8]} = vmlal_s8(vprod${ABC[C:C+8]}, vget_low_s8(vi${K}x${ABC[C:C+16]}), vget_low_s8(vk${K}x${ABC[C:C+16]})); 91 vprod${ABC[C+8:C+16]} = vmlal_s8(vprod${ABC[C+8:C+16]}, vget_high_s8(vi${K}x${ABC[C:C+16]}), vget_high_s8(vk${K}x${ABC[C:C+16]})); 92 $else: 93 $for C in range(0, CHANNEL_TILE, 8): 94 const int8x8_t vi${K}x${ABC[C:C+8]} = vld1_s8(i${K}); i${K} += 8; 95 const int8x8_t vk${K}x${ABC[C:C+8]} = vld1_s8(w); w = (const void*) ((const int8_t*) w + 8); 96 97 $if K == 0: 98 $for C in range(0, CHANNEL_TILE, 8): 99 int16x8_t vprod${ABC[C:C+8]} = vmull_s8(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}); 100 $elif K % 2 == 0 or K + 1 == KERNEL_TILE or not MLA: 101 $for C in range(0, CHANNEL_TILE, 8): 102 vprod${ABC[C:C+8]} = vmull_s8(vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}); 103 $else: 104 $for C in range(0, CHANNEL_TILE, 8): 105 vprod${ABC[C:C+8]} = vmlal_s8(vprod${ABC[C:C+8]}, vi${K}x${ABC[C:C+8]}, vk${K}x${ABC[C:C+8]}); 106 107 $if not MLA or K % 2 == 1 or K + 1 == KERNEL_TILE: 108 $for C in range(0, CHANNEL_TILE, 8): 109 vacc${ABC[C:C+4]} = vaddw_s16(vacc${ABC[C:C+4]}, vget_low_s16(vprod${ABC[C:C+8]})); 110 vacc${ABC[C+4:C+8]} = vaddw_s16(vacc${ABC[C+4:C+8]}, vget_high_s16(vprod${ABC[C:C+8]})); 111 112 $if REQUANTIZATION == "RNDNU": 113 $for C in range(0, CHANNEL_TILE, 4): 114 vacc${ABC[C:C+4]} = vqshlq_s32(vacc${ABC[C:C+4]}, vright_pre_shift); 115 116 $for C in range(0, CHANNEL_TILE, 4): 117 vacc${ABC[C:C+4]} = vqdmulhq_s32(vacc${ABC[C:C+4]}, vmultiplier); 118 119 $for C in range(0, CHANNEL_TILE, 4): 120 vacc${ABC[C:C+4]} = vrshlq_s32(vacc${ABC[C:C+4]}, vright_post_shift); 121 $elif REQUANTIZATION == "FP32": 122 $for C in range(0, CHANNEL_TILE, 4): 123 float32x4_t vfpacc${ABC[C:C+4]} = vcvtq_f32_s32(vacc${ABC[C:C+4]}); 124 125 $if DATATYPE == "QC8": 126 $for C in range(0, CHANNEL_TILE, 4): 127 const float32x4_t vscale${ABC[C:C+4]} = vld1q_f32((const float*) w); w = (const void*) ((const float*) w + 4); 128 129 $for C in range(0, CHANNEL_TILE, 4): 130 vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale${ABC[C:C+4]}); 131 $else: 132 $for C in range(0, CHANNEL_TILE, 4): 133 vfpacc${ABC[C:C+4]} = vmulq_f32(vfpacc${ABC[C:C+4]}, vscale); 134 135 $if ARMV8: 136 $for C in range(0, CHANNEL_TILE, 4): 137 vacc${ABC[C:C+4]} = vcvtnq_s32_f32(vfpacc${ABC[C:C+4]}); 138 $else: 139 $for C in range(0, CHANNEL_TILE, 4): 140 vacc${ABC[C:C+4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[C:C+4]}, vmagic_bias)); 141 142 $for C in range(0, CHANNEL_TILE, 4): 143 vacc${ABC[C:C+4]} = vqsubq_s32(vacc${ABC[C:C+4]}, vmagic_bias_less_output_zero_point); 144 145#if XNN_ARCH_ARM64 146 $for C in range(0, CHANNEL_TILE, 8): 147 int16x8_t vacc${ABC[C:C+8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[C:C+4]}), vacc${ABC[C+4:C+8]}); 148 149 $if REQUANTIZATION != "FP32" or ARMV8: 150 $for C in range(0, CHANNEL_TILE, 8): 151 vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point); 152 153 $for C in range(0, CHANNEL_TILE, 16): 154 $if C + 8 < CHANNEL_TILE: 155 int8x16_t vout${ABC[C:C+16]} = vqmovn_high_s16(vqmovn_s16(vacc${ABC[C:C+8]}), vacc${ABC[C+8:C+16]}); 156 $else: 157 int8x8_t vout${ABC[C:C+8]} = vqmovn_s16(vacc${ABC[C:C+8]}); 158#else // !XNN_ARCH_ARM64 159 $for C in range(0, CHANNEL_TILE, 8): 160 int16x8_t vacc${ABC[C:C+8]} = vcombine_s16(vqmovn_s32(vacc${ABC[C:C+4]}), vqmovn_s32(vacc${ABC[C+4:C+8]})); 161 162 $if REQUANTIZATION != "FP32" or ARMV8: 163 $for C in range(0, CHANNEL_TILE, 8): 164 vacc${ABC[C:C+8]} = vqaddq_s16(vacc${ABC[C:C+8]}, voutput_zero_point); 165 166 $for C in range(0, CHANNEL_TILE, 16): 167 $if C + 8 < CHANNEL_TILE: 168 int8x16_t vout${ABC[C:C+16]} = vcombine_s8(vqmovn_s16(vacc${ABC[C:C+8]}), vqmovn_s16(vacc${ABC[C+8:C+16]})); 169 $else: 170 int8x8_t vout${ABC[C:C+8]} = vqmovn_s16(vacc${ABC[C:C+8]}); 171#endif // !XNN_ARCH_ARM64 172 173 $for C in range(0, CHANNEL_TILE, 16): 174 $if C + 8 < CHANNEL_TILE: 175 vout${ABC[C:C+16]} = vmaxq_s8(vout${ABC[C:C+16]}, voutput_min); 176 $elif CHANNEL_TILE == 8: 177 vout${ABC[C:C+8]} = vmax_s8(vout${ABC[C:C+8]}, voutput_min); 178 $else: 179 vout${ABC[C:C+8]} = vmax_s8(vout${ABC[C:C+8]}, vget_low_s8(voutput_min)); 180 181 $for C in range(0, CHANNEL_TILE, 16): 182 $if C + 8 < CHANNEL_TILE: 183 vout${ABC[C:C+16]} = vminq_s8(vout${ABC[C:C+16]}, voutput_max); 184 $elif CHANNEL_TILE == 8: 185 vout${ABC[C:C+8]} = vmin_s8(vout${ABC[C:C+8]}, voutput_max); 186 $else: 187 vout${ABC[C:C+8]} = vmin_s8(vout${ABC[C:C+8]}, vget_low_s8(voutput_max)); 188 189 $for C in range(0, CHANNEL_TILE, 16): 190 $if C + 8 < CHANNEL_TILE: 191 vst1q_s8(output, vout${ABC[C:C+16]}); output += 16; 192 $else: 193 vst1_s8(output, vout${ABC[C:C+8]}); output += 8; 194 } 195 if XNN_UNLIKELY(c != 0) { 196 $if CHANNEL_TILE > 8: 197 const int8_t* k = (const int8_t*) ((const int32_t*) w + ${CHANNEL_TILE}); 198 ${"do " if CHANNEL_TILE > 8 else ""}{ 199 int32x4_t vacc${ABC[0:4]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4); 200 int32x4_t vacc${ABC[4:8]} = vld1q_s32(w); w = (const void*) ((const int32_t*) w + 4); 201 202 $for K in range(KERNEL_TILE): 203 $if CHANNEL_TILE > 8: 204 const int8x8_t vi${K}x${ABC[0:8]} = vld1_s8(i${K}); i${K} += 8; 205 $else: 206 const int8x8_t vi${K}x${ABC[0:8]} = vld1_s8(i${K}); 207 $if CHANNEL_TILE > 8: 208 $if K == 0: 209 const int8x8_t vk${K}x${ABC[0:8]} = vld1_s8(k); k += 8; 210 $else: 211 const int8x8_t vk${K}x${ABC[0:8]} = vld1_s8((const void*) (k + ${K * CHANNEL_TILE - 8})); 212 $else: 213 $if K == 0: 214 const int8x8_t vk${K}x${ABC[0:8]} = vld1_s8(w); 215 $else: 216 const int8x8_t vk${K}x${ABC[0:8]} = vld1_s8((const void*) ((const int8_t*) w + ${K * CHANNEL_TILE})); 217 218 $if K == 0: 219 int16x8_t vprod${ABC[0:8]} = vmull_s8(vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]}); 220 $elif K % 2 == 0 or K + 1 == KERNEL_TILE or not MLA: 221 vprod${ABC[0:8]} = vmull_s8(vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]}); 222 $else: 223 vprod${ABC[0:8]} = vmlal_s8(vprod${ABC[0:8]}, vi${K}x${ABC[0:8]}, vk${K}x${ABC[0:8]}); 224 225 $if not MLA or K % 2 == 1 or K + 1 == KERNEL_TILE: 226 vacc${ABC[0:4]} = vaddw_s16(vacc${ABC[0:4]}, vget_low_s16(vprod${ABC[0:8]})); 227 vacc${ABC[4:8]} = vaddw_s16(vacc${ABC[4:8]}, vget_high_s16(vprod${ABC[0:8]})); 228 229 $if REQUANTIZATION == "RNDNU": 230 vacc${ABC[0:4]} = vqshlq_s32(vacc${ABC[0:4]}, vright_pre_shift); 231 vacc${ABC[4:8]} = vqshlq_s32(vacc${ABC[4:8]}, vright_pre_shift); 232 233 vacc${ABC[0:4]} = vqdmulhq_s32(vacc${ABC[0:4]}, vmultiplier); 234 vacc${ABC[4:8]} = vqdmulhq_s32(vacc${ABC[4:8]}, vmultiplier); 235 236 vacc${ABC[0:4]} = vrshlq_s32(vacc${ABC[0:4]}, vright_post_shift); 237 vacc${ABC[4:8]} = vrshlq_s32(vacc${ABC[4:8]}, vright_post_shift); 238 $elif REQUANTIZATION == "FP32": 239 float32x4_t vfpacc${ABC[0:4]} = vcvtq_f32_s32(vacc${ABC[0:4]}); 240 float32x4_t vfpacc${ABC[4:8]} = vcvtq_f32_s32(vacc${ABC[4:8]}); 241 242 $if DATATYPE == "QC8": 243 const float32x4_t vscale${ABC[0:4]} = vld1q_f32((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 8} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t))); 244 const float32x4_t vscale${ABC[4:8]} = vld1q_f32((const float*) ((uintptr_t) w + ${CHANNEL_TILE - 8} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t) + 4 * sizeof(float))); 245 vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale${ABC[0:4]}); 246 vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale${ABC[4:8]}); 247 $else: 248 vfpacc${ABC[0:4]} = vmulq_f32(vfpacc${ABC[0:4]}, vscale); 249 vfpacc${ABC[4:8]} = vmulq_f32(vfpacc${ABC[4:8]}, vscale); 250 251 $if ARMV8: 252 vacc${ABC[0:4]} = vcvtnq_s32_f32(vfpacc${ABC[0:4]}); 253 vacc${ABC[4:8]} = vcvtnq_s32_f32(vfpacc${ABC[4:8]}); 254 $else: 255 vacc${ABC[0:4]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[0:4]}, vmagic_bias)); 256 vacc${ABC[4:8]} = vreinterpretq_s32_f32(vaddq_f32(vfpacc${ABC[4:8]}, vmagic_bias)); 257 258 vacc${ABC[0:4]} = vqsubq_s32(vacc${ABC[0:4]}, vmagic_bias_less_output_zero_point); 259 vacc${ABC[4:8]} = vqsubq_s32(vacc${ABC[4:8]}, vmagic_bias_less_output_zero_point); 260 261#if XNN_ARCH_ARM64 262 int16x8_t vacc${ABC[0:8]} = vqmovn_high_s32(vqmovn_s32(vacc${ABC[0:4]}), vacc${ABC[4:8]}); 263#else 264 int16x8_t vacc${ABC[0:8]} = vcombine_s16(vqmovn_s32(vacc${ABC[0:4]}), vqmovn_s32(vacc${ABC[4:8]})); 265#endif 266 $if REQUANTIZATION != "FP32" or ARMV8: 267 vacc${ABC[0:8]} = vqaddq_s16(vacc${ABC[0:8]}, voutput_zero_point); 268 269 int8x8_t vout${ABC[0:8]} = vqmovn_s16(vacc${ABC[0:8]}); 270 $if CHANNEL_TILE == 8: 271 vout${ABC[0:8]} = vmax_s8(vout${ABC[0:8]}, voutput_min); 272 vout${ABC[0:8]} = vmin_s8(vout${ABC[0:8]}, voutput_max); 273 $else: 274 vout${ABC[0:8]} = vmax_s8(vout${ABC[0:8]}, vget_low_s8(voutput_min)); 275 vout${ABC[0:8]} = vmin_s8(vout${ABC[0:8]}, vget_low_s8(voutput_max)); 276 277 $if CHANNEL_TILE > 8: 278 if XNN_LIKELY(c >= 8) { 279 vst1_s8(output, vout${ABC[0:8]}); output += 8; 280 c -= 8; 281 } else { 282 if (c & 4) { 283 vst1_lane_u32((void*) output, vreinterpret_u32_s8(vout${ABC[0:8]}), 0); output += 4; 284 vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 4); 285 } 286 if (c & 2) { 287 vst1_lane_u16((void*) output, vreinterpret_u16_s8(vout${ABC[0:8]}), 0); output += 2; 288 vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 2); 289 } 290 if (c & 1) { 291 vst1_lane_s8(output, vout${ABC[0:8]}, 0); output += 1; 292 } 293 c = 0; 294 } 295 $else: 296 if (c & 4) { 297 vst1_lane_u32((void*) output, vreinterpret_u32_s8(vout${ABC[0:8]}), 0); output += 4; 298 vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 4); 299 } 300 if (c & 2) { 301 vst1_lane_u16((void*) output, vreinterpret_u16_s8(vout${ABC[0:8]}), 0); output += 2; 302 vout${ABC[0:8]} = vext_s8(vout${ABC[0:8]}, vout${ABC[0:8]}, 2); 303 } 304 if (c & 1) { 305 vst1_lane_s8(output, vout${ABC[0:8]}, 0); output += 1; 306 } 307 }${" while (c != 0);" if CHANNEL_TILE > 8 else ""} 308 } 309 310 output = (int8_t*) ((uintptr_t) output + output_increment); 311 } while (--output_width != 0); 312} 313