xref: /aosp_15_r20/external/XNNPACK/src/f16-igemm/1x16-minmax-aarch64-neonfp16arith-ld64.S (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack/assembly.h>
7
8# void xnn_f16_igemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld64(
9#     size_t mr,                         (x0) - unused.  mr = 1
10#     size_t nc,                         x1
11#     size_t kc,                         x2 / x0
12#     size_t ks,                         x3 / x9
13#     const void**restrict a,            x4
14#     const void*restrict w,             x5
15#     void*restrict c,                   x6
16#     size_t cm_stride,                  (x7) - unused
17#     size_t cn_stride,                  [sp] -> x10
18#     size_t a_offset,                   [sp + 8] -> x11
19#     const void* zero,                  [sp + 16] -> x12
20#     const xnn_f16_minmax_params params [sp + 24] -> (x8)
21
22# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
23
24# Register usage
25# A0  x8 v0
26
27# B   x5 v24 v25 v26 v27 v28 v29 v30 v31
28
29# C0  x6 v16 v17 v18 v19 v20 v21 v22 v23
30
31# Clamp v4, v5
32
33BEGIN_FUNCTION xnn_f16_igemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld64
34
35        # Load cn_stride, a_offset
36        LDP     x10, x11, [sp]
37
38        # Load zero, params pointer
39        LDP     x12, x8, [sp, 16]
40
41        # Load params values
42        LD2R    {v4.8h, v5.8h}, [x8]
43
440:
45        # Load initial bias from w into accumulators
46        LDR     q16, [x5], 16
47        LDR     q17, [x5], 16
48        MOVI    v18.8h, 0               // 4 sets of C for pipelining FMLA
49        MOVI    v19.8h, 0
50        MOVI    v20.8h, 0
51        MOVI    v21.8h, 0
52        MOVI    v22.8h, 0
53        MOVI    v23.8h, 0
54
55        MOV     x9, x3                  // p = ks
56
571:
58        # Load next A pointer
59        LDR     x8, [x4], 8
60
61        CMP     x8, x12                 // if a0 == zero
62        ADD     x8, x8, x11             // a0 += a_offset
63        CSEL    x8, x12, x8, EQ         //   a0 = zero, else += a0 + a_offset
64
65        # Is there at least 4 halffloats (8 bytes)?
66        SUBS    x0, x2, 8               // k = kc - 8
67        B.LO    4f
68
69       .p2align 3
70        # Main loop - 2 halffloats of A (4 bytes)
712:
72        LDR     d0,  [x8], 8
73        LDR     q24, [x5, 0]
74        LDR     q25, [x5, 16]
75        LDR     q26, [x5, 32]
76        LDR     q27, [x5, 48]
77        LDR     q28, [x5, 64]
78        LDR     q29, [x5, 80]
79        LDR     q30, [x5, 96]
80        LDR     q31, [x5, 112]
81        SUBS    x0, x0, 8
82        FMLA    v16.8h, v24.8h, v0.h[0]
83        FMLA    v17.8h, v25.8h, v0.h[0]
84        FMLA    v18.8h, v26.8h, v0.h[1]
85        FMLA    v19.8h, v27.8h, v0.h[1]
86        FMLA    v20.8h, v28.8h, v0.h[2]
87        FMLA    v21.8h, v29.8h, v0.h[2]
88        FMLA    v22.8h, v30.8h, v0.h[3]
89        FMLA    v23.8h, v31.8h, v0.h[3]
90        ADD     x5, x5, 128
91        B.HS    2b
92
93        # Is there a remainder?- 1 halffloat of A (2 bytes)
94        ANDS    x0, x0, 7
95        B.NE    4f
96
973:
98        # ks loop
99        SUBS    x9, x9, 8               // ks -= MR * sizeof(void*)
100        B.HI    1b
101
102        FADD    v16.8h, v16.8h, v18.8h
103        FADD    v17.8h, v17.8h, v19.8h
104        FADD    v20.8h, v20.8h, v22.8h
105        FADD    v21.8h, v21.8h, v23.8h
106        FADD    v16.8h, v16.8h, v20.8h
107        FADD    v17.8h, v17.8h, v21.8h
108
109        # Clamp
110        FMAX    v16.8h, v16.8h, v4.8h
111        FMAX    v17.8h, v17.8h, v4.8h
112        FMIN    v16.8h, v16.8h, v5.8h
113        FMIN    v17.8h, v17.8h, v5.8h
114
115        # Store full 1 x 16
116        SUBS    x1, x1, 16
117        B.LO    6f
118
119        STP     q16, q17,  [x6]
120        ADD     x6,  x6, x10
121
122        SUB     x4, x4, x3              // a -= ks
123
124        # nc loop
125        B.HI    0b
126        RET
127
128
129        # Remainder- 1 to 3 halffloats of A (2 to 6 bytes)
1304:
131        TBZ     x0, 2, 5f
132        LDR     s0,  [x8], 4
133        LDR     q24, [x5, 0]
134        LDR     q25, [x5, 16]
135        LDR     q26, [x5, 32]
136        LDR     q27, [x5, 48]
137        FMLA    v16.8h, v24.8h, v0.h[0]
138        FMLA    v17.8h, v25.8h, v0.h[0]
139        FMLA    v18.8h, v26.8h, v0.h[1]
140        FMLA    v19.8h, v27.8h, v0.h[1]
141        ADD     x5, x5, 64
142        TBZ     x0, 1, 3b
1435:
144        LDR     h0, [x8], 2
145        LDR     q24, [x5, 0]
146        LDR     q25, [x5, 16]
147        FMLA    v16.8h, v24.8h, v0.h[0]
148        FMLA    v17.8h, v25.8h, v0.h[0]
149        ADD     x5, x5, 32
150        B       3b
151
152        # Store odd width
1536:
154        TBZ     x1, 3, 7f
155        STR     q16, [x6], 16
156        MOV     v16.16b, v17.16b
1577:
158        TBZ     x1, 2, 8f
159        STR     d16, [x6], 8
160        DUP     d16, v16.d[1]
1618:
162        TBZ     x1, 1, 9f
163        STR     s16,  [x6], 4
164        DUP     s16, v16.s[1]
1659:
166        TBZ     x1, 0, 10f
167        STR     h16, [x6]
16810:
169        RET
170
171END_FUNCTION xnn_f16_igemm_minmax_ukernel_1x16__aarch64_neonfp16arith_ld64
172
173#ifdef __ELF__
174.section ".note.GNU-stack","",%progbits
175#endif
176