xref: /aosp_15_r20/external/XNNPACK/src/f32-igemm/4x2-aarch64-neonfma-ld64.S.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack/assembly.h>
7
8# void xnn_f32_igemm_minmax_ukernel_4x2__aarch64_neonfma_ld64(
9#     size_t mr,                         x0
10#     size_t nc,                         x1
11#     size_t kc,                         x2 / x0
12#     size_t ks,                         x3 / x9
13#     const float**restrict a,           x4
14#     const float*restrict w,            x5
15#     float*restrict c,                  x6
16#     size_t cm_stride,                  x7
17#     size_t cn_stride,                  [sp] -> x10
18#     size_t a_offset,                   [sp + 8] -> x11
19#     const float* zero,                 [sp + 16] -> x12
20#     const xnn_f32_minmax_params params [sp + 24] -> (x8)
21
22# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
23
24# A pointers
25# x8  a0
26# x13 a1
27# x14 a2
28# x15 a3
29
30# C pointers
31# x6  c0
32# x16 c1
33# x17 c2
34# x7  c3 / cm_stride
35
36# Vector register usage
37# A0  v0
38# A1  v1
39# A2  v2
40# A3  v3
41# B  v20 v21
42# C  v24 v25
43# C  v26 v27
44# C  v28 v29
45# C  v30 v31
46# Clamp v4 v5
47
48BEGIN_FUNCTION xnn_f32_igemm_minmax_ukernel_4x2__aarch64_neonfma_ld64
49
50        # Load cn_stride, a_offset
51        LDP     x10, x11, [sp]
52
53        # Load zero, params pointer
54        LDP     x12, x8, [sp, 16]
55
56        # Clamp C pointers
57        CMP     x0, 2                   // if mr < 2
58        ADD     x16, x6, x7             // c1 = c0 + cm_stride
59        CSEL    x16, x6, x16, LO        //   c1 = c0
60
61        # Load min/max values
62        LD2R    {v4.2s, v5.2s}, [x8]
63
64        ADD     x17, x16, x7            // c2 = c1 + cm_stride
65                                        // if mr <= 2
66        CSEL    x17, x16, x17, LS       //   c2 = c1
67
68        CMP     x0, 4                   // if mr < 4
69        ADD     x7, x17, x7             // c3 = c2 + cm_stride
70        CSEL    x7, x17, x7, LO         //   c3 = c2
71
720:
73        # Load initial bias from w into accumulators
74        LDR     d24, [x5], 8
75        MOV     v26.8b, v24.8b
76        MOV     v28.8b, v24.8b
77        MOV     v30.8b, v24.8b
78        MOVI    v25.2s, 0
79        MOVI    v27.2s, 0
80        MOVI    v29.2s, 0
81        MOVI    v31.2s, 0
82
83        MOV     x9, x3                  // p = ks
84
851:
86        # Load next 4 A pointers
87        LDP     x8, x13, [x4], 16
88        LDP     x14, x15, [x4], 16
89
90        CMP     x8, x12                 // if a0 == zero
91        ADD     x8, x8, x11             // a0 += a_offset
92        CSEL    x8, x12, x8, EQ         //   a0 = zero, else += a0 + a_offset
93        CMP     x13, x12                // if a1 == zero
94        ADD     x13, x13, x11           // a1 += a_offset
95        CSEL    x13, x12, x13, EQ       //   a1 = zero, else += a1 + a_offset
96        CMP     x14, x12                // if a2 == zero
97        ADD     x14, x14, x11           // a2 += a_offset
98        CSEL    x14, x12, x14, EQ       //   a2 = zero, else += a2 + a_offset
99        CMP     x15, x12                // if a3 == zero
100        ADD     x15, x15, x11           // a3 += a_offset
101        CSEL    x15, x12, x15, EQ       //   a3 = zero, else += a3 + a_offset
102
103        # Is there at least 2 floats (8 bytes)?
104        SUBS    x0, x2, 8               // k = kc - 8
105        B.LO    4f
106
107        # Main loop - 2 floats of A (8 bytes)
1082:
109        LDR     d0, [x8], 8
110        LDP     d20, d21, [x5], 16
111        LDR     d1, [x13], 8
112        LDR     d2, [x14], 8
113        LDR     d3, [x15], 8
114        SUBS    x0, x0, 8
115        FMLA    v24.2s, v20.2s, v0.s[0]
116        FMLA    v26.2s, v20.2s, v1.s[0]
117        FMLA    v28.2s, v20.2s, v2.s[0]
118        FMLA    v30.2s, v20.2s, v3.s[0]
119        FMLA    v25.2s, v21.2s, v0.s[1]
120        FMLA    v27.2s, v21.2s, v1.s[1]
121        FMLA    v29.2s, v21.2s, v2.s[1]
122        FMLA    v31.2s, v21.2s, v3.s[1]
123        B.HS    2b
124
125        # Is there a remainder?- 1 float of A (4 bytes)
126        TBNZ    x0, 2, 4f
127
1283:
129        # ks loop
130        SUBS    x9, x9, 32              // ks -= MR * sizeof(void*)
131        B.HI    1b
132
133        FADD    v24.2s, v24.2s, v25.2s
134        FADD    v26.2s, v26.2s, v27.2s
135        FADD    v28.2s, v28.2s, v29.2s
136        FADD    v30.2s, v30.2s, v31.2s
137
138        # Clamp
139        FMAX    v24.2s, v24.2s, v4.2s
140        SUBS    x1, x1, 2
141        FMAX    v26.2s, v26.2s, v4.2s
142        FMAX    v28.2s, v28.2s, v4.2s
143        FMAX    v30.2s, v30.2s, v4.2s
144        FMIN    v24.2s, v24.2s, v5.2s
145        FMIN    v26.2s, v26.2s, v5.2s
146        FMIN    v28.2s, v28.2s, v5.2s
147        FMIN    v30.2s, v30.2s, v5.2s
148
149        # Store full 4 x 2
150        B.LO    5f
151
152        STR     d30, [x7]
153        ADD     x7,  x7, x10
154        STR     d28, [x17]
155        ADD     x17, x17, x10
156        STR     d26, [x16]
157        ADD     x16, x16, x10
158        STR     d24, [x6]
159        ADD     x6,  x6, x10
160
161        SUB     x4, x4, x3              // a -= ks
162
163        # nc loop
164        B.HI    0b
165        RET
166
167        # Remainder- 1 float of A
1684:
169        LDR     s0, [x8], 4
170        LDR     d20, [x5], 8
171        LDR     s1, [x13], 4
172        LDR     s2, [x14], 4
173        LDR     s3, [x15], 4
174        FMLA    v24.2s, v20.2s, v0.s[0]
175        FMLA    v26.2s, v20.2s, v1.s[0]
176        FMLA    v28.2s, v20.2s, v2.s[0]
177        FMLA    v30.2s, v20.2s, v3.s[0]
178        B       3b
179
180        # Store odd width
1815:
182        STR     s30,  [x7]
183        STR     s28, [x17]
184        STR     s26, [x16]
185        STR     s24,  [x6]
186        RET
187
188END_FUNCTION xnn_f32_igemm_minmax_ukernel_4x2__aarch64_neonfma_ld64
189
190#ifdef __ELF__
191.section ".note.GNU-stack","",%progbits
192#endif
193