xref: /aosp_15_r20/external/XNNPACK/src/f16-igemm/4x16-minmax-aarch64-neonfp16arith-ld64.S (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack/assembly.h>
7
8# void xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64(
9#     size_t mr,                         x0
10#     size_t nc,                         x1
11#     size_t kc,                         x2 / x0
12#     size_t ks,                         x3 / x9
13#     const void**restrict a,            x4
14#     const void*restrict w,             x5
15#     void*restrict c,                   x6
16#     size_t cm_stride,                  x7
17#     size_t cn_stride,                  [sp] -> x10
18#     size_t a_offset,                   [sp + 8] -> x11
19#     const void* zero,                  [sp + 16] -> x12
20#     const xnn_f16_minmax_params params [sp + 24] -> (x8)
21
22# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
23
24# Register usage
25# A0  x8 v0
26# A1 x13 v1
27# A2 x14 v2
28# A3 x15 v3
29
30# B   x5 v20 v21 v22 v23 v16 v17 v18 v19
31
32# C0  x6 v24 v25
33# C1 x16 v26 v27
34# C2 x17 v28 v29
35# C3  x7 v30 v31
36
37# Clamp v4, v5
38
39BEGIN_FUNCTION xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64
40
41        # Load cn_stride, a_offset
42        LDP     x10, x11, [sp]
43
44        # Load zero, params pointer
45        LDP     x12, x8, [sp, 16]
46
47        # Load params values
48        LD2R    {v4.8h, v5.8h}, [x8]
49
50        # Clamp C pointers
51        CMP     x0, 2                   // if mr < 2
52        ADD     x16, x6, x7             // c1 = c0 + cm_stride
53        CSEL    x16, x6, x16, LO        //   c1 = c0
54        ADD     x17, x16, x7            // c2 = c1 + cm_stride
55                                        // if mr <= 2
56        CSEL    x17, x16, x17, LS       //   c2 = c1
57        CMP     x0, 4                   // if mr < 4
58        ADD     x7, x17, x7             // c3 = c2 + cm_stride
59        CSEL    x7, x17, x7, LO         //   c3 = c2
60
610:
62        # Load initial bias from w into accumulators
63        LDR     q24, [x5], 16
64        LDR     q25, [x5], 16
65        MOV     v26.16b, v24.16b
66        MOV     v28.16b, v24.16b
67        MOV     v30.16b, v24.16b
68        MOV     v27.16b, v25.16b
69        MOV     v29.16b, v25.16b
70        MOV     v31.16b, v25.16b
71
72        MOV     x9, x3                  // p = ks
73
741:
75        # Load next 4 A pointers
76        LDP     x8, x13, [x4], 16
77        LDP     x14, x15, [x4], 16
78
79        CMP     x8, x12                 // if a0 == zero
80        ADD     x8, x8, x11             // a0 += a_offset
81        CSEL    x8, x12, x8, EQ         //   a0 = zero, else += a0 + a_offset
82        CMP     x13, x12                // if a1 == zero
83        ADD     x13, x13, x11           // a1 += a_offset
84        CSEL    x13, x12, x13, EQ       //   a1 = zero, else += a1 + a_offset
85        CMP     x14, x12                // if a2 == zero
86        ADD     x14, x14, x11           // a2 += a_offset
87        CSEL    x14, x12, x14, EQ       //   a2 = zero, else += a2 + a_offset
88        CMP     x15, x12                // if a3 == zero
89        ADD     x15, x15, x11           // a3 += a_offset
90        CSEL    x15, x12, x15, EQ       //   a3 = zero, else += a3 + a_offset
91
92        # Is there at least 4 halffloats (8 bytes)?
93        SUBS    x0, x2, 8               // k = kc - 8
94        B.LO    4f
95
96       .p2align 3
97        # Main loop - 2 halffloats of A (4 bytes)
982:
99        LDR     d0,  [x8], 8
100        LDR     q20, [x5], 16
101        LDR     q21, [x5], 16
102        LDR     d1, [x13], 8
103        LDR     d2, [x14], 8
104        LDR     d3, [x15], 8
105        LDR     q22, [x5], 16
106        LDR     q23, [x5], 16
107        LDR     q16, [x5], 16
108        LDR     q17, [x5], 16
109        LDR     q18, [x5], 16
110        LDR     q19, [x5], 16
111        SUBS    x0, x0, 8
112        FMLA    v24.8h, v20.8h, v0.h[0]
113        FMLA    v25.8h, v21.8h, v0.h[0]
114        FMLA    v26.8h, v20.8h, v1.h[0]
115        FMLA    v27.8h, v21.8h, v1.h[0]
116        FMLA    v28.8h, v20.8h, v2.h[0]
117        FMLA    v29.8h, v21.8h, v2.h[0]
118        FMLA    v30.8h, v20.8h, v3.h[0]
119        FMLA    v31.8h, v21.8h, v3.h[0]
120        FMLA    v24.8h, v22.8h, v0.h[1]
121        FMLA    v25.8h, v23.8h, v0.h[1]
122        FMLA    v26.8h, v22.8h, v1.h[1]
123        FMLA    v27.8h, v23.8h, v1.h[1]
124        FMLA    v28.8h, v22.8h, v2.h[1]
125        FMLA    v29.8h, v23.8h, v2.h[1]
126        FMLA    v30.8h, v22.8h, v3.h[1]
127        FMLA    v31.8h, v23.8h, v3.h[1]
128
129        FMLA    v24.8h, v16.8h, v0.h[2]
130        FMLA    v25.8h, v17.8h, v0.h[2]
131        FMLA    v26.8h, v16.8h, v1.h[2]
132        FMLA    v27.8h, v17.8h, v1.h[2]
133        FMLA    v28.8h, v16.8h, v2.h[2]
134        FMLA    v29.8h, v17.8h, v2.h[2]
135        FMLA    v30.8h, v16.8h, v3.h[2]
136        FMLA    v31.8h, v17.8h, v3.h[2]
137        FMLA    v24.8h, v18.8h, v0.h[3]
138        FMLA    v25.8h, v19.8h, v0.h[3]
139        FMLA    v26.8h, v18.8h, v1.h[3]
140        FMLA    v27.8h, v19.8h, v1.h[3]
141        FMLA    v28.8h, v18.8h, v2.h[3]
142        FMLA    v29.8h, v19.8h, v2.h[3]
143        FMLA    v30.8h, v18.8h, v3.h[3]
144        FMLA    v31.8h, v19.8h, v3.h[3]
145        B.HS    2b
146
147        # Is there a remainder?- 1 halffloat of A (2 bytes)
148        ANDS    x0, x0, 7
149        B.NE    4f
150
1513:
152        # ks loop
153        SUBS    x9, x9, 32              // ks -= MR * sizeof(void*)
154        B.HI    1b
155
156        # Clamp
157        FMAX    v24.8h, v24.8h, v4.8h
158        FMAX    v25.8h, v25.8h, v4.8h
159        FMAX    v26.8h, v26.8h, v4.8h
160        FMAX    v27.8h, v27.8h, v4.8h
161        FMAX    v28.8h, v28.8h, v4.8h
162        FMAX    v29.8h, v29.8h, v4.8h
163        FMAX    v30.8h, v30.8h, v4.8h
164        FMAX    v31.8h, v31.8h, v4.8h
165        FMIN    v24.8h, v24.8h, v5.8h
166        FMIN    v25.8h, v25.8h, v5.8h
167        FMIN    v26.8h, v26.8h, v5.8h
168        FMIN    v27.8h, v27.8h, v5.8h
169        FMIN    v28.8h, v28.8h, v5.8h
170        FMIN    v29.8h, v29.8h, v5.8h
171        FMIN    v30.8h, v30.8h, v5.8h
172        FMIN    v31.8h, v31.8h, v5.8h
173
174        # Store full 4 x 16
175        SUBS    x1, x1, 16
176        B.LO    6f
177
178        STP     q30, q31,  [x7]
179        ADD     x7,  x7, x10
180        STP     q28, q29, [x17]
181        ADD     x17, x17, x10
182        STP     q26, q27, [x16]
183        ADD     x16, x16, x10
184        STP     q24, q25,  [x6]
185        ADD     x6,  x6, x10
186
187        SUB     x4, x4, x3              // a -= ks
188
189        # nc loop
190        B.HI    0b
191        RET
192
193
194        # Remainder- 1 to 3 halffloats of A (2 to 6 bytes)
1954:
196        TBZ     x0, 2, 5f
197        LDR     s0, [x8], 4
198        LDR     q20, [x5], 16
199        LDR     q21, [x5], 16
200        LDR     s1, [x13], 4
201        LDR     s2, [x14], 4
202        LDR     s3, [x15], 4
203        LDR     q22, [x5], 16
204        LDR     q23, [x5], 16
205        FMLA    v24.8h, v20.8h, v0.h[0]
206        FMLA    v25.8h, v21.8h, v0.h[0]
207        FMLA    v26.8h, v20.8h, v1.h[0]
208        FMLA    v27.8h, v21.8h, v1.h[0]
209        FMLA    v28.8h, v20.8h, v2.h[0]
210        FMLA    v29.8h, v21.8h, v2.h[0]
211        FMLA    v30.8h, v20.8h, v3.h[0]
212        FMLA    v31.8h, v21.8h, v3.h[0]
213        FMLA    v24.8h, v22.8h, v0.h[1]
214        FMLA    v25.8h, v23.8h, v0.h[1]
215        FMLA    v26.8h, v22.8h, v1.h[1]
216        FMLA    v27.8h, v23.8h, v1.h[1]
217        FMLA    v28.8h, v22.8h, v2.h[1]
218        FMLA    v29.8h, v23.8h, v2.h[1]
219        FMLA    v30.8h, v22.8h, v3.h[1]
220        FMLA    v31.8h, v23.8h, v3.h[1]
221        TBZ     x0, 1, 3b
222
2235:
224        LDR     h0, [x8], 2
225        LDR     q20, [x5], 16
226        LDR     q21, [x5], 16
227        LDR     h1, [x13], 2
228        LDR     h2, [x14], 2
229        LDR     h3, [x15], 2
230        FMLA    v24.8h, v20.8h, v0.h[0]
231        FMLA    v25.8h, v21.8h, v0.h[0]
232        FMLA    v26.8h, v20.8h, v1.h[0]
233        FMLA    v27.8h, v21.8h, v1.h[0]
234        FMLA    v28.8h, v20.8h, v2.h[0]
235        FMLA    v29.8h, v21.8h, v2.h[0]
236        FMLA    v30.8h, v20.8h, v3.h[0]
237        FMLA    v31.8h, v21.8h, v3.h[0]
238        B       3b
239
240        # Store odd width
2416:
242        TBZ     x1, 3, 7f
243        STR     q30, [x7], 16
244        MOV     v30.16b, v31.16b
245        STR     q28, [x17], 16
246        MOV     v28.16b, v29.16b
247        STR     q26, [x16], 16
248        MOV     v26.16b, v27.16b
249        STR     q24, [x6], 16
250        MOV     v24.16b, v25.16b
251
2527:
253        TBZ     x1, 2, 8f
254        STR     d30, [x7], 8
255        STR     d28, [x17], 8
256        DUP     d30, v30.d[1]
257        DUP     d28, v28.d[1]
258        STR     d26, [x16], 8
259        STR     d24, [x6], 8
260        DUP     d26, v26.d[1]
261        DUP     d24, v24.d[1]
262
2638:
264        TBZ     x1, 1, 9f
265        STR     s30,  [x7], 4
266        STR     s28, [x17], 4
267        DUP     s30, v30.s[1]
268        DUP     s28, v28.s[1]
269        STR     s26, [x16], 4
270        STR     s24,  [x6], 4
271        DUP     s26, v26.s[1]
272        DUP     s24, v24.s[1]
2739:
274        TBZ     x1, 0, 10f
275        STR     h30,  [x7]
276        STR     h28, [x17]
277        STR     h26, [x16]
278        STR     h24,  [x6]
27910:
280        RET
281
282END_FUNCTION xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld64
283
284#ifdef __ELF__
285.section ".note.GNU-stack","",%progbits
286#endif
287