1// Copyright 2022 Google LLC 2// 3// This source code is licensed under the BSD-style license found in the 4// LICENSE file in the root directory of this source tree. 5 6#include <xnnpack/assembly.h> 7 8# void xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64( 9# size_t mr, x0 10# size_t nc, x1 11# size_t kc, x2 / x0 12# size_t ks, x3 / x9 13# const void**restrict a, x4 14# const void*restrict w, x5 15# void*restrict c, x6 16# size_t cm_stride, x7 17# size_t cn_stride, [sp] -> x10 18# size_t a_offset, [sp + 8] -> x11 19# const void* zero, [sp + 16] -> x12 20# const xnn_f16_minmax_params params [sp + 24] -> (x8) 21 22# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS. 23 24# Register usage 25# A0 x8 v0 26# A1 x13 v1 27# A2 x14 v2 28# A3 x15 v3 29 30# B x5 v20 v21 v22 v23 v16 v17 v18 v19 31 32# C0 x6 v24 v25 33# C1 x16 v26 v27 34# C2 x17 v28 v29 35# C3 x7 v30 v31 36 37# Clamp v4, v5 38 39BEGIN_FUNCTION xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64 40 41 # Load cn_stride, a_offset 42 LDP x10, x11, [sp] 43 44 # Load zero, params pointer 45 LDP x12, x8, [sp, 16] 46 47 # Load params values 48 LD2R {v4.8h, v5.8h}, [x8] 49 50 # Clamp C pointers 51 CMP x0, 2 // if mr < 2 52 ADD x16, x6, x7 // c1 = c0 + cm_stride 53 CSEL x16, x6, x16, LO // c1 = c0 54 ADD x17, x16, x7 // c2 = c1 + cm_stride 55 // if mr <= 2 56 CSEL x17, x16, x17, LS // c2 = c1 57 CMP x0, 4 // if mr < 4 58 ADD x7, x17, x7 // c3 = c2 + cm_stride 59 CSEL x7, x17, x7, LO // c3 = c2 60 610: 62 # Load initial bias from w into accumulators 63 LDR q24, [x5], 16 64 LDR q25, [x5], 16 65 MOV v26.16b, v24.16b 66 MOV v28.16b, v24.16b 67 MOV v30.16b, v24.16b 68 MOV v27.16b, v25.16b 69 MOV v29.16b, v25.16b 70 MOV v31.16b, v25.16b 71 72 MOV x9, x3 // p = ks 73 741: 75 # Load next 4 A pointers 76 LDP x8, x13, [x4], 16 77 LDP x14, x15, [x4], 16 78 79 CMP x8, x12 // if a0 == zero 80 ADD x8, x8, x11 // a0 += a_offset 81 CSEL x8, x12, x8, EQ // a0 = zero, else += a0 + a_offset 82 CMP x13, x12 // if a1 == zero 83 ADD x13, x13, x11 // a1 += a_offset 84 CSEL x13, x12, x13, EQ // a1 = zero, else += a1 + a_offset 85 CMP x14, x12 // if a2 == zero 86 ADD x14, x14, x11 // a2 += a_offset 87 CSEL x14, x12, x14, EQ // a2 = zero, else += a2 + a_offset 88 CMP x15, x12 // if a3 == zero 89 ADD x15, x15, x11 // a3 += a_offset 90 CSEL x15, x12, x15, EQ // a3 = zero, else += a3 + a_offset 91 92 # Is there at least 4 halffloats (8 bytes)? 93 SUBS x0, x2, 8 // k = kc - 8 94 B.LO 4f 95 96 .p2align 3 97 # Main loop - 2 halffloats of A (4 bytes) 982: 99 LDR d0, [x8], 8 100 LDR q20, [x5], 16 101 LDR q21, [x5], 16 102 LDR d1, [x13], 8 103 LDR d2, [x14], 8 104 LDR d3, [x15], 8 105 LDR q22, [x5], 16 106 LDR q23, [x5], 16 107 LDR q16, [x5], 16 108 LDR q17, [x5], 16 109 LDR q18, [x5], 16 110 LDR q19, [x5], 16 111 SUBS x0, x0, 8 112 FMLA v24.8h, v20.8h, v0.h[0] 113 FMLA v25.8h, v21.8h, v0.h[0] 114 FMLA v26.8h, v20.8h, v1.h[0] 115 FMLA v27.8h, v21.8h, v1.h[0] 116 FMLA v28.8h, v20.8h, v2.h[0] 117 FMLA v29.8h, v21.8h, v2.h[0] 118 FMLA v30.8h, v20.8h, v3.h[0] 119 FMLA v31.8h, v21.8h, v3.h[0] 120 FMLA v24.8h, v22.8h, v0.h[1] 121 FMLA v25.8h, v23.8h, v0.h[1] 122 FMLA v26.8h, v22.8h, v1.h[1] 123 FMLA v27.8h, v23.8h, v1.h[1] 124 FMLA v28.8h, v22.8h, v2.h[1] 125 FMLA v29.8h, v23.8h, v2.h[1] 126 FMLA v30.8h, v22.8h, v3.h[1] 127 FMLA v31.8h, v23.8h, v3.h[1] 128 129 FMLA v24.8h, v16.8h, v0.h[2] 130 FMLA v25.8h, v17.8h, v0.h[2] 131 FMLA v26.8h, v16.8h, v1.h[2] 132 FMLA v27.8h, v17.8h, v1.h[2] 133 FMLA v28.8h, v16.8h, v2.h[2] 134 FMLA v29.8h, v17.8h, v2.h[2] 135 FMLA v30.8h, v16.8h, v3.h[2] 136 FMLA v31.8h, v17.8h, v3.h[2] 137 FMLA v24.8h, v18.8h, v0.h[3] 138 FMLA v25.8h, v19.8h, v0.h[3] 139 FMLA v26.8h, v18.8h, v1.h[3] 140 FMLA v27.8h, v19.8h, v1.h[3] 141 FMLA v28.8h, v18.8h, v2.h[3] 142 FMLA v29.8h, v19.8h, v2.h[3] 143 FMLA v30.8h, v18.8h, v3.h[3] 144 FMLA v31.8h, v19.8h, v3.h[3] 145 B.HS 2b 146 147 # Is there a remainder?- 1 halffloat of A (2 bytes) 148 ANDS x0, x0, 7 149 B.NE 4f 150 1513: 152 # ks loop 153 SUBS x9, x9, 32 // ks -= MR * sizeof(void*) 154 B.HI 1b 155 156 # Clamp 157 FMAX v24.8h, v24.8h, v4.8h 158 FMAX v25.8h, v25.8h, v4.8h 159 FMAX v26.8h, v26.8h, v4.8h 160 FMAX v27.8h, v27.8h, v4.8h 161 FMAX v28.8h, v28.8h, v4.8h 162 FMAX v29.8h, v29.8h, v4.8h 163 FMAX v30.8h, v30.8h, v4.8h 164 FMAX v31.8h, v31.8h, v4.8h 165 FMIN v24.8h, v24.8h, v5.8h 166 FMIN v25.8h, v25.8h, v5.8h 167 FMIN v26.8h, v26.8h, v5.8h 168 FMIN v27.8h, v27.8h, v5.8h 169 FMIN v28.8h, v28.8h, v5.8h 170 FMIN v29.8h, v29.8h, v5.8h 171 FMIN v30.8h, v30.8h, v5.8h 172 FMIN v31.8h, v31.8h, v5.8h 173 174 # Store full 4 x 16 175 SUBS x1, x1, 16 176 B.LO 6f 177 178 STP q30, q31, [x7] 179 ADD x7, x7, x10 180 STP q28, q29, [x17] 181 ADD x17, x17, x10 182 STP q26, q27, [x16] 183 ADD x16, x16, x10 184 STP q24, q25, [x6] 185 ADD x6, x6, x10 186 187 SUB x4, x4, x3 // a -= ks 188 189 # nc loop 190 B.HI 0b 191 RET 192 193 194 # Remainder- 1 to 3 halffloats of A (2 to 6 bytes) 1954: 196 TBZ x0, 2, 5f 197 LDR s0, [x8], 4 198 LDR q20, [x5], 16 199 LDR q21, [x5], 16 200 LDR s1, [x13], 4 201 LDR s2, [x14], 4 202 LDR s3, [x15], 4 203 LDR q22, [x5], 16 204 LDR q23, [x5], 16 205 FMLA v24.8h, v20.8h, v0.h[0] 206 FMLA v25.8h, v21.8h, v0.h[0] 207 FMLA v26.8h, v20.8h, v1.h[0] 208 FMLA v27.8h, v21.8h, v1.h[0] 209 FMLA v28.8h, v20.8h, v2.h[0] 210 FMLA v29.8h, v21.8h, v2.h[0] 211 FMLA v30.8h, v20.8h, v3.h[0] 212 FMLA v31.8h, v21.8h, v3.h[0] 213 FMLA v24.8h, v22.8h, v0.h[1] 214 FMLA v25.8h, v23.8h, v0.h[1] 215 FMLA v26.8h, v22.8h, v1.h[1] 216 FMLA v27.8h, v23.8h, v1.h[1] 217 FMLA v28.8h, v22.8h, v2.h[1] 218 FMLA v29.8h, v23.8h, v2.h[1] 219 FMLA v30.8h, v22.8h, v3.h[1] 220 FMLA v31.8h, v23.8h, v3.h[1] 221 TBZ x0, 1, 3b 222 2235: 224 LDR h0, [x8], 2 225 LDR q20, [x5], 16 226 LDR q21, [x5], 16 227 LDR h1, [x13], 2 228 LDR h2, [x14], 2 229 LDR h3, [x15], 2 230 FMLA v24.8h, v20.8h, v0.h[0] 231 FMLA v25.8h, v21.8h, v0.h[0] 232 FMLA v26.8h, v20.8h, v1.h[0] 233 FMLA v27.8h, v21.8h, v1.h[0] 234 FMLA v28.8h, v20.8h, v2.h[0] 235 FMLA v29.8h, v21.8h, v2.h[0] 236 FMLA v30.8h, v20.8h, v3.h[0] 237 FMLA v31.8h, v21.8h, v3.h[0] 238 B 3b 239 240 # Store odd width 2416: 242 TBZ x1, 3, 7f 243 STR q30, [x7], 16 244 MOV v30.16b, v31.16b 245 STR q28, [x17], 16 246 MOV v28.16b, v29.16b 247 STR q26, [x16], 16 248 MOV v26.16b, v27.16b 249 STR q24, [x6], 16 250 MOV v24.16b, v25.16b 251 2527: 253 TBZ x1, 2, 8f 254 STR d30, [x7], 8 255 STR d28, [x17], 8 256 DUP d30, v30.d[1] 257 DUP d28, v28.d[1] 258 STR d26, [x16], 8 259 STR d24, [x6], 8 260 DUP d26, v26.d[1] 261 DUP d24, v24.d[1] 262 2638: 264 TBZ x1, 1, 9f 265 STR s30, [x7], 4 266 STR s28, [x17], 4 267 DUP s30, v30.s[1] 268 DUP s28, v28.s[1] 269 STR s26, [x16], 4 270 STR s24, [x6], 4 271 DUP s26, v26.s[1] 272 DUP s24, v24.s[1] 2739: 274 TBZ x1, 0, 10f 275 STR h30, [x7] 276 STR h28, [x17] 277 STR h26, [x16] 278 STR h24, [x6] 27910: 280 RET 281 282END_FUNCTION xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64 283 284#ifdef __ELF__ 285.section ".note.GNU-stack","",%progbits 286#endif 287