1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <cassert>
7 #include <cstddef>
8 #include <limits>
9
10 #include <xnnpack.h>
11 #include <xnnpack/aarch64-assembler.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/igemm.h>
14
15 namespace xnnpack {
16 namespace aarch64 {
17 namespace {
18 class Generator : public Assembler {
19 using Assembler::Assembler;
20 public:
21 void generate(size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max);
22 };
23
24 // void xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_ld128(
25 // size_t mr, x0
26 // size_t nc, x1
27 // size_t kc, x2 / x0
28 // size_t ks, x3 / x9
29 // const float**restrict a, x4
30 // const void*restrict w, x5
31 // uint8_t*restrict c, x6
32 // size_t cm_stride, x7
33 // size_t cn_stride, [sp] -> (x0)
34 // size_t a_offset, [sp + 8] -> x11
35 // const float* zero, [sp + 16] -> x12
36 // const xnn_f32_minmax_params params [sp + 24] -> x8
37
38 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
39
40 // A pointers
41 // x14 a0
42 // x15 a1
43 // x20 a2
44 // x21 a3
45 // x22 a4
46 // x23 a5
47
48 // C pointers
49 // x6 c0
50 // x16 c1
51 // x17 c2
52 // x10 c3
53 // x13 c4
54 // x7 c5
55
56 // Vector register usage
57 // A0 v0
58 // A1 v1
59 // A2 v2
60 // A3 v3
61 // A4 v4
62 // A5 v5
63 // B v16 v17 v18 v19
64 // C v20 v21
65 // C v22 v23
66 // C v24 v25
67 // C v26 v27
68 // C v28 v29
69 // C v30 v31
70 // Clamp v6 v7
71 // unused A v8 v9 v10 v11
72 // unused B v12 v13 v14 v15
73
74 // Converted from: src/f32-igemm/gen/6x8-minmax-aarch64-neonfma-ld128.S
generate(size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,float min,float max)75 void Generator::generate(size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max) {
76 assert(max_mr <= 6);
77 assert(nc_mod_nr < 8);
78 assert(kc != 0);
79 assert(kc % sizeof(float) == 0);
80
81 Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9;
82
83 const bool clamp_min = min != -std::numeric_limits<float>::infinity();
84 const bool clamp_max = max != +std::numeric_limits<float>::infinity();
85
86 // Load zero, params pointer
87 ldp(x12, x8, mem[sp, 16]);
88
89 // Clamp C pointers
90 cmp(x0, 2); // if mr < 2
91 add(x16, x6, x7); // c1 = c0 + cm_stride
92 csel(x16, x6, x16, kLO); // c1 = c0
93
94 // Load min/max values
95 if (clamp_min || clamp_max) {
96 ld2r({v6.v4s(), v7.v4s()}, mem[x8]);
97 }
98
99 add(x17, x16, x7); // c2 = c1 + cm_stride
100 // if mr <= 2
101 csel(x17, x16, x17, kLS); // c2 = c1
102
103 // Save x20,x21,x22,x23 on stack
104 stp(x20, x21, mem[sp, -32]++);
105
106 cmp(x0, 4); // if mr < 4
107 add(x10, x17, x7); // c3 = c2 + cm_stride
108 csel(x10, x17, x10, kLO); // c3 = c2
109
110 stp(x22, x23, mem[sp, 16]);
111
112 add(x13, x10, x7); // c4 = c3 + cm_stride
113 // if mr <= 4
114 csel(x13, x10, x13, kLS); // c4 = c3
115
116 // Load a_offset
117 ldr(x11, mem[sp, 40]);
118
119 cmp(x0, 6); // if mr < 6
120 add(x7, x13, x7); // c5 = c4 + cm_stride
121 csel(x7, x13, x7, kLO); // c5 = c4
122
123 bind(l0);
124 // Load initial bias from w into accumulators
125 ldp(q20, q21, mem[x5], 32);
126 mov(v22.v16b(), v20.v16b());
127 mov(v23.v16b(), v21.v16b());
128 mov(v24.v16b(), v20.v16b());
129 mov(v25.v16b(), v21.v16b());
130 mov(v26.v16b(), v20.v16b());
131 mov(v27.v16b(), v21.v16b());
132 mov(v28.v16b(), v20.v16b());
133 mov(v29.v16b(), v21.v16b());
134 mov(v30.v16b(), v20.v16b());
135 mov(v31.v16b(), v21.v16b());
136
137 mov(x9, x3); // p = ks
138
139 bind(l1);
140 // Load next 6 A pointers
141 ldp(x14, x15, mem[x4], 16);
142 ldp(x20, x21, mem[x4], 16);
143 ldp(x22, x23, mem[x4], 16);
144
145 cmp(x14, x12); // if a0 == zero
146 add(x14, x14, x11); // a0 += a_offset
147 csel(x14, x12, x14, kEQ); // a0 = zero, else += a0 + a_offset
148 cmp(x15, x12); // if a1 == zero
149 add(x15, x15, x11); // a1 += a_offset
150 csel(x15, x12, x15, kEQ); // a1 = zero, else += a1 + a_offset
151 cmp(x20, x12); // if a2 == zero
152 add(x20, x20, x11); // a2 += a_offset
153 csel(x20, x12, x20, kEQ); // a2 = zero, else += a2 + a_offset
154 cmp(x21, x12); // if a3 == zero
155 add(x21, x21, x11); // a3 += a_offset
156 csel(x21, x12, x21, kEQ); // a3 = zero, else += a3 + a_offset
157 cmp(x22, x12); // if a4 == zero
158 add(x22, x22, x11); // a4 += a_offset
159 csel(x22, x12, x22, kEQ); // a4 = zero, else += a4 + a_offset
160 cmp(x23, x12); // if a5 == zero
161 add(x23, x23, x11); // a5 += a_offset
162 csel(x23, x12, x23, kEQ); // a5 = zero, else += a5 + a_offset
163
164 // Is there at least 4 floats (16 bytes)?
165 subs(x0, x2, 16); // k = kc - 16
166 b_lo(l4);
167
168 // Main loop - 4 floats of A (16 bytes)
169 // 48 FMA + 6 ld128 A + 4 LDP B
170 bind(l2);
171 ldp(q16, q17, mem[x5], 32);
172 ldr(q0, mem[x14], 16);
173 ldr(q1, mem[x15], 16);
174 ldr(q2, mem[x20], 16);
175 ldr(q3, mem[x21], 16);
176 ldr(q4, mem[x22], 16);
177 ldr(q5, mem[x23], 16);
178 fmla(v20.v4s(), v16.v4s(), v0.s()[0]);
179 fmla(v22.v4s(), v16.v4s(), v1.s()[0]);
180 fmla(v24.v4s(), v16.v4s(), v2.s()[0]);
181 fmla(v26.v4s(), v16.v4s(), v3.s()[0]);
182 ldp(q18, q19, mem[x5], 32);
183 fmla(v28.v4s(), v16.v4s(), v4.s()[0]);
184 fmla(v30.v4s(), v16.v4s(), v5.s()[0]);
185 fmla(v21.v4s(), v17.v4s(), v0.s()[0]);
186 fmla(v23.v4s(), v17.v4s(), v1.s()[0]);
187 fmla(v25.v4s(), v17.v4s(), v2.s()[0]);
188 fmla(v27.v4s(), v17.v4s(), v3.s()[0]);
189 fmla(v29.v4s(), v17.v4s(), v4.s()[0]);
190 fmla(v31.v4s(), v17.v4s(), v5.s()[0]);
191
192 fmla(v20.v4s(), v18.v4s(), v0.s()[1]);
193 ldp(q16, q17, mem[x5], 32);
194 fmla(v22.v4s(), v18.v4s(), v1.s()[1]);
195 fmla(v24.v4s(), v18.v4s(), v2.s()[1]);
196 fmla(v26.v4s(), v18.v4s(), v3.s()[1]);
197 fmla(v28.v4s(), v18.v4s(), v4.s()[1]);
198 fmla(v30.v4s(), v18.v4s(), v5.s()[1]);
199 fmla(v21.v4s(), v19.v4s(), v0.s()[1]);
200 fmla(v23.v4s(), v19.v4s(), v1.s()[1]);
201 fmla(v25.v4s(), v19.v4s(), v2.s()[1]);
202 fmla(v27.v4s(), v19.v4s(), v3.s()[1]);
203 fmla(v29.v4s(), v19.v4s(), v4.s()[1]);
204 fmla(v31.v4s(), v19.v4s(), v5.s()[1]);
205
206 fmla(v20.v4s(), v16.v4s(), v0.s()[2]);
207 ldp(q18, q19, mem[x5], 32);
208 fmla(v22.v4s(), v16.v4s(), v1.s()[2]);
209 fmla(v24.v4s(), v16.v4s(), v2.s()[2]);
210 fmla(v26.v4s(), v16.v4s(), v3.s()[2]);
211 fmla(v28.v4s(), v16.v4s(), v4.s()[2]);
212 fmla(v30.v4s(), v16.v4s(), v5.s()[2]);
213 fmla(v21.v4s(), v17.v4s(), v0.s()[2]);
214 fmla(v23.v4s(), v17.v4s(), v1.s()[2]);
215 fmla(v25.v4s(), v17.v4s(), v2.s()[2]);
216 fmla(v27.v4s(), v17.v4s(), v3.s()[2]);
217 fmla(v29.v4s(), v17.v4s(), v4.s()[2]);
218 fmla(v31.v4s(), v17.v4s(), v5.s()[2]);
219
220 fmla(v20.v4s(), v18.v4s(), v0.s()[3]);
221 fmla(v22.v4s(), v18.v4s(), v1.s()[3]);
222 fmla(v24.v4s(), v18.v4s(), v2.s()[3]);
223 fmla(v26.v4s(), v18.v4s(), v3.s()[3]);
224 fmla(v28.v4s(), v18.v4s(), v4.s()[3]);
225 fmla(v30.v4s(), v18.v4s(), v5.s()[3]);
226 fmla(v21.v4s(), v19.v4s(), v0.s()[3]);
227 fmla(v23.v4s(), v19.v4s(), v1.s()[3]);
228 fmla(v25.v4s(), v19.v4s(), v2.s()[3]);
229 fmla(v27.v4s(), v19.v4s(), v3.s()[3]);
230 subs(x0, x0, 16);
231 fmla(v29.v4s(), v19.v4s(), v4.s()[3]);
232 fmla(v31.v4s(), v19.v4s(), v5.s()[3]);
233 b_hs(l2);
234
235 // Is there a remainder?- 2 floats of A (8 bytes) or less
236 tst(x0, 15);
237 b_ne(l4);
238
239 bind(l3);
240 // ks loop
241 subs(x9, x9, 48); // ks -= MR * sizeof(void*)
242 b_hi(l1);
243
244 // Load cn_stride
245 ldr(x0, mem[sp, 32]);
246 // Clamp
247 if (clamp_min) {
248 fmax(v20.v4s(), v20.v4s(), v6.v4s());
249 fmax(v21.v4s(), v21.v4s(), v6.v4s());
250 fmax(v22.v4s(), v22.v4s(), v6.v4s());
251 fmax(v23.v4s(), v23.v4s(), v6.v4s());
252 fmax(v24.v4s(), v24.v4s(), v6.v4s());
253 fmax(v25.v4s(), v25.v4s(), v6.v4s());
254 fmax(v26.v4s(), v26.v4s(), v6.v4s());
255 fmax(v27.v4s(), v27.v4s(), v6.v4s());
256 fmax(v28.v4s(), v28.v4s(), v6.v4s());
257 fmax(v29.v4s(), v29.v4s(), v6.v4s());
258 fmax(v30.v4s(), v30.v4s(), v6.v4s());
259 fmax(v31.v4s(), v31.v4s(), v6.v4s());
260 }
261 subs(x1, x1, 8);
262 if (clamp_max) {
263 fmin(v20.v4s(), v20.v4s(), v7.v4s());
264 fmin(v21.v4s(), v21.v4s(), v7.v4s());
265 fmin(v22.v4s(), v22.v4s(), v7.v4s());
266 fmin(v23.v4s(), v23.v4s(), v7.v4s());
267 fmin(v24.v4s(), v24.v4s(), v7.v4s());
268 fmin(v25.v4s(), v25.v4s(), v7.v4s());
269 fmin(v26.v4s(), v26.v4s(), v7.v4s());
270 fmin(v27.v4s(), v27.v4s(), v7.v4s());
271 fmin(v28.v4s(), v28.v4s(), v7.v4s());
272 fmin(v29.v4s(), v29.v4s(), v7.v4s());
273 fmin(v30.v4s(), v30.v4s(), v7.v4s());
274 fmin(v31.v4s(), v31.v4s(), v7.v4s());
275 }
276
277 // Store full 6 x 8
278 b_lo(l6);
279
280 stp(q30, q31, mem[x7]);
281 add(x7, x7, x0);
282 stp(q28, q29, mem[x13]);
283 add(x13, x13, x0);
284 stp(q26, q27, mem[x10]);
285 add(x10, x10, x0);
286 stp(q24, q25, mem[x17]);
287 add(x17, x17, x0);
288 stp(q22, q23, mem[x16]);
289 add(x16, x16, x0);
290 stp(q20, q21, mem[x6]);
291 add(x6, x6, x0);
292
293 sub(x4, x4, x3); // a -= ks
294
295 // nc loop
296 b_hi(l0);
297
298 // Restore x20,x21,x22,x23 from stack
299 ldp(x22, x23, mem[sp, 16]);
300 ldp(x20, x21, mem[sp], 32);
301 ret();
302
303 bind(l4);
304 // Is there a remainder?- 2 floats of A (8 bytes)
305 tbz(x0, 3, l5);
306
307 // Remainder- 2 floats of A (8 bytes)
308 ldr(d0, mem[x14], 8);
309 ldp(q16, q17, mem[x5], 32);
310 ldr(d1, mem[x15], 8);
311 ldr(d2, mem[x20], 8);
312 ldr(d3, mem[x21], 8);
313 ldr(d4, mem[x22], 8);
314 ldr(d5, mem[x23], 8);
315 fmla(v20.v4s(), v16.v4s(), v0.s()[0]);
316 fmla(v22.v4s(), v16.v4s(), v1.s()[0]);
317 fmla(v24.v4s(), v16.v4s(), v2.s()[0]);
318 fmla(v26.v4s(), v16.v4s(), v3.s()[0]);
319 ldp(q18, q19, mem[x5], 32);
320 fmla(v28.v4s(), v16.v4s(), v4.s()[0]);
321 fmla(v30.v4s(), v16.v4s(), v5.s()[0]);
322 fmla(v21.v4s(), v17.v4s(), v0.s()[0]);
323 fmla(v23.v4s(), v17.v4s(), v1.s()[0]);
324 fmla(v25.v4s(), v17.v4s(), v2.s()[0]);
325 fmla(v27.v4s(), v17.v4s(), v3.s()[0]);
326 fmla(v29.v4s(), v17.v4s(), v4.s()[0]);
327 fmla(v31.v4s(), v17.v4s(), v5.s()[0]);
328
329 fmla(v20.v4s(), v18.v4s(), v0.s()[1]);
330 fmla(v22.v4s(), v18.v4s(), v1.s()[1]);
331 fmla(v24.v4s(), v18.v4s(), v2.s()[1]);
332 fmla(v26.v4s(), v18.v4s(), v3.s()[1]);
333 fmla(v28.v4s(), v18.v4s(), v4.s()[1]);
334 fmla(v30.v4s(), v18.v4s(), v5.s()[1]);
335 fmla(v21.v4s(), v19.v4s(), v0.s()[1]);
336 fmla(v23.v4s(), v19.v4s(), v1.s()[1]);
337 fmla(v25.v4s(), v19.v4s(), v2.s()[1]);
338 fmla(v27.v4s(), v19.v4s(), v3.s()[1]);
339 fmla(v29.v4s(), v19.v4s(), v4.s()[1]);
340 fmla(v31.v4s(), v19.v4s(), v5.s()[1]);
341
342 // Is there a remainder?- 1 float of A (4 bytes)
343 tbz(x0, 2, l3);
344
345 // Remainder- 1 float of A (4 bytes)
346 bind(l5);
347 ldr(s0, mem[x14], 4);
348 ldp(q16, q17, mem[x5], 32);
349 ldr(s1, mem[x15], 4);
350 ldr(s2, mem[x20], 4);
351 ldr(s3, mem[x21], 4);
352 ldr(s4, mem[x22], 4);
353 ldr(s5, mem[x23], 4);
354 fmla(v20.v4s(), v16.v4s(), v0.s()[0]);
355 fmla(v22.v4s(), v16.v4s(), v1.s()[0]);
356 fmla(v24.v4s(), v16.v4s(), v2.s()[0]);
357 fmla(v26.v4s(), v16.v4s(), v3.s()[0]);
358 fmla(v28.v4s(), v16.v4s(), v4.s()[0]);
359 fmla(v30.v4s(), v16.v4s(), v5.s()[0]);
360 fmla(v21.v4s(), v17.v4s(), v0.s()[0]);
361 fmla(v23.v4s(), v17.v4s(), v1.s()[0]);
362 fmla(v25.v4s(), v17.v4s(), v2.s()[0]);
363 fmla(v27.v4s(), v17.v4s(), v3.s()[0]);
364 fmla(v29.v4s(), v17.v4s(), v4.s()[0]);
365 fmla(v31.v4s(), v17.v4s(), v5.s()[0]);
366 b(l3);
367
368 // Store odd width
369 bind(l6);
370 tbz(x1, 2, l7);
371 str(q30, mem[x7], 16);
372 mov(v30.v16b(), v31.v16b());
373 str(q28, mem[x13], 16);
374 mov(v28.v16b(), v29.v16b());
375 str(q26, mem[x10], 16);
376 mov(v26.v16b(), v27.v16b());
377 str(q24, mem[x17], 16);
378 mov(v24.v16b(), v25.v16b());
379 str(q22, mem[x16], 16);
380 mov(v22.v16b(), v23.v16b());
381 str(q20, mem[x6], 16);
382 mov(v20.v16b(), v21.v16b());
383 bind(l7);
384 tbz(x1, 1, l8);
385 str(d30, mem[x7], 8);
386 str(d28, mem[x13], 8);
387 dup(d30, v30.d()[1]);
388 dup(d28, v28.d()[1]);
389 str(d26, mem[x10], 8);
390 str(d24, mem[x17], 8);
391 dup(d26, v26.d()[1]);
392 dup(d24, v24.d()[1]);
393 str(d22, mem[x16], 8);
394 str(d20, mem[x6], 8);
395 dup(d22, v22.d()[1]);
396 dup(d20, v20.d()[1]);
397
398 bind(l8);
399 tbz(x1, 0, l9);
400 str(s30, mem[x7]);
401 str(s28, mem[x13]);
402 str(s26, mem[x10]);
403 str(s24, mem[x17]);
404 str(s22, mem[x16]);
405 str(s20, mem[x6]);
406 bind(l9);
407 // Restore x20,x21,x22,x23 from stack
408 ldp(x22, x23, mem[sp, 16]);
409 ldp(x20, x21, mem[sp], 32);
410 ret();
411
412 align(16, AlignInstruction::kHlt);
413 }
414 } // namespace
415 } // aarch64
416 } // xnnpack
417
xnn_generate_f32_igemm_ukernel_6x8__aarch64_neonfma_ld128(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)418 xnn_status_t xnn_generate_f32_igemm_ukernel_6x8__aarch64_neonfma_ld128(
419 xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
420 using namespace xnnpack::aarch64;
421 Generator g(code);
422 assert(params != nullptr);
423 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
424 g.generate(max_mr, nc_mod_nr, kc, ks, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
425 g.finalize();
426 if (g.error() != xnnpack::Error::kNoError) {
427 return xnn_status_invalid_state;
428 }
429 return xnn_status_success;
430 }
431