1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6
7 #include <cassert>
8
9 #include <xnnpack.h>
10 #include <xnnpack/aarch32-assembler.h>
11 #include <xnnpack/allocator.h>
12 #include <xnnpack/igemm.h>
13
14 namespace xnnpack {
15 namespace aarch32 {
16 namespace {
17 class Generator : public Assembler {
18 using Assembler::Assembler;
19 public:
20 void generate(size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params);
21 };
22
23
24 // void xnn_f32_igemm_minmax_ukernel_4x8__aarch32_neon_cortex_a55(
25 // size_t mr, r0
26 // size_t nc, r1
27 // size_t kc, r2 -> r5
28 // size_t ks, r3 -> sp + 64 -> r14
29 // const float**restrict a, sp + 104 -> (r5)
30 // const void*restrict w, sp + 108 -> r9
31 // uint8_t*restrict c, sp + 112 -> r11
32 // size_t cm_stride, sp + 116 -> (r6)
33 // size_t cn_stride, sp + 120 -> (r0)
34 // size_t a_offset, sp + 124 -> (r5)
35 // const float* zero, sp + 128 -> (r0)
36 // minmax_params*params, sp + 132 -> (r5)
37
38 // d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved.
39
40 // Register usage
41
42 // A0 r3 d0
43 // A1 r12 d1
44 // A2 r10 d2
45 // A3 r7 d3
46
47 // B r9 d8, d9, d10, d11
48 // B d12, d13, d14, d15
49
50 // C0 r11 d16-d17 q8 d18-d19 q9
51 // C1 r4 d20-d21 q10 d22-d23 q11
52 // C2 r8 d24-d25 q12 d26-d27 q13
53 // C3 r6 d28-d29 q14 d30-d31 q15
54
55 // Clamp (r5) d4 d5 d6 d7
56
57 // Converted from: src/f32-igemm/4x8-minmax-aarch32-neon-cortex-a55.S
generate(size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)58 void Generator::generate(size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
59 {
60 assert(nc_mod_nr < 8);
61 assert(kc != 0);
62 assert(kc % sizeof(float) == 0);
63
64 Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;
65
66 // Push 104 bytes
67 push({r3, r4, r5, r6, r7, r8, r9, r10, r11, lr}); // +40
68 vpush({d8-d15}); // +64 = 104
69
70 ldr(r11, mem[sp, 112]); // c
71 ldr(r6, mem[sp, 116]); // cm_stride
72 ldr(r5, mem[sp, 104]); // a
73 ldr(r9, mem[sp, 108]); // w
74 mov(r14, r3); // p = ks
75
76 // Clamp C pointers
77 cmp(r0, 2); // if mr >= 2
78 add(r4, r11, r6); // c1 = c0 + cm_stride
79 movlo(r4, r11); // c1
80 // if mr > 2
81 add(r8, r4, r6); // c2 = c1 + cm_stride
82 movls(r8, r4); // c2
83 cmp(r0, 4); // if mr >=4
84 add(r6, r8, r6); // c3 = c2 + cm_stride
85 movlo(r6, r8); // c3
86
87
88 align(8);
89 bind(l0);
90 // Load initial bias from w into accumulators
91 vldm(mem[r9]++, {d16-d19}); // Bias
92
93 vmov(q10, q8);
94 vmov(q11, q9);
95 vmov(q12, q8);
96 vmov(q13, q9);
97 pld(mem[r9, 0]); // Prefetch B
98 pld(mem[r9, 64]);
99 vmov(q14, q8);
100 pld(mem[r9, 128]);
101 pld(mem[r9, 192]);
102 vmov(q15, q9);
103 pld(mem[r9, 256]);
104 pld(mem[r9, 320]);
105
106 bind(l1);
107 // Load next 4 A pointers
108 ldr(r3, mem[r5, 0]);
109 ldr(r12, mem[r5, 4]);
110 ldr(r10, mem[r5, 8]);
111 ldr(r7, mem[r5, 12]);
112 add(r5, r5, 16);
113 pld(mem[r3, 0]); // Prefetch A
114 str(r5, mem[sp, 104]); // a
115 pld(mem[r3, 64]);
116 ldr(r0, mem[sp, 128]); // zero
117 pld(mem[r12, 0]);
118 ldr(r5, mem[sp, 124]); // a_offset
119 pld(mem[r12, 64]);
120 pld(mem[r10, 0]);
121 pld(mem[r10, 64]);
122 pld(mem[r7, 0]);
123 pld(mem[r7, 64]);
124
125 // Add a_offset
126 cmp(r3, r0); // if a0 == zero
127 add(r3, r3, r5); // a0 += a_offset
128 moveq(r3, r0); // a0 = zero, else += a0 + a_offset
129 cmp(r12, r0); // if a1 == zero
130 add(r12, r12, r5); // a1 += a_offset
131 moveq(r12, r0); // a1 = zero, else += a1 + a_offset
132 cmp(r10, r0); // if a2 == zero
133 add(r10, r10, r5); // a2 += a_offset
134 moveq(r10, r0); // a2 = zero, else += a2 + a_offset
135 cmp(r7, r0); // if a3 == zero
136 add(r7, r7, r5); // a3 += a_offset
137 moveq(r7, r0); // a3 = zero, else += a3 + a_offset
138
139 subs(r5, r2, 16); // kc - 16
140 blo(l5); // less than 4 channels?
141
142 // Prologue
143 vld1_32({d0}, mem[r3]++); // A0
144 vld1_32({d1}, mem[r12]++); // A1
145 vld1_32({d2}, mem[r10]++); // A2
146 vld1_32({d3}, mem[r7]++); // A3
147 subs(r5, r5, 16);
148 vldm(mem[r9], {d8-d11}); // B0
149 vldr(d15, mem[r9, 56]); // B1CK 0
150 vldr(d13, mem[r9, 40]); // B1
151
152 blo(l3); // less than 4 channels? skip main loop
153
154 // Main loop - 4 floats of A (16 bytes)
155 // 32 FMA + 8 LD64 A + 8 LDR B
156 align(8);
157 bind(l2);
158 // First group of 16 FMA, Second group loads
159 // BLOCK 0
160 vmla_f32(q8, q4, d0[0]);
161 vld1_32({d4}, mem[r3]++); // A0
162 vmla_f32(q10, q4, d1[0]);
163 vld1_32({d5}, mem[r12]++); // A1
164 vmla_f32(q12, q4, d2[0]);
165
166 // BLOCK 1
167 vmla_f32(q14, q4, d3[0]);
168 vldr(d12, mem[r9, 32]); // B1
169 vmla_f32(q9, q5, d0[0]);
170 vldr(d9, mem[r9, 72]); // B0
171 vmla_f32(q11, q5, d1[0]);
172
173 // BLOCK 2
174 vmla_f32(q13, q5, d2[0]);
175 vld1_32({d6}, mem[r10]++); // A2
176 vmla_f32(q15, q5, d3[0]);
177 vld1_32({d7}, mem[r7]++); // A3
178 vmla_f32(q8, q6, d0[1]);
179
180 // BLOCK 3
181 vmla_f32(q10, q6, d1[1]);
182 vldr(d14, mem[r9, 48]); // B1
183 vmla_f32(q12, q6, d2[1]);
184 vldr(d11, mem[r9, 88]); // B0
185 vmla_f32(q14, q6, d3[1]);
186
187 // BLOCK 4
188 vmla_f32(q9, q7, d0[1]);
189 vldr(d8, mem[r9, 64]); // B0
190 vmla_f32(q11, q7, d1[1]);
191 vldr(d13, mem[r9, 104]); // B1
192 vmla_f32(q13, q7, d2[1]);
193 vldr(d10, mem[r9, 80]); // B0
194
195 // BLOCK 5
196 vmla_f32(q15, q7, d3[1]);
197 vldr(d15, mem[r9, 120]); // B1
198
199 // Second group of 16 FMA, First group of loads
200 // BLOCK 0
201 vmla_f32(q8, q4, d4[0]);
202 vld1_32({d0}, mem[r3]++); // A0
203 vmla_f32(q10, q4, d5[0]);
204 vld1_32({d1}, mem[r12]++); // A1
205 vmla_f32(q12, q4, d6[0]);
206
207 // BLOCK 1
208 vmla_f32(q14, q4, d7[0]);
209 vldr(d12, mem[r9, 96]); // B1
210 vmla_f32(q9, q5, d4[0]);
211 vldr(d9, mem[r9, 136]); // B0
212 vmla_f32(q11, q5, d5[0]);
213
214 // BLOCK 2
215 vmla_f32(q13, q5, d6[0]);
216 vld1_32({d2}, mem[r10]++); // A2
217 vmla_f32(q15, q5, d7[0]);
218 vld1_32({d3}, mem[r7]++); // A3
219 vmla_f32(q8, q6, d4[1]);
220 subs(r5, r5, 16);
221
222 // BLOCK 3
223 vmla_f32(q10, q6, d5[1]);
224 vldr(d14, mem[r9, 112]); // B1
225 vmla_f32(q12, q6, d6[1]);
226 vldr(d11, mem[r9, 152]); // B0
227 vmla_f32(q14, q6, d7[1]);
228
229 // BLOCK 4
230 vmla_f32(q9, q7, d4[1]);
231 vldr(d8, mem[r9, 128]); // B0
232 vmla_f32(q11, q7, d5[1]);
233 vldr(d13, mem[r9, 168]); // B1
234 vmla_f32(q13, q7, d6[1]);
235 vldr(d10, mem[r9, 144]); // B0
236
237 // BLOCK 5
238 vmla_f32(q15, q7, d7[1]);
239 vldr(d15, mem[r9, 184]); // B1
240 add(r9, r9, 128); // B++
241 bhs(l2);
242
243 // Epilogue - 4 floats of A (16 bytes)
244 bind(l3);
245 // First group of 16 FMA, Second group loads
246 // BLOCK 0
247 vmla_f32(q8, q4, d0[0]);
248 vld1_32({d4}, mem[r3]++); // A0
249 vmla_f32(q10, q4, d1[0]);
250 vld1_32({d5}, mem[r12]++); // A1
251 vmla_f32(q12, q4, d2[0]);
252
253 // BLOCK 1
254 vmla_f32(q14, q4, d3[0]);
255 vldr(d12, mem[r9, 32]); // B1
256 vmla_f32(q9, q5, d0[0]);
257 vldr(d9, mem[r9, 72]); // B0
258 vmla_f32(q11, q5, d1[0]);
259
260 // BLOCK 2
261 vmla_f32(q13, q5, d2[0]);
262 vld1_32({d6}, mem[r10]++); // A2
263 vmla_f32(q15, q5, d3[0]);
264 vld1_32({d7}, mem[r7]++); // A3
265 vmla_f32(q8, q6, d0[1]);
266
267 // BLOCK 3
268 vmla_f32(q10, q6, d1[1]);
269 vldr(d14, mem[r9, 48]); // B1
270 vmla_f32(q12, q6, d2[1]);
271 vldr(d11, mem[r9, 88]); // B0
272 vmla_f32(q14, q6, d3[1]);
273
274 // BLOCK 4
275 vmla_f32(q9, q7, d0[1]);
276 vldr(d8, mem[r9, 64]); // B0
277 vmla_f32(q11, q7, d1[1]);
278 vldr(d13, mem[r9, 104]); // B1
279 vmla_f32(q13, q7, d2[1]);
280 vldr(d10, mem[r9, 80]); // B0
281
282 // BLOCK 5
283 vmla_f32(q15, q7, d3[1]);
284 vldr(d15, mem[r9, 120]); // B1
285
286 // Second group of 16 FMA, First group of loads
287 // BLOCK 0
288 vmla_f32(q8, q4, d4[0]);
289 vldr(d12, mem[r9, 96]); // B1
290 vmla_f32(q10, q4, d5[0]);
291 vmla_f32(q12, q4, d6[0]);
292
293 // BLOCK 1
294 vmla_f32(q14, q4, d7[0]);
295 vldr(d14, mem[r9, 112]); // B1
296 vmla_f32(q9, q5, d4[0]);
297 vmla_f32(q11, q5, d5[0]);
298
299 // BLOCK 2
300 vmla_f32(q13, q5, d6[0]);
301 vmla_f32(q15, q5, d7[0]);
302 vmla_f32(q8, q6, d4[1]);
303 add(r9, r9, 128); // B++
304
305 // BLOCK 3
306 vmla_f32(q10, q6, d5[1]);
307 vmla_f32(q12, q6, d6[1]);
308 vmla_f32(q14, q6, d7[1]);
309 tst(r5, 15);
310
311 // BLOCK 4
312 vmla_f32(q9, q7, d4[1]);
313 vmla_f32(q11, q7, d5[1]);
314 vmla_f32(q13, q7, d6[1]);
315
316 // BLOCK 5
317 vmla_f32(q15, q7, d7[1]);
318
319 // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
320 bne(l5);
321
322 align(8);
323 bind(l4);
324 ldr(r5, mem[sp, 104]); // a
325 subs(r14, r14, 16); // ks -= MR * sizeof(void*)
326
327 // ks loop
328 bhi(l1);
329
330 // Load params pointer
331 ldr(r0, mem[sp, 132]); // params
332 ldr(r14, mem[sp, 64]); // p = ks
333 // Load min/max values
334 vld1r_32({d4,d5}, mem[r0]++);
335 vld1r_32({d6,d7}, mem[r0]);
336 subs(r1, r1, 8);
337 ldr(r0, mem[sp, 120]); // cn_stride
338
339 // Clamp
340 vmax_f32(q8, q8, q2);
341 vmax_f32(q9, q9, q2);
342 vmax_f32(q10, q10, q2);
343 vmax_f32(q11, q11, q2);
344 vmax_f32(q12, q12, q2);
345 vmax_f32(q13, q13, q2);
346 vmax_f32(q14, q14, q2);
347 vmax_f32(q15, q15, q2);
348 vmin_f32(q8, q8, q3);
349 vmin_f32(q9, q9, q3);
350 vmin_f32(q10, q10, q3);
351 vmin_f32(q11, q11, q3);
352 vmin_f32(q12, q12, q3);
353 vmin_f32(q13, q13, q3);
354 vmin_f32(q14, q14, q3);
355 vmin_f32(q15, q15, q3);
356
357 // Store full 4 x 8
358 blo(l7);
359 vst1_32({d28-d31}, mem[r6], r0);
360 vst1_32({d24-d27}, mem[r8], r0);
361 vst1_32({d20-d23}, mem[r4], r0);
362 vst1_32({d16-d19}, mem[r11], r0);
363
364 sub(r5, r5, r14); // a -= ks
365
366 bhi(l0);
367
368 vpop({d8-d15});
369 add(sp, sp, 4); // skip r3
370 pop({r4, r5, r6, r7, r8, r9, r10, r11, pc});
371
372 align(8);
373 bind(l5);
374 // Is there a remainder?- 2 floats of A (8 bytes)
375 tst(r5, 8);
376 beq(l6);
377
378 // Remainder - 2 floats of A (8 bytes)
379 vld1_32({d0}, mem[r3]++); // A0
380 vldm(mem[r9]++, {d8-d11}); // B0
381 vld1_32({d1}, mem[r12]++); // A1
382 vld1_32({d2}, mem[r10]++); // A2
383 vld1_32({d3}, mem[r7]++); // A3
384
385 vmla_f32(q8, q4, d0[0]);
386 vmla_f32(q9, q5, d0[0]);
387 vmla_f32(q10, q4, d1[0]);
388 vmla_f32(q11, q5, d1[0]);
389 vldm(mem[r9]++, {d12-d15}); // B1
390 vmla_f32(q12, q4, d2[0]);
391 vmla_f32(q13, q5, d2[0]);
392 vmla_f32(q14, q4, d3[0]);
393 vmla_f32(q15, q5, d3[0]);
394 vmla_f32(q8, q6, d0[1]);
395 vmla_f32(q9, q7, d0[1]);
396 vmla_f32(q10, q6, d1[1]);
397 vmla_f32(q11, q7, d1[1]);
398 vmla_f32(q12, q6, d2[1]);
399 vmla_f32(q13, q7, d2[1]);
400 vmla_f32(q14, q6, d3[1]);
401 vmla_f32(q15, q7, d3[1]);
402
403 // Is there a remainder?- 1 float of A (4 bytes)
404 tst(r5, 4);
405 beq(l4);
406
407 bind(l6);
408 // Remainder- 1 float of A (4 bytes)
409 vldm(mem[r3]++, {s0}); // A0
410 vldm(mem[r9]++, {d8-d11}); // B0
411 vldm(mem[r12]++, {s2}); // A1
412 vldm(mem[r10]++, {s4}); // A2
413 vldm(mem[r7]++, {s6}); // A3
414 vmla_f32(q8, q4, d0[0]);
415 vmla_f32(q9, q5, d0[0]);
416 vmla_f32(q10, q4, d1[0]);
417 vmla_f32(q11, q5, d1[0]);
418 vmla_f32(q12, q4, d2[0]);
419 vmla_f32(q13, q5, d2[0]);
420 vmla_f32(q14, q4, d3[0]);
421 vmla_f32(q15, q5, d3[0]);
422 b(l4);
423
424 // Store odd width
425 bind(l7);
426 tst(r1, 4);
427 beq(l8);
428 vst1_32({d28-d29}, mem[r6]++);
429 vst1_32({d24-d25}, mem[r8]++);
430 vmov(q14, q15);
431 vmov(q12, q13);
432 vst1_32({d20-d21}, mem[r4]++);
433 vst1_32({d16-d17}, mem[r11]++);
434 vmov(q10, q11);
435 vmov(q8, q9);
436
437 bind(l8);
438 tst(r1, 2);
439 beq(l9);
440 vst1_32({d28}, mem[r6]++);
441 vst1_32({d24}, mem[r8]++);
442 vmov(d28, d29);
443 vmov(d24, d25);
444 vst1_32({d20}, mem[r4]++);
445 vst1_32({d16}, mem[r11]++);
446 vmov(d20, d21);
447 vmov(d16, d17);
448
449 bind(l9);
450 tst(r1, 1);
451 beq(l10);
452 vst1_32({d28[0]}, mem[r6]++);
453 vst1_32({d24[0]}, mem[r8]++);
454 vst1_32({d20[0]}, mem[r4]++);
455 vst1_32({d16[0]}, mem[r11]++);
456
457 bind(l10);
458 vpop({d8-d15});
459 add(sp, sp, 4); // skip r3
460 pop({r4, r5, r6, r7, r8, r9, r10, r11, pc});
461 }
462 } // namespace
463 } // aarch32
464 } // xnnpack
465
xnn_generate_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a55(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)466 xnn_status_t xnn_generate_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a55(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
467 using namespace xnnpack::aarch32;
468 Generator g(code);
469 assert(params != nullptr);
470 g.generate(max_mr, nc_mod_nr, kc, nullptr);
471 g.finalize();
472 if (g.error() != xnnpack::Error::kNoError) {
473 return xnn_status_invalid_state;
474 }
475 return xnn_status_success;
476 }
477