1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6
7 #include <cassert>
8
9 #include <xnnpack.h>
10 #include <xnnpack/aarch32-assembler.h>
11 #include <xnnpack/allocator.h>
12 #include <xnnpack/igemm.h>
13
14 namespace xnnpack {
15 namespace aarch32 {
16 namespace {
17 class Generator : public Assembler {
18 using Assembler::Assembler;
19 public:
20 void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params);
21 };
22
23
24 // void xnn_f32_igemm_minmax_ukernel_4x8__aarch32_neon_prfm_cortex_a75(
25 // size_t mr, r0
26 // size_t nc, r1
27 // size_t kc, r2 -> r5 -> sp + 68
28 // size_t ks, r3 -> sp + 72 -> r14
29 // const float**restrict a, sp + 112 -> r2
30 // const void*restrict w, sp + 116 -> r9
31 // uint8_t*restrict c, sp + 120 -> r11
32 // size_t cm_stride, sp + 124 -> (r6)
33 // size_t cn_stride, sp + 128 -> (r7)
34 // size_t a_offset, sp + 132 -> (r5)
35 // const float* zero, sp + 136 -> (r7)
36 // minmax_params*params, sp + 140 -> (r5)
37
38 // d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved.
39
40 // Register usage
41
42 // A0 r3 d0
43 // A1 r12 d1
44 // A2 r10 d2
45 // A3 r0 d3
46
47 // B r9 d8, d9, d10, d11
48 // B d12, d13, d14, d15
49
50 // C0 r11 d16-d17 q8 d18-d19 q9
51 // C1 r4 d20-d21 q10 d22-d23 q11
52 // C2 r8 d24-d25 q12 d26-d27 q13
53 // C3 r6 d28-d29 q14 d30-d31 q15
54
55 // Clamp (r5) d4 d5 d6 d7
56
57 // Converted from: src/f32-igemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)58 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
59 {
60 assert(nc_mod_nr < 8);
61 assert(kc != 0);
62 assert(kc % sizeof(float) == 0);
63
64 Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;
65
66 // Push 112 bytes
67 // r2 will be reloaded in outer loop. r3 is ks
68 push({r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, lr}); // +44
69 sub(sp, sp, 4); // 4
70 vpush({d8-d15}); // +64 = 112
71
72 ldr(r11, mem[sp, 120]); // c
73 ldr(r6, mem[sp, 124]); // cm_stride
74 ldr(r2, mem[sp, 112]); // a
75 ldr(r9, mem[sp, 116]); // w
76 mov(r14, r3); // p = ks
77
78 // Clamp C pointers
79 cmp(r0, 2); // if mr >= 2
80 add(r4, r11, r6); // c1 = c0 + cm_stride
81 movlo(r4, r11); // c1
82 // if mr > 2
83 add(r8, r4, r6); // c2 = c1 + cm_stride
84 movls(r8, r4); // c2
85 cmp(r0, 4); // if mr >=4
86 add(r6, r8, r6); // c3 = c2 + cm_stride
87 movlo(r6, r8); // c3
88
89 align(8);
90 bind(l0);
91 // Load initial bias from w into accumulators
92 vldm(mem[r9]++, {d16-d19}); // Bias
93 vmov(q10, q8);
94 vmov(q11, q9);
95 vmov(q12, q8);
96 vmov(q13, q9);
97 vmov(q14, q8);
98 vmov(q15, q9);
99
100 if (prefetch) {
101 pld(mem[r9, 0]); // Prefetch B
102 pld(mem[r9, 64]);
103 pld(mem[r9, 128]);
104 pld(mem[r9, 192]);
105 pld(mem[r9, 256]);
106 pld(mem[r9, 320]);
107 pld(mem[r9, 384]);
108 }
109
110 bind(l1);
111 // Load next 4 A pointers
112 ldr(r3, mem[r2, 0]);
113 ldr(r12, mem[r2, 4]);
114 ldr(r10, mem[r2, 8]);
115 ldr(r0, mem[r2, 12]);
116 add(r2, r2, 16);
117
118 // Add a_offset
119 ldr(r5, mem[sp, 132]); // a_offset
120 ldr(r7, mem[sp, 136]); // zero
121 cmp(r3, r7); // if a0 == zero
122 add(r3, r3, r5); // a0 += a_offset
123 moveq(r3, r7); // a0 = zero, else += a0 + a_offset
124 cmp(r12, r7); // if a1 == zero
125 add(r12, r12, r5); // a1 += a_offset
126 moveq(r12, r7); // a1 = zero, else += a1 + a_offset
127 cmp(r10, r7); // if a2 == zero
128 add(r10, r10, r5); // a2 += a_offset
129 moveq(r10, r7); // a2 = zero, else += a2 + a_offset
130 cmp(r0, r7); // if a3 == zero
131 add(r0, r0, r5); // a3 += a_offset
132 ldr(r5, mem[sp, 68]); // kc
133 moveq(r0, r7); // a3 = zero, else += a3 + a_offset
134
135 if (prefetch) {
136 pld(mem[r3, 0]); // Prefetch A
137 pld(mem[r3, 64]);
138 pld(mem[r12, 0]);
139 pld(mem[r12, 64]);
140 pld(mem[r10, 0]);
141 pld(mem[r10, 64]);
142 pld(mem[r0, 0]);
143 pld(mem[r0, 64]);
144 }
145
146 subs(r5, r5, 16); // kc - 16
147 blo(l5); // less than 4 channels?
148
149 // Prologue
150 vld1_32({d0}, mem[r3]++); // A0
151 vldm(mem[r9]++, {d8-d11}); // B0
152 vld1_32({d1}, mem[r12]++); // A1
153 vld1_32({d2}, mem[r10]++); // A2
154 vld1_32({d3}, mem[r0]++); // A3
155
156 subs(r5, r5, 16);
157 blo(l3); // less than 4 channels? skip main loop
158
159 align(8);
160
161 // Main loop - 4 floats of A (16 bytes)
162 bind(l2);
163 vmla_f32(q8, q4, d0[0]);
164 vldm(mem[r9]++, {d12-d15}); // B1
165 vmla_f32(q10, q4, d1[0]);
166 vmla_f32(q12, q4, d2[0]);
167 vld1_32({d4}, mem[r3]++); // A0
168 vmla_f32(q14, q4, d3[0]);
169 vmla_f32(q9, q5, d0[0]);
170 vld1_32({d5}, mem[r12]++); // A1
171 vmla_f32(q11, q5, d1[0]);
172 vmla_f32(q13, q5, d2[0]);
173 vmla_f32(q15, q5, d3[0]);
174 vld1_32({d6}, mem[r10]++); // A2
175 vmla_f32(q8, q6, d0[1]);
176 vmla_f32(q10, q6, d1[1]);
177 vld1_32({d7}, mem[r0]++); // A3
178 vmla_f32(q12, q6, d2[1]);
179 vmla_f32(q14, q6, d3[1]);
180 vldm(mem[r9]++, {d8-d11}); // B0
181 vmla_f32(q9, q7, d0[1]);
182 vmla_f32(q11, q7, d1[1]);
183 vmla_f32(q13, q7, d2[1]);
184 vmla_f32(q15, q7, d3[1]);
185
186 vmla_f32(q8, q4, d4[0]);
187 vldm(mem[r9]++, {d12-d15}); // B1
188 vmla_f32(q10, q4, d5[0]);
189 if (prefetch) {
190 pld(mem[r3, 128]); // Prefetch A0
191 }
192 vmla_f32(q12, q4, d6[0]);
193 vld1_32({d0}, mem[r3]++); // A0
194 vmla_f32(q14, q4, d7[0]);
195 if (prefetch) {
196 pld(mem[r12, 128]); // Prefetch A1
197 }
198 vmla_f32(q9, q5, d4[0]);
199 vld1_32({d1}, mem[r12]++); // A1
200 vmla_f32(q11, q5, d5[0]);
201 if (prefetch) {
202 pld(mem[r10, 128]); // Prefetch A2
203 }
204 vmla_f32(q13, q5, d6[0]);
205 vld1_32({d2}, mem[r10]++); // A2
206 vmla_f32(q15, q5, d7[0]);
207 if (prefetch) {
208 pld(mem[r0, 128]); // Prefetch A3
209 }
210 vmla_f32(q8, q6, d4[1]);
211 vld1_32({d3}, mem[r0]++); // A3
212 vmla_f32(q10, q6, d5[1]);
213 if (prefetch) {
214 pld(mem[r9, 352]); // Prefetch B
215 }
216 vmla_f32(q12, q6, d6[1]);
217 if (prefetch) {
218 pld(mem[r9, 416]); // Prefetch B
219 }
220 vmla_f32(q14, q6, d7[1]);
221 vldm(mem[r9]++, {d8-d11}); // B0
222 vmla_f32(q9, q7, d4[1]);
223 vmla_f32(q11, q7, d5[1]);
224 subs(r5, r5, 16);
225 vmla_f32(q13, q7, d6[1]);
226 vmla_f32(q15, q7, d7[1]);
227 bhs(l2);
228
229 // Epilogue
230 bind(l3);
231 vmla_f32(q8, q4, d0[0]);
232 vldm(mem[r9]++, {d12-d15}); // B1
233 vmla_f32(q10, q4, d1[0]);
234 vmla_f32(q12, q4, d2[0]);
235 vld1_32({d4}, mem[r3]++); // A0
236 vmla_f32(q14, q4, d3[0]);
237 vmla_f32(q9, q5, d0[0]);
238 vld1_32({d5}, mem[r12]++); // A1
239 vmla_f32(q11, q5, d1[0]);
240 vmla_f32(q13, q5, d2[0]);
241 vmla_f32(q15, q5, d3[0]);
242 vld1_32({d6}, mem[r10]++); // A2
243 vmla_f32(q8, q6, d0[1]);
244 vmla_f32(q10, q6, d1[1]);
245 vld1_32({d7}, mem[r0]++); // A3
246 vmla_f32(q12, q6, d2[1]);
247 vmla_f32(q14, q6, d3[1]);
248 vldm(mem[r9]++, {d8-d11}); // B0
249 vmla_f32(q9, q7, d0[1]);
250 vmla_f32(q11, q7, d1[1]);
251 vmla_f32(q13, q7, d2[1]);
252 vmla_f32(q15, q7, d3[1]);
253
254 vmla_f32(q8, q4, d4[0]);
255 vldm(mem[r9]++, {d12-d15}); // B1
256 vmla_f32(q10, q4, d5[0]);
257 vmla_f32(q12, q4, d6[0]);
258 vmla_f32(q14, q4, d7[0]);
259 vmla_f32(q9, q5, d4[0]);
260 vmla_f32(q11, q5, d5[0]);
261 vmla_f32(q13, q5, d6[0]);
262 vmla_f32(q15, q5, d7[0]);
263 vmla_f32(q8, q6, d4[1]);
264 vmla_f32(q10, q6, d5[1]);
265 vmla_f32(q12, q6, d6[1]);
266 vmla_f32(q14, q6, d7[1]);
267 vmla_f32(q9, q7, d4[1]);
268 vmla_f32(q11, q7, d5[1]);
269 vmla_f32(q13, q7, d6[1]);
270 vmla_f32(q15, q7, d7[1]);
271
272 // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
273 tst(r5, 12);
274 bne(l5);
275
276 align(8);
277 bind(l4);
278 // ks loop
279 subs(r14, r14, 16); // ks -= MR * sizeof(void*)
280 bhi(l1);
281
282 // Load params pointer
283 ldr(r5, mem[sp, 140]); // params
284 ldr(r7, mem[sp, 128]); // cn_stride
285 ldr(r14, mem[sp, 72]); // p = ks
286
287 // Load min/max values
288 vld1r_32({d4,d5}, mem[r5]++);
289 subs(r1, r1, 8);
290 vld1r_32({d6,d7}, mem[r5]);
291
292 // Clamp
293 vmax_f32(q8, q8, q2);
294 vmax_f32(q9, q9, q2);
295 vmax_f32(q10, q10, q2);
296 vmax_f32(q11, q11, q2);
297 vmax_f32(q12, q12, q2);
298 vmax_f32(q13, q13, q2);
299 vmax_f32(q14, q14, q2);
300 vmax_f32(q15, q15, q2);
301 vmin_f32(q8, q8, q3);
302 vmin_f32(q9, q9, q3);
303 vmin_f32(q10, q10, q3);
304 vmin_f32(q11, q11, q3);
305 vmin_f32(q12, q12, q3);
306 vmin_f32(q13, q13, q3);
307 vmin_f32(q14, q14, q3);
308 vmin_f32(q15, q15, q3);
309
310 // Store full 4 x 8
311 blo(l7);
312 vst1_32({d28-d31}, mem[r6], r7);
313 vst1_32({d24-d27}, mem[r8], r7);
314 vst1_32({d20-d23}, mem[r4], r7);
315 vst1_32({d16-d19}, mem[r11], r7);
316 sub(r2, r2, r14); // a -= ks
317 bhi(l0);
318
319 vpop({d8-d15});
320 add(sp, sp, 12); // skip pad, r2, r3
321 pop({r4, r5, r6, r7, r8, r9, r10, r11, pc});
322
323 align(8);
324 bind(l5);
325 // Is there a remainder?- 2 floats of A (8 bytes)
326 tst(r5, 8);
327 beq(l6);
328
329 // Remainder - 2 floats of A (8 bytes)
330 vld1_32({d0}, mem[r3]++); // A0
331 vldm(mem[r9]++, {d8-d11}); // B0
332 vld1_32({d1}, mem[r12]++); // A1
333 vld1_32({d2}, mem[r10]++); // A2
334 vld1_32({d3}, mem[r0]++); // A3
335
336 vmla_f32(q8, q4, d0[0]);
337 vmla_f32(q9, q5, d0[0]);
338 vmla_f32(q10, q4, d1[0]);
339 vmla_f32(q11, q5, d1[0]);
340 vldm(mem[r9]++, {d12-d15}); // B1
341 vmla_f32(q12, q4, d2[0]);
342 vmla_f32(q13, q5, d2[0]);
343 vmla_f32(q14, q4, d3[0]);
344 vmla_f32(q15, q5, d3[0]);
345 vmla_f32(q8, q6, d0[1]);
346 vmla_f32(q9, q7, d0[1]);
347 vmla_f32(q10, q6, d1[1]);
348 vmla_f32(q11, q7, d1[1]);
349 vmla_f32(q12, q6, d2[1]);
350 vmla_f32(q13, q7, d2[1]);
351 vmla_f32(q14, q6, d3[1]);
352 vmla_f32(q15, q7, d3[1]);
353
354 // Is there a remainder?- 1 float of A (4 bytes)
355 tst(r5, 4);
356 beq(l4);
357
358 bind(l6);
359 // Remainder- 1 float of A (4 bytes)
360 vldm(mem[r3]++, {s0}); // A0
361 vldm(mem[r9]++, {d8-d11}); // B0
362 vldm(mem[r12]++, {s2}); // A1
363 vldm(mem[r10]++, {s4}); // A2
364 vldm(mem[r0]++, {s6}); // A3
365 vmla_f32(q8, q4, d0[0]);
366 vmla_f32(q9, q5, d0[0]);
367 vmla_f32(q10, q4, d1[0]);
368 vmla_f32(q11, q5, d1[0]);
369 vmla_f32(q12, q4, d2[0]);
370 vmla_f32(q13, q5, d2[0]);
371 vmla_f32(q14, q4, d3[0]);
372 vmla_f32(q15, q5, d3[0]);
373 b(l4);
374
375 // Store odd width
376 bind(l7);
377 tst(r1, 4);
378 beq(l8);
379 vst1_32({d28-d29}, mem[r6]++);
380 vst1_32({d24-d25}, mem[r8]++);
381 vmov(q14, q15);
382 vmov(q12, q13);
383 vst1_32({d20-d21}, mem[r4]++);
384 vst1_32({d16-d17}, mem[r11]++);
385 vmov(q10, q11);
386 vmov(q8, q9);
387
388 bind(l8);
389 tst(r1, 2);
390 beq(l9);
391 vst1_32({d28}, mem[r6]++);
392 vst1_32({d24}, mem[r8]++);
393 vmov(d28, d29);
394 vmov(d24, d25);
395 vst1_32({d20}, mem[r4]++);
396 vst1_32({d16}, mem[r11]++);
397 vmov(d20, d21);
398 vmov(d16, d17);
399
400 bind(l9);
401 tst(r1, 1);
402 beq(l10);
403 vst1_32({d28[0]}, mem[r6]++);
404 vst1_32({d24[0]}, mem[r8]++);
405 vst1_32({d20[0]}, mem[r4]++);
406 vst1_32({d16[0]}, mem[r11]++);
407
408 bind(l10);
409 vpop({d8-d15});
410 add(sp, sp, 12); // skip pad, r2, r3
411 pop({r4, r5, r6, r7, r8, r9, r10, r11, pc});
412 }
413 } // namespace
414 } // aarch32
415 } // xnnpack
416
xnn_generate_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)417 xnn_status_t xnn_generate_f32_igemm_ukernel_4x8__aarch32_neon_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
418 using namespace xnnpack::aarch32;
419 Generator g(code);
420 assert(params != nullptr);
421 g.generate(false, max_mr, nc_mod_nr, kc, nullptr);
422 g.finalize();
423 if (g.error() != xnnpack::Error::kNoError) {
424 return xnn_status_invalid_state;
425 }
426 return xnn_status_success;
427 }
428
xnn_generate_f32_igemm_ukernel_4x8__aarch32_neon_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)429 xnn_status_t xnn_generate_f32_igemm_ukernel_4x8__aarch32_neon_prfm_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
430 using namespace xnnpack::aarch32;
431 Generator g(code);
432 assert(params != nullptr);
433 g.generate(true, max_mr, nc_mod_nr, kc, nullptr);
434 g.finalize();
435 if (g.error() != xnnpack::Error::kNoError) {
436 return xnn_status_invalid_state;
437 }
438 return xnn_status_success;
439 }
440