xref: /aosp_15_r20/external/XNNPACK/src/f32-gemm/4x8-aarch32-neon-cortex-a75.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 
7 #include <cassert>
8 
9 #include <xnnpack.h>
10 #include <xnnpack/aarch32-assembler.h>
11 #include <xnnpack/allocator.h>
12 #include <xnnpack/gemm.h>
13 
14 namespace xnnpack {
15 namespace aarch32 {
16 namespace {
17 class Generator : public Assembler {
18   using Assembler::Assembler;
19  public:
20   void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params);
21 };
22 
23 
24 // void xnn_f32_gemm_minmax_ukernel_4x8__aarch32_neon_prfm_cortex_a75(
25 //     size_t mr,                            r0
26 //     size_t nc,                            r1
27 //     size_t kc,                            r2 -> r5
28 //     const uint8_t*restrict a,             r3
29 //     size_t a_stride,          sp + 96  -> (r7)
30 //     const void*restrict w,    sp + 100 -> r9
31 //     uint8_t*restrict c,       sp + 104 -> r11
32 //     size_t cm_stride,         sp + 108 -> (r6)
33 //     size_t cn_stride,         sp + 112 -> r7
34 //     const union xnn_f32_minmax_params params)  sp + 116 -> (r7)
35 
36 // d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved.
37 
38 // Register usage
39 
40 // A0   r3  d0
41 // A1  r12  d1
42 // A2  r10  d2
43 // A3   r0  d3
44 
45 // B    r9  d8,  d9, d10, d11
46 // B       d12, d13, d14, d15
47 
48 // C0  r11 d16-d17  q8  d18-d19  q9
49 // C1   r4 d20-d21 q10  d22-d23 q11
50 // C2   r8 d24-d25 q12  d26-d27 q13
51 // C3   r6 d28-d29 q14  d30-d31 q15
52 
53 // Clamp (r5) d4 d5 d6 d7
54 
55 // Converted from: src/f32-gemm/gen/4x8-minmax-aarch32-neon-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)56 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
57 {
58   assert(nc_mod_nr < 8);
59   assert(kc != 0);
60   assert(kc % sizeof(float) == 0);
61 
62   Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9;
63 
64   // Push 96 bytes
65   push({r4, r5, r6, r7, r8, r9, r10, r11}); // 32
66   vpush({d8-d15}); // +64 = 96
67 
68   ldr(r7, mem[sp, 96]); // a_stride
69   ldr(r6, mem[sp, 108]); // cm_stride
70   ldr(r11, mem[sp, 104]); // c
71   ldr(r9, mem[sp, 100]); // w
72 
73   // Clamp A and C pointers
74   cmp(r0, 2); // if mr >= 2
75   add(r12, r3, r7); //   a1 = a0 + a_stride
76   add(r4, r11, r6); //   c1 = c0 + cm_stride
77   movlo(r12, r3); // a1
78   movlo(r4, r11); // c1
79   // if mr > 2
80   add(r10, r12, r7); //   a2 = a1 + a_stride
81   add(r8, r4, r6); //   c2 = c1 + cm_stride
82   movls(r10, r12); // a2
83   movls(r8, r4); // c2
84 
85   cmp(r0, 4); // if mr >=4
86   add(r0, r10, r7); //   a3 = a2 + a_stride
87   add(r6, r8, r6); //   c3 = c2 + cm_stride
88   movlo(r0, r10); // a3
89   movlo(r6, r8); // c3
90 
91   ldr(r7, mem[sp, 112]); // cn_stride
92 
93   align(8);
94   bind(l0);
95   // Load initial bias from w into accumulators
96   vldm(mem[r9]++, {d16-d19}); // Bias
97   subs(r5, r2, 16);
98   vmov(q10, q8);
99   vmov(q11, q9);
100   vmov(q12, q8);
101   vmov(q13, q9);
102   vmov(q14, q8);
103   vmov(q15, q9);
104 
105   if (prefetch) {
106     pld(mem[r3, 0]); // Prefetch A
107     pld(mem[r3, 64]);
108     pld(mem[r12, 0]);
109     pld(mem[r12, 64]);
110     pld(mem[r10, 0]);
111     pld(mem[r10, 64]);
112     pld(mem[r0, 0]);
113     pld(mem[r0, 64]);
114     pld(mem[r9, 0]); // Prefetch B
115     pld(mem[r9, 64]);
116     pld(mem[r9, 128]);
117     pld(mem[r9, 192]);
118     pld(mem[r9, 256]);
119     pld(mem[r9, 320]);
120     pld(mem[r9, 384]);
121   }
122 
123   blo(l4); // less than 4 channels?
124 
125   // Prologue
126   vld1_32({d0}, mem[r3]++); // A0
127   vldm(mem[r9]++, {d8-d11}); // B0
128   vld1_32({d1}, mem[r12]++); // A1
129   vld1_32({d2}, mem[r10]++); // A2
130   vld1_32({d3}, mem[r0]++); // A3
131 
132   subs(r5, r5, 16);
133   blo(l2); // less than 4 channels?  skip main loop
134 
135   align(8);
136 
137   // Main loop - 4 floats of A (16 bytes)
138   bind(l1);
139   vmla_f32(q8, q4, d0[0]);
140   vldm(mem[r9]++, {d12-d15}); // B1
141   vmla_f32(q10, q4, d1[0]);
142   vmla_f32(q12, q4, d2[0]);
143   vld1_32({d4}, mem[r3]++); // A0
144   vmla_f32(q14, q4, d3[0]);
145   vmla_f32(q9, q5, d0[0]);
146   vld1_32({d5}, mem[r12]++); // A1
147   vmla_f32(q11, q5, d1[0]);
148   vmla_f32(q13, q5, d2[0]);
149   vmla_f32(q15, q5, d3[0]);
150   vld1_32({d6}, mem[r10]++); // A2
151   vmla_f32(q8, q6, d0[1]);
152   vmla_f32(q10, q6, d1[1]);
153   vld1_32({d7}, mem[r0]++); // A3
154   vmla_f32(q12, q6, d2[1]);
155   vmla_f32(q14, q6, d3[1]);
156   vldm(mem[r9]++, {d8-d11}); // B0
157   vmla_f32(q9, q7, d0[1]);
158   vmla_f32(q11, q7, d1[1]);
159   vmla_f32(q13, q7, d2[1]);
160   vmla_f32(q15, q7, d3[1]);
161 
162   vmla_f32(q8, q4, d4[0]);
163   vldm(mem[r9]++, {d12-d15}); // B1
164   vmla_f32(q10, q4, d5[0]);
165   if (prefetch) {
166     pld(mem[r3, 128]); // Prefetch A0
167   }
168   vmla_f32(q12, q4, d6[0]);
169   vld1_32({d0}, mem[r3]++); // A0
170   vmla_f32(q14, q4, d7[0]);
171   if (prefetch) {
172     pld(mem[r12, 128]); // Prefetch A1
173   }
174   vmla_f32(q9, q5, d4[0]);
175   vld1_32({d1}, mem[r12]++); // A1
176   vmla_f32(q11, q5, d5[0]);
177   if (prefetch) {
178     pld(mem[r10, 128]); // Prefetch A2
179   }
180   vmla_f32(q13, q5, d6[0]);
181   vld1_32({d2}, mem[r10]++); // A2
182   vmla_f32(q15, q5, d7[0]);
183   if (prefetch) {
184     pld(mem[r0, 128]); // Prefetch A3
185   }
186   vmla_f32(q8, q6, d4[1]);
187   vld1_32({d3}, mem[r0]++); // A3
188   vmla_f32(q10, q6, d5[1]);
189   if (prefetch) {
190     pld(mem[r9, 352]); // Prefetch B
191   }
192   vmla_f32(q12, q6, d6[1]);
193   if (prefetch) {
194     pld(mem[r9, 416]); // Prefetch B
195   }
196   vmla_f32(q14, q6, d7[1]);
197   vldm(mem[r9]++, {d8-d11}); // B0
198   vmla_f32(q9, q7, d4[1]);
199   vmla_f32(q11, q7, d5[1]);
200   subs(r5, r5, 16);
201   vmla_f32(q13, q7, d6[1]);
202   vmla_f32(q15, q7, d7[1]);
203   bhs(l1);
204 
205   // Epilogue
206   bind(l2);
207   vmla_f32(q8, q4, d0[0]);
208   vldm(mem[r9]++, {d12-d15}); // B1
209   vmla_f32(q10, q4, d1[0]);
210   vmla_f32(q12, q4, d2[0]);
211   vld1_32({d4}, mem[r3]++); // A0
212   vmla_f32(q14, q4, d3[0]);
213   vmla_f32(q9, q5, d0[0]);
214   vld1_32({d5}, mem[r12]++); // A1
215   vmla_f32(q11, q5, d1[0]);
216   vmla_f32(q13, q5, d2[0]);
217   vmla_f32(q15, q5, d3[0]);
218   vld1_32({d6}, mem[r10]++); // A2
219   vmla_f32(q8, q6, d0[1]);
220   vmla_f32(q10, q6, d1[1]);
221   vld1_32({d7}, mem[r0]++); // A3
222   vmla_f32(q12, q6, d2[1]);
223   vmla_f32(q14, q6, d3[1]);
224   vldm(mem[r9]++, {d8-d11}); // B0
225   vmla_f32(q9, q7, d0[1]);
226   vmla_f32(q11, q7, d1[1]);
227   vmla_f32(q13, q7, d2[1]);
228   vmla_f32(q15, q7, d3[1]);
229 
230   vmla_f32(q8, q4, d4[0]);
231   vldm(mem[r9]++, {d12-d15}); // B1
232   vmla_f32(q10, q4, d5[0]);
233   vmla_f32(q12, q4, d6[0]);
234   vmla_f32(q14, q4, d7[0]);
235   vmla_f32(q9, q5, d4[0]);
236   vmla_f32(q11, q5, d5[0]);
237   vmla_f32(q13, q5, d6[0]);
238   vmla_f32(q15, q5, d7[0]);
239   vmla_f32(q8, q6, d4[1]);
240   vmla_f32(q10, q6, d5[1]);
241   vmla_f32(q12, q6, d6[1]);
242   vmla_f32(q14, q6, d7[1]);
243   vmla_f32(q9, q7, d4[1]);
244   vmla_f32(q11, q7, d5[1]);
245   tst(r5, 15);
246   vmla_f32(q13, q7, d6[1]);
247   vmla_f32(q15, q7, d7[1]);
248 
249   // Is there a remainder?- 1 to 3 floats of A (4, 8 or 12 bytes)
250   bne(l4);
251 
252   align(8);
253   bind(l3);
254   // Load params pointer
255   ldr(r5, mem[sp, 116]); // params
256 
257   // Load min/max values
258   vld1r_32({d4,d5}, mem[r5]++);
259   subs(r1, r1, 8);
260   vld1r_32({d6,d7}, mem[r5]);
261 
262   // Clamp
263   vmax_f32(q8, q8, q2);
264   vmax_f32(q9, q9, q2);
265   vmax_f32(q10, q10, q2);
266   vmax_f32(q11, q11, q2);
267   vmax_f32(q12, q12, q2);
268   vmax_f32(q13, q13, q2);
269   vmax_f32(q14, q14, q2);
270   vmax_f32(q15, q15, q2);
271   vmin_f32(q8, q8, q3);
272   vmin_f32(q9, q9, q3);
273   vmin_f32(q10, q10, q3);
274   vmin_f32(q11, q11, q3);
275   vmin_f32(q12, q12, q3);
276   vmin_f32(q13, q13, q3);
277   vmin_f32(q14, q14, q3);
278   vmin_f32(q15, q15, q3);
279 
280   // Store full 4 x 8
281   blo(l6);
282   vst1_32({d16-d19}, mem[r11], r7);
283   sub(r0, r0, r2);
284   vst1_32({d20-d23}, mem[r4], r7);
285   sub(r10, r10, r2);
286   vst1_32({d24-d27}, mem[r8], r7);
287   sub(r12, r12, r2);
288   vst1_32({d28-d31}, mem[r6], r7);
289   sub(r3, r3, r2);
290   bhi(l0);
291 
292   vpop({d8-d15});
293   pop({r4, r5, r6, r7, r8, r9, r10, r11});
294   bx(lr);
295 
296   align(8);
297   bind(l4);
298   // Is there a remainder?- 2 floats of A (8 bytes)
299   tst(r5, 8);
300   beq(l5);
301 
302   // Remainder - 2 floats of A (8 bytes)
303   vld1_32({d0}, mem[r3]++); // A0
304   vldm(mem[r9]++, {d8-d11}); // B0
305   vld1_32({d1}, mem[r12]++); // A1
306   vld1_32({d2}, mem[r10]++); // A2
307   vld1_32({d3}, mem[r0]++); // A3
308 
309   vmla_f32(q8, q4, d0[0]);
310   vmla_f32(q9, q5, d0[0]);
311   vmla_f32(q10, q4, d1[0]);
312   vmla_f32(q11, q5, d1[0]);
313   vldm(mem[r9]++, {d12-d15}); // B1
314   vmla_f32(q12, q4, d2[0]);
315   vmla_f32(q13, q5, d2[0]);
316   vmla_f32(q14, q4, d3[0]);
317   vmla_f32(q15, q5, d3[0]);
318   vmla_f32(q8, q6, d0[1]);
319   vmla_f32(q9, q7, d0[1]);
320   vmla_f32(q10, q6, d1[1]);
321   vmla_f32(q11, q7, d1[1]);
322   vmla_f32(q12, q6, d2[1]);
323   vmla_f32(q13, q7, d2[1]);
324   vmla_f32(q14, q6, d3[1]);
325   vmla_f32(q15, q7, d3[1]);
326 
327   // Is there a remainder?- 1 float of A (4 bytes)
328   tst(r5, 4);
329   beq(l3);
330 
331   bind(l5);
332   // Remainder- 1 float of A (4 bytes)
333   vldm(mem[r3]++, {s0}); // A0
334   vldm(mem[r9]++, {d8-d11}); // B0
335   vldm(mem[r12]++, {s2}); // A1
336   vldm(mem[r10]++, {s4}); // A2
337   vldm(mem[r0]++, {s6}); // A3
338   vmla_f32(q8, q4, d0[0]);
339   vmla_f32(q9, q5, d0[0]);
340   vmla_f32(q10, q4, d1[0]);
341   vmla_f32(q11, q5, d1[0]);
342   vmla_f32(q12, q4, d2[0]);
343   vmla_f32(q13, q5, d2[0]);
344   vmla_f32(q14, q4, d3[0]);
345   vmla_f32(q15, q5, d3[0]);
346   b(l3);
347 
348   // Store odd width
349   bind(l6);
350   tst(r1, 4);
351   beq(l7);
352   vst1_32({d16-d17}, mem[r11]++);
353   vst1_32({d20-d21}, mem[r4]++);
354   vmov(q8, q9);
355   vmov(q10, q11);
356   vst1_32({d24-d25}, mem[r8]++);
357   vst1_32({d28-d29}, mem[r6]++);
358   vmov(q12, q13);
359   vmov(q14, q15);
360 
361   bind(l7);
362   tst(r1, 2);
363   beq(l8);
364   vst1_32({d16}, mem[r11]++);
365   vst1_32({d20}, mem[r4]++);
366   vmov(d16, d17);
367   vmov(d20, d21);
368   vst1_32({d24}, mem[r8]++);
369   vst1_32({d28}, mem[r6]++);
370   vmov(d24, d25);
371   vmov(d28, d29);
372 
373   bind(l8);
374   tst(r1, 1);
375   beq(l9);
376   vst1_32({d16[0]}, mem[r11]);
377   vst1_32({d20[0]}, mem[r4]);
378   vst1_32({d24[0]}, mem[r8]);
379   vst1_32({d28[0]}, mem[r6]);
380 
381   bind(l9);
382   vpop({d8-d15});
383   pop({r4, r5, r6, r7, r8, r9, r10, r11});
384   bx(lr);
385 }
386 }  // namespace
387 }  // aarch32
388 }  // xnnpack
389 
xnn_generate_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)390 xnn_status_t xnn_generate_f32_gemm_ukernel_4x8__aarch32_neon_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
391   using namespace xnnpack::aarch32;
392   Generator g(code);
393   assert(params != nullptr);
394   g.generate(false, max_mr, nc_mod_nr, kc, nullptr);
395   g.finalize();
396   if (g.error() != xnnpack::Error::kNoError) {
397     return xnn_status_invalid_state;
398   }
399   return xnn_status_success;
400 }
401 
xnn_generate_f32_gemm_ukernel_4x8__aarch32_neon_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)402 xnn_status_t xnn_generate_f32_gemm_ukernel_4x8__aarch32_neon_prfm_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
403   using namespace xnnpack::aarch32;
404   Generator g(code);
405   assert(params != nullptr);
406   g.generate(true, max_mr, nc_mod_nr, kc, nullptr);
407   g.finalize();
408   if (g.error() != xnnpack::Error::kNoError) {
409     return xnn_status_invalid_state;
410   }
411   return xnn_status_success;
412 }
413