1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6
7 #include <cassert>
8 #include <cstddef>
9 #include <limits>
10
11 #include <xnnpack.h>
12 #include <xnnpack/aarch64-assembler.h>
13 #include <xnnpack/allocator.h>
14 #include <xnnpack/gemm.h>
15
16 namespace xnnpack {
17 namespace aarch64 {
18 namespace {
19 class Generator : public Assembler {
20 using Assembler::Assembler;
21
22 public:
23 void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max);
24 };
25
26 // void xnn_f32_gemm_minmax_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(
27 // size_t mr, (x0) - unused. mr = 1
28 // size_t nc, x1
29 // size_t kc, x2 / x0
30 // const uint8_t*restrict a, x3
31 // size_t a_stride, (x4) - unused
32 // const void*restrict w, x5
33 // uint8_t*restrict c, x6
34 // size_t cm_stride, (x7) - unused
35 // size_t cn_stride, [sp] -> x14
36 // const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) [sp + 8] -> (x8)
37
38 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
39
40 // A pointer
41 // x3 a0
42
43 // C pointer
44 // x6 c0
45
46 // Clamp v4 v5
47
48 // Converted from: src/f32-gemm/gen/1x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,float min,float max)49 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max)
50 {
51 assert(nc_mod_nr < 8);
52 assert(kc != 0);
53 assert(kc % sizeof(float) == 0);
54
55 Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11, l12;
56 const bool clamp_min = min != -std::numeric_limits<float>::infinity();
57 const bool clamp_max = max != +std::numeric_limits<float>::infinity();
58
59 // Load cn_stride, params pointer
60 ldp(x14, x8, mem[sp]);
61
62 // Load min/max values
63 if (clamp_max) {
64 ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
65 } else if (clamp_min) {
66 if (min == 0.f) {
67 movi(v4.v4s(), 0);
68 } else {
69 ld1r({v4.v4s()}, mem[x8]);
70 }
71 }
72
73 bind(l0);
74 // Load initial bias from w into accumulators
75 ldp(q16, q17, mem[x5], 32);
76
77 movi(v18.v4s(), 0); // second set of C for pipelining FMLA
78 if (prefetch) {
79 prfm(kPLDL1KEEP, mem[x5]);
80 }
81 movi(v19.v4s(), 0);
82 if (prefetch) {
83 prfm(kPLDL1KEEP, mem[x5, 64]);
84 prfm(kPLDL1KEEP, mem[x5, 128]);
85 prfm(kPLDL1KEEP, mem[x5, 192]);
86 }
87
88 // Is there at least 8 floats (32 bytes) for prologue + epilogue?
89 subs(x0, x2, 32); // k = kc - 32
90
91 b_lo(l3);
92
93 // 16 prologue
94 // Read first block of 1 A and B.
95 ldp(q20, q21, mem[x5], 32);
96 ldp(q22, q23, mem[x5], 32);
97 ldp(q24, q25, mem[x5], 32);
98 ldp(q26, q27, mem[x5], 32);
99 ldr(q0, mem[x3], 16);
100
101 // Is there at least 32. yes do main loop
102 subs(x0, x0, 32);
103 b_lo(l2);
104
105 // Main loop - 8 floats of A (32 bytes)
106 bind(l1);
107 // First block of 4. FMA for first 4, loads for 2nd block of 4.
108 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
109 ldr(q1, mem[x3], 16);
110 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
111 ldp(q20, q21, mem[x5], 32);
112 fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
113 if (prefetch) {
114 prfm(kPLDL1KEEP, mem[x5, 96]);
115 }
116 fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
117 ldp(q22, q23, mem[x5], 32);
118 fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
119 fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
120 ldp(q24, q25, mem[x5], 32);
121 fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
122 fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
123 ldp(q26, q27, mem[x5], 32);
124
125 // Second block of 4. FMA for second 4, loads for 1st block of 4.
126 fmla(v16.v4s(), v20.v4s(), v1.s()[0]);
127 ldr(q0, mem[x3], 16);
128 fmla(v17.v4s(), v21.v4s(), v1.s()[0]);
129 ldp(q20, q21, mem[x5], 32);
130 fmla(v18.v4s(), v22.v4s(), v1.s()[1]);
131 fmla(v19.v4s(), v23.v4s(), v1.s()[1]);
132 ldp(q22, q23, mem[x5], 32);
133 fmla(v16.v4s(), v24.v4s(), v1.s()[2]);
134 fmla(v17.v4s(), v25.v4s(), v1.s()[2]);
135 ldp(q24, q25, mem[x5], 32);
136 fmla(v18.v4s(), v26.v4s(), v1.s()[3]);
137 fmla(v19.v4s(), v27.v4s(), v1.s()[3]);
138 subs(x0, x0, 32);
139 ldp(q26, q27, mem[x5], 32);
140 b_hs(l1);
141
142 bind(l2);
143 // Epilogue
144
145 // First block of 4. FMA for first 4, loads for 2nd block of 4.
146 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
147 ldr(q1, mem[x3], 16);
148 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
149 ldp(q20, q21, mem[x5], 32);
150 fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
151 fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
152 ldp(q22, q23, mem[x5], 32);
153 fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
154 fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
155 ldp(q24, q25, mem[x5], 32);
156 fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
157 fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
158 ldp(q26, q27, mem[x5], 32);
159
160 // Second block of 4. no loads
161 fmla(v16.v4s(), v20.v4s(), v1.s()[0]);
162 fmla(v17.v4s(), v21.v4s(), v1.s()[0]);
163 fmla(v18.v4s(), v22.v4s(), v1.s()[1]);
164 fmla(v19.v4s(), v23.v4s(), v1.s()[1]);
165 fmla(v16.v4s(), v24.v4s(), v1.s()[2]);
166 fmla(v17.v4s(), v25.v4s(), v1.s()[2]);
167 fmla(v18.v4s(), v26.v4s(), v1.s()[3]);
168 fmla(v19.v4s(), v27.v4s(), v1.s()[3]);
169
170 bind(l3);
171 // Is there a remainder?- 4 floats of A (16 bytes)
172 tbnz(x0, 4, l5);
173 // Is there a remainder?- 2 floats of A (8 bytes)
174 tbnz(x0, 3, l6);
175 // Is there a remainder?- 1 float of A (4 bytes)
176 tbnz(x0, 2, l8);
177
178 bind(l4);
179 fadd(v16.v4s(), v16.v4s(), v18.v4s());
180 subs(x1, x1, 8);
181 fadd(v17.v4s(), v17.v4s(), v19.v4s());
182
183 // Clamp
184 if (clamp_min) {
185 fmax(v16.v4s(), v16.v4s(), v4.v4s());
186 fmax(v17.v4s(), v17.v4s(), v4.v4s());
187 }
188 if (clamp_max) {
189 fmin(v16.v4s(), v16.v4s(), v5.v4s());
190 fmin(v17.v4s(), v17.v4s(), v5.v4s());
191 }
192
193 // Store full 1 x 8
194 b_lo(l9);
195
196 stp(q16, q17, mem[x6]);
197 add(x6, x6, x14);
198
199 sub(x3, x3, x2); // a0 -= kc
200
201 b_hi(l0);
202
203 ret();
204
205 bind(l5);
206 // Remainder- 4 floats of A (16 bytes)
207 ldp(q20, q21, mem[x5], 32);
208 ldr(q0, mem[x3], 16);
209 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
210 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
211 ldp(q22, q23, mem[x5], 32);
212 ldp(q24, q25, mem[x5], 32);
213 ldp(q26, q27, mem[x5], 32);
214 fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
215 fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
216 fmla(v16.v4s(), v24.v4s(), v0.s()[2]);
217 fmla(v17.v4s(), v25.v4s(), v0.s()[2]);
218 fmla(v18.v4s(), v26.v4s(), v0.s()[3]);
219 fmla(v19.v4s(), v27.v4s(), v0.s()[3]);
220
221 tbz(x0, 3, l7);
222 bind(l6);
223 // Remainder- 2 floats of A (8 bytes)
224 ldp(q20, q21, mem[x5], 32);
225 ldr(d0, mem[x3], 8);
226 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
227 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
228 ldp(q22, q23, mem[x5], 32);
229 fmla(v18.v4s(), v22.v4s(), v0.s()[1]);
230 fmla(v19.v4s(), v23.v4s(), v0.s()[1]);
231 bind(l7);
232 tbz(x0, 2, l4);
233 bind(l8);
234 // Remainder- 1 float of A (4 bytes)
235 ldp(q20, q21, mem[x5], 32);
236 ldr(s0, mem[x3], 4);
237 fmla(v16.v4s(), v20.v4s(), v0.s()[0]);
238 fmla(v17.v4s(), v21.v4s(), v0.s()[0]);
239 b(l4);
240
241 // Store odd channels
242 bind(l9);
243 tbz(x1, 2, l10);
244 str(q16, mem[x6], 16);
245 mov(v16.v16b(), v17.v16b());
246
247 bind(l10);
248 tbz(x1, 1, l11);
249 str(d16, mem[x6], 8);
250 dup(d16, v16.d()[1]);
251
252 bind(l11);
253 tbz(x1, 0, l12);
254 str(s16, mem[x6]);
255 bind(l12);
256 ret();
257
258 align(16, AlignInstruction::kHlt);
259 }
260 } // namespace
261 } // namespace aarch64
262 } // namespace xnnpack
263
xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)264 xnn_status_t xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_cortex_a75(
265 xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
266 {
267 using namespace xnnpack::aarch64;
268 Generator g(code);
269 assert(params != nullptr);
270 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
271 g.generate(false, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
272 g.finalize();
273 if (g.error() != xnnpack::Error::kNoError) {
274 return xnn_status_invalid_state;
275 }
276 return xnn_status_success;
277 }
278
xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)279 xnn_status_t xnn_generate_f32_gemm_ukernel_1x8__aarch64_neonfma_prfm_cortex_a75(
280 xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
281 {
282 using namespace xnnpack::aarch64;
283 Generator g(code);
284 assert(params != nullptr);
285 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
286 g.generate(true, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
287 g.finalize();
288 if (g.error() != xnnpack::Error::kNoError) {
289 return xnn_status_invalid_state;
290 }
291 return xnn_status_success;
292 }
293