xref: /aosp_15_r20/external/XNNPACK/src/f32-gemm/gen/4x2-minmax-aarch64-neonfma-ld64.S (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Auto-generated file. Do not edit!
2//   Template: src/f32-gemm/4x2-aarch64-neonfma-ld64.S.in
3//   Generator: tools/xngen
4//
5// Copyright 2019 Google LLC
6//
7// This source code is licensed under the BSD-style license found in the
8// LICENSE file in the root directory of this source tree.
9
10#include <xnnpack/assembly.h>
11
12# void xnn_f32_gemm_minmax_ukernel_4x2__aarch64_neonfma_ld64(
13#     size_t mr,                x0
14#     size_t nc,                x1
15#     size_t kc,                x2 / x0
16#     const uint8_t*restrict a, x3
17#     size_t a_stride,          x4
18#     const void*restrict w,    x5
19#     uint8_t*restrict c,       x6
20#     size_t cm_stride,         x7
21#     size_t cn_stride,         [sp] -> x14
22#     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])  [sp + 8] -> (x8)
23
24# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
25
26# A pointers
27# x3  a0
28# x11 a1
29# x12 a2
30# x4  a3 / a_stride
31
32# C pointers
33# x6  c0
34# x9  c1
35# x10 c2
36# x7  c3 / cm_stride
37
38# Vector register usage
39# A0  v0
40# A1  v1
41# A2  v2
42# A3  v3
43# B  v20 v21
44# C  v24 v25
45# C  v26 v27
46# C  v28 v29
47# C  v30 v31
48# Clamp v4 v5
49
50BEGIN_FUNCTION xnn_f32_gemm_minmax_ukernel_4x2__aarch64_neonfma_ld64
51
52        # Load cn_stride, params pointer
53        LDP     x14, x8, [sp]
54
55        # Clamp A and C pointers
56        CMP     x0, 2                   // if mr < 2
57        ADD     x11, x3, x4             // a1 = a0 + a_stride
58        ADD     x9, x6, x7              // c1 = c0 + cm_stride
59        CSEL    x11, x3, x11, LO        //   a1 = a0
60        CSEL    x9, x6, x9, LO          //   c1 = c0
61
62        # Load min/max values
63        LD2R    {v4.2s, v5.2s}, [x8]
64
65        ADD     x12, x11, x4            // a2 = a1 + a_stride
66        ADD     x10, x9, x7             // c2 = c1 + cm_stride
67                                        // if mr <= 2
68        CSEL    x12, x11, x12, LS       //   a2 = a1
69        CSEL    x10, x9, x10, LS        //   c2 = c1
70
71        CMP     x0, 4                   // if mr < 4
72        ADD     x4, x12, x4             // a3 = a2 + a_stride
73        ADD     x7, x10, x7             // c3 = c2 + cm_stride
74        CSEL    x4, x12, x4, LO         //   a3 = a2
75        CSEL    x7, x10, x7, LO         //   c3 = c2
76
770:
78        # Load initial bias from w into accumulators
79        LDR     d24, [x5], 8
80        MOV     v26.8b, v24.8b
81        MOV     v28.8b, v24.8b
82        MOV     v30.8b, v24.8b
83        MOVI    v25.2s, 0
84        MOVI    v27.2s, 0
85        MOVI    v29.2s, 0
86        MOVI    v31.2s, 0
87
88        # Is there at least 2 floats (8 bytes)?
89        SUBS    x0, x2, 8               // k = kc - 8
90        B.LO    3f
91
92        # Main loop - 2 floats of A (8 bytes)
931:
94        LDR     d0,  [x3], 8
95        LDP     d20, d21, [x5], 16
96        LDR     d1, [x11], 8
97        LDR     d2, [x12], 8
98        LDR     d3,  [x4], 8
99        SUBS    x0, x0, 8
100        FMLA    v24.2s, v20.2s, v0.s[0]
101        FMLA    v26.2s, v20.2s, v1.s[0]
102        FMLA    v28.2s, v20.2s, v2.s[0]
103        FMLA    v30.2s, v20.2s, v3.s[0]
104        FMLA    v25.2s, v21.2s, v0.s[1]
105        FMLA    v27.2s, v21.2s, v1.s[1]
106        FMLA    v29.2s, v21.2s, v2.s[1]
107        FMLA    v31.2s, v21.2s, v3.s[1]
108        B.HS    1b
109
110        # Is there a remainder?- 1 float of A (4 bytes)
111        TBNZ    x0, 2, 3f
112
1132:
114        FADD    v24.2s, v24.2s, v25.2s
115        FADD    v26.2s, v26.2s, v27.2s
116        FADD    v28.2s, v28.2s, v29.2s
117        FADD    v30.2s, v30.2s, v31.2s
118
119        # Clamp
120        FMAX    v24.2s, v24.2s, v4.2s
121        SUBS    x1, x1, 2
122        FMAX    v26.2s, v26.2s, v4.2s
123        FMAX    v28.2s, v28.2s, v4.2s
124        FMAX    v30.2s, v30.2s, v4.2s
125        FMIN    v24.2s, v24.2s, v5.2s
126        FMIN    v26.2s, v26.2s, v5.2s
127        FMIN    v28.2s, v28.2s, v5.2s
128        FMIN    v30.2s, v30.2s, v5.2s
129
130        # Store full 4 x 2
131        B.LO    4f
132
133        ST1     {v24.8b},  [x6], x14
134        SUB     x3,  x3, x2             // a0 -= kc
135        ST1     {v26.8b},  [x9], x14
136        SUB     x11, x11, x2            // a1 -= kc
137        ST1     {v28.8b}, [x10], x14
138        SUB     x12, x12, x2            // a2 -= kc
139        ST1     {v30.8b},  [x7], x14
140        SUB     x4,  x4, x2             // a3 -= kc
141
142        B.HI    0b
143
144        RET
145
146        # Remainder- 1 float of A (4 bytes)
1473:
148        LDR     s0,  [x3], 4
149        LDR     d20, [x5], 8
150        LDR     s1, [x11], 4
151        LDR     s2, [x12], 4
152        LDR     s3,  [x4], 4
153        SUBS    x0, x0, 4
154        FMLA    v24.2s, v20.2s, v0.s[0]
155        FMLA    v26.2s, v20.2s, v1.s[0]
156        FMLA    v28.2s, v20.2s, v2.s[0]
157        FMLA    v30.2s, v20.2s, v3.s[0]
158        B       2b
159
160        # Store odd width
1614:
162        STR     s24,  [x6]
163        STR     s26,  [x9]
164        STR     s28, [x10]
165        STR     s30,  [x7]
1667:
167        RET
168
169END_FUNCTION xnn_f32_gemm_minmax_ukernel_4x2__aarch64_neonfma_ld64
170
171#ifdef __ELF__
172.section ".note.GNU-stack","",%progbits
173#endif
174