xref: /aosp_15_r20/external/XNNPACK/src/f16-igemm/4x16-minmax-aarch64-neonfp16arith-ld32.S (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack/assembly.h>
7
8# void xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld32(
9#     size_t mr,                         x0
10#     size_t nc,                         x1
11#     size_t kc,                         x2 / x0
12#     size_t ks,                         x3 / x9
13#     const void**restrict a,            x4
14#     const void*restrict w,             x5
15#     void*restrict c,                   x6
16#     size_t cm_stride,                  x7
17#     size_t cn_stride,                  [sp] -> x10
18#     size_t a_offset,                   [sp + 8] -> x11
19#     const void* zero,                  [sp + 16] -> x12
20#     const xnn_f16_minmax_params params [sp + 24] -> (x8)
21
22# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
23
24# Register usage
25# A0  x8 v0
26# A1 x13 v1
27# A2 x14 v2
28# A3 x15 v3
29
30# B   x5 v20 v21 v22 v23
31
32# C0  x6 v24 v25
33# C1 x16 v26 v27
34# C2 x17 v28 v29
35# C3  x7 v30 v31
36
37# Clamp v4, v5
38
39BEGIN_FUNCTION xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld32
40
41        # Load cn_stride, a_offset
42        LDP     x10, x11, [sp]
43
44        # Load zero, params pointer
45        LDP     x12, x8, [sp, 16]
46
47        # Load params values
48        LD2R    {v4.8h, v5.8h}, [x8]
49
50        # Clamp C pointers
51        CMP     x0, 2                   // if mr < 2
52        ADD     x16, x6, x7             // c1 = c0 + cm_stride
53        CSEL    x16, x6, x16, LO        //   c1 = c0
54        ADD     x17, x16, x7            // c2 = c1 + cm_stride
55                                        // if mr <= 2
56        CSEL    x17, x16, x17, LS       //   c2 = c1
57        CMP     x0, 4                   // if mr < 4
58        ADD     x7, x17, x7             // c3 = c2 + cm_stride
59        CSEL    x7, x17, x7, LO         //   c3 = c2
60
610:
62        # Load initial bias from w into accumulators
63        LDR     q24, [x5], 16
64        LDR     q25, [x5], 16
65        MOV     v26.16b, v24.16b
66        MOV     v28.16b, v24.16b
67        MOV     v30.16b, v24.16b
68        MOV     v27.16b, v25.16b
69        MOV     v29.16b, v25.16b
70        MOV     v31.16b, v25.16b
71
72        MOV     x9, x3                  // p = ks
73
741:
75        # Load next 4 A pointers
76        LDP     x8, x13, [x4], 16
77        LDP     x14, x15, [x4], 16
78
79        CMP     x8, x12                 // if a0 == zero
80        ADD     x8, x8, x11             // a0 += a_offset
81        CSEL    x8, x12, x8, EQ         //   a0 = zero, else += a0 + a_offset
82        CMP     x13, x12                // if a1 == zero
83        ADD     x13, x13, x11           // a1 += a_offset
84        CSEL    x13, x12, x13, EQ       //   a1 = zero, else += a1 + a_offset
85        CMP     x14, x12                // if a2 == zero
86        ADD     x14, x14, x11           // a2 += a_offset
87        CSEL    x14, x12, x14, EQ       //   a2 = zero, else += a2 + a_offset
88        CMP     x15, x12                // if a3 == zero
89        ADD     x15, x15, x11           // a3 += a_offset
90        CSEL    x15, x12, x15, EQ       //   a3 = zero, else += a3 + a_offset
91
92        # Is there at least 2 halffloats (4 bytes)?
93        SUBS    x0, x2, 4               // k = kc - 4
94        B.LO    4f
95
96       .p2align 3
97        # Main loop - 2 halffloats of A (4 bytes)
982:
99        LDR     s0,  [x8], 4
100        LDR     q20, [x5], 16
101        LDR     q21, [x5], 16
102        LDR     s1, [x13], 4
103        LDR     s2, [x14], 4
104        LDR     s3, [x15], 4
105        LDR     q22, [x5], 16
106        LDR     q23, [x5], 16
107        SUBS    x0, x0, 4
108        FMLA    v24.8h, v20.8h, v0.h[0]
109        FMLA    v25.8h, v21.8h, v0.h[0]
110        FMLA    v26.8h, v20.8h, v1.h[0]
111        FMLA    v27.8h, v21.8h, v1.h[0]
112        FMLA    v28.8h, v20.8h, v2.h[0]
113        FMLA    v29.8h, v21.8h, v2.h[0]
114        FMLA    v30.8h, v20.8h, v3.h[0]
115        FMLA    v31.8h, v21.8h, v3.h[0]
116
117        FMLA    v24.8h, v22.8h, v0.h[1]
118        FMLA    v25.8h, v23.8h, v0.h[1]
119        FMLA    v26.8h, v22.8h, v1.h[1]
120        FMLA    v27.8h, v23.8h, v1.h[1]
121        FMLA    v28.8h, v22.8h, v2.h[1]
122        FMLA    v29.8h, v23.8h, v2.h[1]
123        FMLA    v30.8h, v22.8h, v3.h[1]
124        FMLA    v31.8h, v23.8h, v3.h[1]
125        B.HS    2b
126
127        # Is there a remainder?- 1 halffloat of A (2 bytes)
128        TBNZ    x0, 1, 4f
129
1303:
131        # ks loop
132        SUBS    x9, x9, 32              // ks -= MR * sizeof(void*)
133        B.HI    1b
134
135        # Clamp
136        FMAX    v24.8h, v24.8h, v4.8h
137        FMAX    v25.8h, v25.8h, v4.8h
138        FMAX    v26.8h, v26.8h, v4.8h
139        FMAX    v27.8h, v27.8h, v4.8h
140        FMAX    v28.8h, v28.8h, v4.8h
141        FMAX    v29.8h, v29.8h, v4.8h
142        FMAX    v30.8h, v30.8h, v4.8h
143        FMAX    v31.8h, v31.8h, v4.8h
144        FMIN    v24.8h, v24.8h, v5.8h
145        FMIN    v25.8h, v25.8h, v5.8h
146        FMIN    v26.8h, v26.8h, v5.8h
147        FMIN    v27.8h, v27.8h, v5.8h
148        FMIN    v28.8h, v28.8h, v5.8h
149        FMIN    v29.8h, v29.8h, v5.8h
150        FMIN    v30.8h, v30.8h, v5.8h
151        FMIN    v31.8h, v31.8h, v5.8h
152
153        # Store full 4 x 16
154        SUBS    x1, x1, 16
155        B.LO    5f
156
157        STP     q30, q31,  [x7]
158        ADD     x7,  x7, x10
159        STP     q28, q29, [x17]
160        ADD     x17, x17, x10
161        STP     q26, q27, [x16]
162        ADD     x16, x16, x10
163        STP     q24, q25,  [x6]
164        ADD     x6,  x6, x10
165
166        SUB     x4, x4, x3              // a -= ks
167
168        # nc loop
169        B.HI    0b
170        RET
171
172        # Remainder- 1 halffloat of A
1734:
174        LDR     h0, [x8], 2
175        LDR     q20, [x5], 16
176        LDR     q21, [x5], 16
177        LDR     h1, [x13], 2
178        LDR     h2, [x14], 2
179        LDR     h3, [x15], 2
180        FMLA    v24.8h, v20.8h, v0.h[0]
181        FMLA    v25.8h, v21.8h, v0.h[0]
182        FMLA    v26.8h, v20.8h, v1.h[0]
183        FMLA    v27.8h, v21.8h, v1.h[0]
184        FMLA    v28.8h, v20.8h, v2.h[0]
185        FMLA    v29.8h, v21.8h, v2.h[0]
186        FMLA    v30.8h, v20.8h, v3.h[0]
187        FMLA    v31.8h, v21.8h, v3.h[0]
188        B       3b
189
190        # Store odd width
1915:
192        TBZ     x1, 3, 6f
193        STR     q30, [x7], 16
194        MOV     v30.16b, v31.16b
195        STR     q28, [x17], 16
196        MOV     v28.16b, v29.16b
197        STR     q26, [x16], 16
198        MOV     v26.16b, v27.16b
199        STR     q24, [x6], 16
200        MOV     v24.16b, v25.16b
201
2026:
203        TBZ     x1, 2, 7f
204        STR     d30, [x7], 8
205        STR     d28, [x17], 8
206        DUP     d30, v30.d[1]
207        DUP     d28, v28.d[1]
208        STR     d26, [x16], 8
209        STR     d24, [x6], 8
210        DUP     d26, v26.d[1]
211        DUP     d24, v24.d[1]
212
2137:
214        TBZ     x1, 1, 8f
215        STR     s30,  [x7], 4
216        STR     s28, [x17], 4
217        DUP     s30, v30.s[1]
218        DUP     s28, v28.s[1]
219        STR     s26, [x16], 4
220        STR     s24,  [x6], 4
221        DUP     s26, v26.s[1]
222        DUP     s24, v24.s[1]
2238:
224        TBZ     x1, 0, 9f
225        STR     h30,  [x7]
226        STR     h28, [x17]
227        STR     h26, [x16]
228        STR     h24,  [x6]
2299:
230        RET
231
232END_FUNCTION xnn_f16_igemm_minmax_ukernel_4x16__aarch64_neonfp16arith_ld32
233
234#ifdef __ELF__
235.section ".note.GNU-stack","",%progbits
236#endif
237