1// Copyright 2021 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6$assert REQUANTIZATION in ["FP32", "RNDNU"] 7$assert not CHANNELWISE or REQUANTIZATION == "FP32" 8 9#include <xnnpack/assembly.h> 10 11$DATATYPE = "qc8" if CHANNELWISE else "qs8" 12$PARAMS_UNION = "xnn_qs8_minmax_params" if CHANNELWISE else "xnn_qs8_conv_minmax_params" 13$REWIND_DECREMENT = 3 if CHANNELWISE else {"RNDNU": 15, "FP32": 7}[REQUANTIZATION] 14# void xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x16c4__aarch64_neondot_ld64( 15# size_t mr, x0 16# size_t nc, x1 17# size_t kc, x2 / x0 18# size_t ks, x3 / x9 19# const int8_t**restrict a, x4 20# const int8_t* restrict w, x5 21# int8_t* restrict c, x6 22# size_t cm_stride, x7 23# size_t cn_stride, [sp] -> (x0) 24# size_t a_offset, [sp + 8] -> x8 25# const int8_t* zero, [sp + 16] -> x12 26# const union ${PARAMS_UNION} params [sp + 24] -> x11 27 28# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS. 29 30# Register usage 31# A0 x13 v0 32# A1 x14 v1 33# A2 x15 v2 34# A3 x10 v3 35# B x5 v4 v5 v6 v7 36# C0 x6 v16 v20 v24 v28 37# C1 x16 v17 v21 v25 v29 38# C2 x17 v18 v22 v26 v30 39# C3 x7 v19 v23 v27 v31 40# unused v8 v9 v10 v11 v12 v13 v14 v15 41 42BEGIN_FUNCTION xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x16c4__aarch64_neondot_ld64 43 44 # Clamp C pointers 45 CMP x0, 2 // if mr < 2 46 LDR x8, [sp, 8] // Load a_offset 47 ADD x16, x6, x7 // c1 = c0 + cm_stride 48 CSEL x16, x6, x16, LO // c1 = c0 49 ADD x2, x2, 3 // kc = (kc + 3) & ~3 50 51 ADD x17, x16, x7 // c2 = c1 + cm_stride 52 LDP x12, x11, [sp, 16] // Load zero, params pointer 53 // if mr <= 2 54 CSEL x17, x16, x17, LS // c2 = c1 55 BIC x2, x2, 3 56 57 CMP x0, 4 // if mr < 4 58 ADD x7, x17, x7 // c3 = c2 + cm_stride 59 CSEL x7, x17, x7, LO // c3 = c2 60 61 .p2align 3 620: 63 # Load initial bias from w into accumulators 64 LDP q16, q20, [x5], 32 65 MOV v17.16b, v16.16b 66 MOV v18.16b, v16.16b 67 LDP q24, q28, [x5], 32 68 MOV v19.16b, v16.16b 69 MOV v21.16b, v20.16b 70 MOV v22.16b, v20.16b 71 MOV v23.16b, v20.16b 72 MOV v25.16b, v24.16b 73 MOV v26.16b, v24.16b 74 MOV v27.16b, v24.16b 75 MOV v29.16b, v28.16b 76 MOV v30.16b, v28.16b 77 MOV v31.16b, v28.16b 78 MOV x9, x3 // p = ks 79 80 .p2align 3 811: 82 # Load next 4 A pointers 83 LDP x13, x14, [x4], 16 84 LDP x15, x10, [x4], 16 85 86 CMP x13, x12 // if a0 == zero 87 ADD x13, x13, x8 // a0 += a_offset 88 CSEL x13, x12, x13, EQ // a0 = zero, else += a0 + a_offset 89 CMP x14, x12 // if a1 == zero 90 ADD x14, x14, x8 // a1 += a_offset 91 CSEL x14, x12, x14, EQ // a1 = zero, else += a1 + a_offset 92 CMP x15, x12 // if a2 == zero 93 ADD x15, x15, x8 // a2 += a_offset 94 CSEL x15, x12, x15, EQ // a2 = zero, else += a2 + a_offset 95 CMP x10, x12 // if a3 == zero 96 ADD x10, x10, x8 // a3 += a_offset 97 CSEL x10, x12, x10, EQ // a3 = zero, else += a3 + a_offset 98 99 # Is there at least 8 bytes for main loop? 100 SUBS x0, x2, 8 // k = kc - 8 101 B.LO 4f 102 103 # Main loop - 8 bytes of A 104 .p2align 3 1052: 106 LDR d0, [x13], 8 107 LDR q4, [x5], 16 108 LDR d1, [x14], 8 109 LDR d2, [x15], 8 110 LDR d3, [x10], 8 111 LDR q5, [x5], 16 112 SDOT v16.4s, v4.16b, v0.4b[0] 113 SDOT v17.4s, v4.16b, v1.4b[0] 114 LDP q6, q7, [x5], 32 115 SDOT v18.4s, v4.16b, v2.4b[0] 116 SDOT v19.4s, v4.16b, v3.4b[0] 117 SDOT v20.4s, v5.16b, v0.4b[0] 118 SDOT v21.4s, v5.16b, v1.4b[0] 119 SDOT v22.4s, v5.16b, v2.4b[0] 120 SDOT v23.4s, v5.16b, v3.4b[0] 121 SDOT v24.4s, v6.16b, v0.4b[0] 122 SDOT v25.4s, v6.16b, v1.4b[0] 123 LDP q4, q5, [x5], 32 124 SDOT v26.4s, v6.16b, v2.4b[0] 125 SDOT v27.4s, v6.16b, v3.4b[0] 126 SDOT v28.4s, v7.16b, v0.4b[0] 127 SDOT v29.4s, v7.16b, v1.4b[0] 128 SDOT v30.4s, v7.16b, v2.4b[0] 129 SDOT v31.4s, v7.16b, v3.4b[0] 130 SDOT v16.4s, v4.16b, v0.4b[1] 131 SDOT v17.4s, v4.16b, v1.4b[1] 132 LDP q6, q7, [x5], 32 133 SDOT v18.4s, v4.16b, v2.4b[1] 134 SDOT v19.4s, v4.16b, v3.4b[1] 135 SDOT v20.4s, v5.16b, v0.4b[1] 136 SDOT v21.4s, v5.16b, v1.4b[1] 137 SDOT v22.4s, v5.16b, v2.4b[1] 138 SDOT v23.4s, v5.16b, v3.4b[1] 139 SDOT v24.4s, v6.16b, v0.4b[1] 140 SDOT v25.4s, v6.16b, v1.4b[1] 141 SDOT v26.4s, v6.16b, v2.4b[1] 142 SDOT v27.4s, v6.16b, v3.4b[1] 143 SDOT v28.4s, v7.16b, v0.4b[1] 144 SDOT v29.4s, v7.16b, v1.4b[1] 145 SDOT v30.4s, v7.16b, v2.4b[1] 146 SUBS x0, x0, 8 147 SDOT v31.4s, v7.16b, v3.4b[1] 148 B.HS 2b 149 150 # Is there a remainder?- 4 bytes of A 151 TBNZ x0, 2, 4f 152 153 # ks loop 154 SUBS x9, x9, 32 // ks -= MR * sizeof(int8_t*) 155 B.HI 1b 156 1573: 158 $if REQUANTIZATION == "RNDNU": 159 # Apply params - preshift, scale, postshift, bias and clamp 160 LD1R {v4.4s}, [x11], 4 161 SQSHL v16.4s, v16.4s, v4.4s // shift to upper bits 162 SQSHL v17.4s, v17.4s, v4.4s 163 SQSHL v18.4s, v18.4s, v4.4s 164 SQSHL v19.4s, v19.4s, v4.4s 165 SQSHL v20.4s, v20.4s, v4.4s 166 SQSHL v21.4s, v21.4s, v4.4s 167 SQSHL v22.4s, v22.4s, v4.4s 168 SQSHL v23.4s, v23.4s, v4.4s 169 LD1R {v5.4s}, [x11], 4 170 SQSHL v24.4s, v24.4s, v4.4s 171 SQSHL v25.4s, v25.4s, v4.4s 172 SQSHL v26.4s, v26.4s, v4.4s 173 SQSHL v27.4s, v27.4s, v4.4s 174 SQSHL v28.4s, v28.4s, v4.4s 175 SQSHL v29.4s, v29.4s, v4.4s 176 SQSHL v30.4s, v30.4s, v4.4s 177 SQSHL v31.4s, v31.4s, v4.4s 178 LD1R {v6.4s}, [x11], 4 179 SQDMULH v16.4s, v16.4s, v5.4s // scale without rounding 180 SQDMULH v17.4s, v17.4s, v5.4s 181 SQDMULH v18.4s, v18.4s, v5.4s 182 SQDMULH v19.4s, v19.4s, v5.4s 183 SQDMULH v20.4s, v20.4s, v5.4s 184 SQDMULH v21.4s, v21.4s, v5.4s 185 SQDMULH v22.4s, v22.4s, v5.4s 186 SQDMULH v23.4s, v23.4s, v5.4s 187 SQDMULH v24.4s, v24.4s, v5.4s 188 SQDMULH v25.4s, v25.4s, v5.4s 189 SQDMULH v26.4s, v26.4s, v5.4s 190 SQDMULH v27.4s, v27.4s, v5.4s 191 SQDMULH v28.4s, v28.4s, v5.4s 192 SQDMULH v29.4s, v29.4s, v5.4s 193 SQDMULH v30.4s, v30.4s, v5.4s 194 SQDMULH v31.4s, v31.4s, v5.4s 195 SRSHL v16.4s, v16.4s, v6.4s // signed rounding shift left 196 SRSHL v17.4s, v17.4s, v6.4s 197 SRSHL v18.4s, v18.4s, v6.4s 198 SRSHL v19.4s, v19.4s, v6.4s 199 SRSHL v20.4s, v20.4s, v6.4s 200 SRSHL v21.4s, v21.4s, v6.4s 201 SRSHL v22.4s, v22.4s, v6.4s 202 SRSHL v23.4s, v23.4s, v6.4s 203 SRSHL v24.4s, v24.4s, v6.4s 204 SRSHL v25.4s, v25.4s, v6.4s 205 SRSHL v26.4s, v26.4s, v6.4s 206 SRSHL v27.4s, v27.4s, v6.4s 207 SRSHL v28.4s, v28.4s, v6.4s 208 SRSHL v29.4s, v29.4s, v6.4s 209 SRSHL v30.4s, v30.4s, v6.4s 210 SRSHL v31.4s, v31.4s, v6.4s 211 $elif REQUANTIZATION == "FP32": 212 SCVTF v16.4s, v16.4s 213 SCVTF v17.4s, v17.4s 214 $if not CHANNELWISE: 215 # Apply params - scale, bias and clamp 216 LD1R {v4.4s}, [x11], 4 217 SCVTF v18.4s, v18.4s 218 SCVTF v19.4s, v19.4s 219 $else: 220 # Load per channel scale values from weights 221 LDR q4, [x5], 16 222 SCVTF v18.4s, v18.4s 223 SCVTF v19.4s, v19.4s 224 LDR q5, [x5], 16 225 SCVTF v20.4s, v20.4s 226 SCVTF v21.4s, v21.4s 227 SCVTF v22.4s, v22.4s 228 SCVTF v23.4s, v23.4s 229 SCVTF v24.4s, v24.4s 230 SCVTF v25.4s, v25.4s 231 SCVTF v26.4s, v26.4s 232 SCVTF v27.4s, v27.4s 233 SCVTF v28.4s, v28.4s 234 SCVTF v29.4s, v29.4s 235 SCVTF v30.4s, v30.4s 236 SCVTF v31.4s, v31.4s 237 238 $if CHANNELWISE: 239 LDR q6, [x5], 16 240 FMUL v16.4s, v16.4s, v4.4s 241 FMUL v17.4s, v17.4s, v4.4s 242 FMUL v18.4s, v18.4s, v4.4s 243 FMUL v19.4s, v19.4s, v4.4s 244 FMUL v20.4s, v20.4s, v5.4s 245 LDR q4, [x5], 16 246 FMUL v21.4s, v21.4s, v5.4s 247 FMUL v22.4s, v22.4s, v5.4s 248 FMUL v23.4s, v23.4s, v5.4s 249 FMUL v24.4s, v24.4s, v6.4s 250 FMUL v25.4s, v25.4s, v6.4s 251 FMUL v26.4s, v26.4s, v6.4s 252 FMUL v27.4s, v27.4s, v6.4s 253 FMUL v28.4s, v28.4s, v4.4s 254 FMUL v29.4s, v29.4s, v4.4s 255 FMUL v30.4s, v30.4s, v4.4s 256 FMUL v31.4s, v31.4s, v4.4s 257 $else: 258 FMUL v16.4s, v16.4s, v4.4s 259 FMUL v17.4s, v17.4s, v4.4s 260 FMUL v18.4s, v18.4s, v4.4s 261 FMUL v19.4s, v19.4s, v4.4s 262 FMUL v20.4s, v20.4s, v4.4s 263 FMUL v21.4s, v21.4s, v4.4s 264 FMUL v22.4s, v22.4s, v4.4s 265 FMUL v23.4s, v23.4s, v4.4s 266 FMUL v24.4s, v24.4s, v4.4s 267 FMUL v25.4s, v25.4s, v4.4s 268 FMUL v26.4s, v26.4s, v4.4s 269 FMUL v27.4s, v27.4s, v4.4s 270 FMUL v28.4s, v28.4s, v4.4s 271 FMUL v29.4s, v29.4s, v4.4s 272 FMUL v30.4s, v30.4s, v4.4s 273 FMUL v31.4s, v31.4s, v4.4s 274 275 FCVTNS v16.4s, v16.4s 276 FCVTNS v17.4s, v17.4s 277 FCVTNS v18.4s, v18.4s 278 FCVTNS v19.4s, v19.4s 279 FCVTNS v20.4s, v20.4s 280 FCVTNS v21.4s, v21.4s 281 FCVTNS v22.4s, v22.4s 282 FCVTNS v23.4s, v23.4s 283 FCVTNS v24.4s, v24.4s 284 FCVTNS v25.4s, v25.4s 285 FCVTNS v26.4s, v26.4s 286 FCVTNS v27.4s, v27.4s 287 FCVTNS v28.4s, v28.4s 288 FCVTNS v29.4s, v29.4s 289 FCVTNS v30.4s, v30.4s 290 FCVTNS v31.4s, v31.4s 291 292 SQXTN v16.4h, v16.4s 293 SQXTN v17.4h, v17.4s 294 SQXTN v18.4h, v18.4s 295 SQXTN v19.4h, v19.4s 296 SQXTN v24.4h, v24.4s 297 SQXTN v25.4h, v25.4s 298 SQXTN v26.4h, v26.4s 299 SQXTN v27.4h, v27.4s 300 LD1R {v6.8h}, [x11], 2 // add bias 301 302 SQXTN2 v16.8h, v20.4s 303 SQXTN2 v17.8h, v21.4s 304 SQXTN2 v18.8h, v22.4s 305 SQXTN2 v19.8h, v23.4s 306 SQXTN2 v24.8h, v28.4s 307 SQXTN2 v25.8h, v29.4s 308 SQXTN2 v26.8h, v30.4s 309 SQXTN2 v27.8h, v31.4s 310 311 SQADD v16.8h, v16.8h, v6.8h 312 SQADD v17.8h, v17.8h, v6.8h 313 SQADD v18.8h, v18.8h, v6.8h 314 SQADD v19.8h, v19.8h, v6.8h 315 SQADD v24.8h, v24.8h, v6.8h 316 SQADD v25.8h, v25.8h, v6.8h 317 SQADD v26.8h, v26.8h, v6.8h 318 SQADD v27.8h, v27.8h, v6.8h 319 LD1R {v4.16b}, [x11], 1 // clamp min value 320 321 SQXTN v0.8b, v16.8h 322 SQXTN v1.8b, v17.8h 323 SQXTN v2.8b, v18.8h 324 SQXTN v3.8b, v19.8h 325 LD1R {v5.16b}, [x11] // clamp max value 326 SQXTN2 v0.16b, v24.8h 327 SQXTN2 v1.16b, v25.8h 328 SQXTN2 v2.16b, v26.8h 329 SQXTN2 v3.16b, v27.8h 330 LDR x0, [sp] // cn_stride 331 SMAX v0.16b, v0.16b, v4.16b 332 SMAX v1.16b, v1.16b, v4.16b 333 SUB x11, x11, ${REWIND_DECREMENT} // rewind params pointer 334 SMAX v2.16b, v2.16b, v4.16b 335 SMAX v3.16b, v3.16b, v4.16b 336 SUBS x1, x1, 16 337 SMIN v0.16b, v0.16b, v5.16b 338 SMIN v1.16b, v1.16b, v5.16b 339 SMIN v2.16b, v2.16b, v5.16b 340 SMIN v3.16b, v3.16b, v5.16b 341 B.LO 5f 342 343 # Store full 4 x 16 344 ST1 {v3.16b}, [x7], x0 345 ST1 {v2.16b}, [x17], x0 346 ST1 {v1.16b}, [x16], x0 347 ST1 {v0.16b}, [x6], x0 348 349 SUB x4, x4, x3 // a -= ks 350 351 # nc loop 352 B.HI 0b 353 RET 354 355 # Remainder- 4 bytes of A 356 .p2align 3 3574: 358 LDR s0, [x13], 4 359 LDR q4, [x5], 16 360 LDR s1, [x14], 4 361 LDR s2, [x15], 4 362 LDR s3, [x10], 4 363 LDR q5, [x5], 16 364 SDOT v16.4s, v4.16b, v0.4b[0] 365 SDOT v17.4s, v4.16b, v1.4b[0] 366 LDP q6, q7, [x5], 32 367 SDOT v18.4s, v4.16b, v2.4b[0] 368 SDOT v19.4s, v4.16b, v3.4b[0] 369 SDOT v20.4s, v5.16b, v0.4b[0] 370 SDOT v21.4s, v5.16b, v1.4b[0] 371 SDOT v22.4s, v5.16b, v2.4b[0] 372 SDOT v23.4s, v5.16b, v3.4b[0] 373 SDOT v24.4s, v6.16b, v0.4b[0] 374 SDOT v25.4s, v6.16b, v1.4b[0] 375 SDOT v26.4s, v6.16b, v2.4b[0] 376 SDOT v27.4s, v6.16b, v3.4b[0] 377 SDOT v28.4s, v7.16b, v0.4b[0] 378 SDOT v29.4s, v7.16b, v1.4b[0] 379 SDOT v30.4s, v7.16b, v2.4b[0] 380 SDOT v31.4s, v7.16b, v3.4b[0] 381 382 # ks loop 383 SUBS x9, x9, 32 // ks -= MR * sizeof(int8_t*) 384 B.HI 1b 385 B 3b 386 387 # Store odd width 388 .p2align 3 3895: 390 TBZ x1, 3, 6f 391 STR d3, [x7], 8 392 STR d2, [x17], 8 393 DUP d3, v3.d[1] 394 DUP d2, v2.d[1] 395 STR d1, [x16], 8 396 STR d0, [x6], 8 397 DUP d1, v1.d[1] 398 DUP d0, v0.d[1] 3996: 400 TBZ x1, 2, 7f 401 STR s3, [x7], 4 402 STR s2, [x17], 4 403 DUP s3, v3.s[1] 404 DUP s2, v2.s[1] 405 STR s1, [x16], 4 406 STR s0, [x6], 4 407 DUP s1, v1.s[1] 408 DUP s0, v0.s[1] 4097: 410 TBZ x1, 1, 8f 411 STR h3, [x7], 2 412 STR h2, [x17], 2 413 DUP h3, v3.h[1] 414 DUP h2, v2.h[1] 415 STR h1, [x16], 2 416 STR h0, [x6], 2 417 DUP h1, v1.h[1] 418 DUP h0, v0.h[1] 4198: 420 TBZ x1, 0, 9f 421 STR b3, [x7] 422 STR b2, [x17] 423 STR b1, [x16] 424 STR b0, [x6] 4259: 426 RET 427 428END_FUNCTION xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x16c4__aarch64_neondot_ld64 429 430#ifdef __ELF__ 431.section ".note.GNU-stack","",%progbits 432#endif 433