xref: /aosp_15_r20/external/XNNPACK/src/f32-igemm/6x8-aarch64-neonfma-ld128.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <cassert>
7 #include <cstddef>
8 #include <limits>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/aarch64-assembler.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/igemm.h>
14 
15 namespace xnnpack {
16 namespace aarch64 {
17 namespace {
18 class Generator : public Assembler {
19   using Assembler::Assembler;
20  public:
21   void generate(size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max);
22 };
23 
24 // void xnn_f32_igemm_minmax_ukernel_6x8__aarch64_neonfma_ld128(
25 //     size_t mr,                         x0
26 //     size_t nc,                         x1
27 //     size_t kc,                         x2 / x0
28 //     size_t ks,                         x3 / x9
29 //     const float**restrict a,           x4
30 //     const void*restrict w,             x5
31 //     uint8_t*restrict c,                x6
32 //     size_t cm_stride,                  x7
33 //     size_t cn_stride,                  [sp] -> (x0)
34 //     size_t a_offset,                   [sp + 8] -> x11
35 //     const float* zero,                 [sp + 16] -> x12
36 //     const xnn_f32_minmax_params params [sp + 24] -> x8
37 
38 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
39 
40 // A pointers
41 // x14 a0
42 // x15 a1
43 // x20 a2
44 // x21 a3
45 // x22 a4
46 // x23 a5
47 
48 // C pointers
49 //  x6 c0
50 // x16 c1
51 // x17 c2
52 // x10 c3
53 // x13 c4
54 //  x7 c5
55 
56 // Vector register usage
57 // A0   v0
58 // A1   v1
59 // A2   v2
60 // A3   v3
61 // A4   v4
62 // A5   v5
63 // B   v16 v17 v18 v19
64 // C   v20 v21
65 // C   v22 v23
66 // C   v24 v25
67 // C   v26 v27
68 // C   v28 v29
69 // C   v30 v31
70 // Clamp v6 v7
71 // unused A   v8 v9 v10 v11
72 // unused B   v12 v13 v14 v15
73 
74 // Converted from: src/f32-igemm/gen/6x8-minmax-aarch64-neonfma-ld128.S
generate(size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,float min,float max)75 void Generator::generate(size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max) {
76   assert(max_mr <= 6);
77   assert(nc_mod_nr < 8);
78   assert(kc != 0);
79   assert(kc % sizeof(float) == 0);
80 
81   Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9;
82 
83   const bool clamp_min = min != -std::numeric_limits<float>::infinity();
84   const bool clamp_max = max != +std::numeric_limits<float>::infinity();
85 
86   // Load zero, params pointer
87   ldp(x12, x8, mem[sp, 16]);
88 
89   // Clamp C pointers
90   cmp(x0, 2); // if mr < 2
91   add(x16, x6, x7); // c1 = c0 + cm_stride
92   csel(x16, x6, x16, kLO); //   c1 = c0
93 
94   // Load min/max values
95   if (clamp_min || clamp_max) {
96     ld2r({v6.v4s(), v7.v4s()}, mem[x8]);
97   }
98 
99   add(x17, x16, x7); // c2 = c1 + cm_stride
100   // if mr <= 2
101   csel(x17, x16, x17, kLS); //   c2 = c1
102 
103   // Save x20,x21,x22,x23 on stack
104   stp(x20, x21, mem[sp, -32]++);
105 
106   cmp(x0, 4); // if mr < 4
107   add(x10, x17, x7); // c3 = c2 + cm_stride
108   csel(x10, x17, x10, kLO); //   c3 = c2
109 
110   stp(x22, x23, mem[sp, 16]);
111 
112   add(x13, x10, x7); // c4 = c3 + cm_stride
113   // if mr <= 4
114   csel(x13, x10, x13, kLS); //   c4 = c3
115 
116   // Load a_offset
117   ldr(x11, mem[sp, 40]);
118 
119   cmp(x0, 6); // if mr < 6
120   add(x7, x13, x7); // c5 = c4 + cm_stride
121   csel(x7, x13, x7, kLO); //   c5 = c4
122 
123   bind(l0);
124   // Load initial bias from w into accumulators
125   ldp(q20, q21, mem[x5], 32);
126   mov(v22.v16b(), v20.v16b());
127   mov(v23.v16b(), v21.v16b());
128   mov(v24.v16b(), v20.v16b());
129   mov(v25.v16b(), v21.v16b());
130   mov(v26.v16b(), v20.v16b());
131   mov(v27.v16b(), v21.v16b());
132   mov(v28.v16b(), v20.v16b());
133   mov(v29.v16b(), v21.v16b());
134   mov(v30.v16b(), v20.v16b());
135   mov(v31.v16b(), v21.v16b());
136 
137   mov(x9, x3); // p = ks
138 
139   bind(l1);
140   // Load next 6 A pointers
141   ldp(x14, x15, mem[x4], 16);
142   ldp(x20, x21, mem[x4], 16);
143   ldp(x22, x23, mem[x4], 16);
144 
145   cmp(x14, x12); // if a0 == zero
146   add(x14, x14, x11); // a0 += a_offset
147   csel(x14, x12, x14, kEQ); //   a0 = zero, else += a0 + a_offset
148   cmp(x15, x12); // if a1 == zero
149   add(x15, x15, x11); // a1 += a_offset
150   csel(x15, x12, x15, kEQ); //   a1 = zero, else += a1 + a_offset
151   cmp(x20, x12); // if a2 == zero
152   add(x20, x20, x11); // a2 += a_offset
153   csel(x20, x12, x20, kEQ); //   a2 = zero, else += a2 + a_offset
154   cmp(x21, x12); // if a3 == zero
155   add(x21, x21, x11); // a3 += a_offset
156   csel(x21, x12, x21, kEQ); //   a3 = zero, else += a3 + a_offset
157   cmp(x22, x12); // if a4 == zero
158   add(x22, x22, x11); // a4 += a_offset
159   csel(x22, x12, x22, kEQ); //   a4 = zero, else += a4 + a_offset
160   cmp(x23, x12); // if a5 == zero
161   add(x23, x23, x11); // a5 += a_offset
162   csel(x23, x12, x23, kEQ); //   a5 = zero, else += a5 + a_offset
163 
164   // Is there at least 4 floats (16 bytes)?
165   subs(x0, x2, 16); // k = kc - 16
166   b_lo(l4);
167 
168   // Main loop - 4 floats of A (16 bytes)
169   // 48 FMA + 6 ld128 A + 4 LDP B
170   bind(l2);
171   ldp(q16, q17, mem[x5], 32);
172   ldr(q0, mem[x14], 16);
173   ldr(q1, mem[x15], 16);
174   ldr(q2, mem[x20], 16);
175   ldr(q3, mem[x21], 16);
176   ldr(q4, mem[x22], 16);
177   ldr(q5, mem[x23], 16);
178   fmla(v20.v4s(), v16.v4s(), v0.s()[0]);
179   fmla(v22.v4s(), v16.v4s(), v1.s()[0]);
180   fmla(v24.v4s(), v16.v4s(), v2.s()[0]);
181   fmla(v26.v4s(), v16.v4s(), v3.s()[0]);
182   ldp(q18, q19, mem[x5], 32);
183   fmla(v28.v4s(), v16.v4s(), v4.s()[0]);
184   fmla(v30.v4s(), v16.v4s(), v5.s()[0]);
185   fmla(v21.v4s(), v17.v4s(), v0.s()[0]);
186   fmla(v23.v4s(), v17.v4s(), v1.s()[0]);
187   fmla(v25.v4s(), v17.v4s(), v2.s()[0]);
188   fmla(v27.v4s(), v17.v4s(), v3.s()[0]);
189   fmla(v29.v4s(), v17.v4s(), v4.s()[0]);
190   fmla(v31.v4s(), v17.v4s(), v5.s()[0]);
191 
192   fmla(v20.v4s(), v18.v4s(), v0.s()[1]);
193   ldp(q16, q17, mem[x5], 32);
194   fmla(v22.v4s(), v18.v4s(), v1.s()[1]);
195   fmla(v24.v4s(), v18.v4s(), v2.s()[1]);
196   fmla(v26.v4s(), v18.v4s(), v3.s()[1]);
197   fmla(v28.v4s(), v18.v4s(), v4.s()[1]);
198   fmla(v30.v4s(), v18.v4s(), v5.s()[1]);
199   fmla(v21.v4s(), v19.v4s(), v0.s()[1]);
200   fmla(v23.v4s(), v19.v4s(), v1.s()[1]);
201   fmla(v25.v4s(), v19.v4s(), v2.s()[1]);
202   fmla(v27.v4s(), v19.v4s(), v3.s()[1]);
203   fmla(v29.v4s(), v19.v4s(), v4.s()[1]);
204   fmla(v31.v4s(), v19.v4s(), v5.s()[1]);
205 
206   fmla(v20.v4s(), v16.v4s(), v0.s()[2]);
207   ldp(q18, q19, mem[x5], 32);
208   fmla(v22.v4s(), v16.v4s(), v1.s()[2]);
209   fmla(v24.v4s(), v16.v4s(), v2.s()[2]);
210   fmla(v26.v4s(), v16.v4s(), v3.s()[2]);
211   fmla(v28.v4s(), v16.v4s(), v4.s()[2]);
212   fmla(v30.v4s(), v16.v4s(), v5.s()[2]);
213   fmla(v21.v4s(), v17.v4s(), v0.s()[2]);
214   fmla(v23.v4s(), v17.v4s(), v1.s()[2]);
215   fmla(v25.v4s(), v17.v4s(), v2.s()[2]);
216   fmla(v27.v4s(), v17.v4s(), v3.s()[2]);
217   fmla(v29.v4s(), v17.v4s(), v4.s()[2]);
218   fmla(v31.v4s(), v17.v4s(), v5.s()[2]);
219 
220   fmla(v20.v4s(), v18.v4s(), v0.s()[3]);
221   fmla(v22.v4s(), v18.v4s(), v1.s()[3]);
222   fmla(v24.v4s(), v18.v4s(), v2.s()[3]);
223   fmla(v26.v4s(), v18.v4s(), v3.s()[3]);
224   fmla(v28.v4s(), v18.v4s(), v4.s()[3]);
225   fmla(v30.v4s(), v18.v4s(), v5.s()[3]);
226   fmla(v21.v4s(), v19.v4s(), v0.s()[3]);
227   fmla(v23.v4s(), v19.v4s(), v1.s()[3]);
228   fmla(v25.v4s(), v19.v4s(), v2.s()[3]);
229   fmla(v27.v4s(), v19.v4s(), v3.s()[3]);
230   subs(x0, x0, 16);
231   fmla(v29.v4s(), v19.v4s(), v4.s()[3]);
232   fmla(v31.v4s(), v19.v4s(), v5.s()[3]);
233   b_hs(l2);
234 
235   // Is there a remainder?- 2 floats of A (8 bytes) or less
236   tst(x0, 15);
237   b_ne(l4);
238 
239   bind(l3);
240   // ks loop
241   subs(x9, x9, 48); // ks -= MR * sizeof(void*)
242   b_hi(l1);
243 
244   // Load cn_stride
245   ldr(x0, mem[sp, 32]);
246   // Clamp
247   if (clamp_min) {
248     fmax(v20.v4s(), v20.v4s(), v6.v4s());
249     fmax(v21.v4s(), v21.v4s(), v6.v4s());
250     fmax(v22.v4s(), v22.v4s(), v6.v4s());
251     fmax(v23.v4s(), v23.v4s(), v6.v4s());
252     fmax(v24.v4s(), v24.v4s(), v6.v4s());
253     fmax(v25.v4s(), v25.v4s(), v6.v4s());
254     fmax(v26.v4s(), v26.v4s(), v6.v4s());
255     fmax(v27.v4s(), v27.v4s(), v6.v4s());
256     fmax(v28.v4s(), v28.v4s(), v6.v4s());
257     fmax(v29.v4s(), v29.v4s(), v6.v4s());
258     fmax(v30.v4s(), v30.v4s(), v6.v4s());
259     fmax(v31.v4s(), v31.v4s(), v6.v4s());
260   }
261   subs(x1, x1, 8);
262   if (clamp_max) {
263     fmin(v20.v4s(), v20.v4s(), v7.v4s());
264     fmin(v21.v4s(), v21.v4s(), v7.v4s());
265     fmin(v22.v4s(), v22.v4s(), v7.v4s());
266     fmin(v23.v4s(), v23.v4s(), v7.v4s());
267     fmin(v24.v4s(), v24.v4s(), v7.v4s());
268     fmin(v25.v4s(), v25.v4s(), v7.v4s());
269     fmin(v26.v4s(), v26.v4s(), v7.v4s());
270     fmin(v27.v4s(), v27.v4s(), v7.v4s());
271     fmin(v28.v4s(), v28.v4s(), v7.v4s());
272     fmin(v29.v4s(), v29.v4s(), v7.v4s());
273     fmin(v30.v4s(), v30.v4s(), v7.v4s());
274     fmin(v31.v4s(), v31.v4s(), v7.v4s());
275   }
276 
277   // Store full 6 x 8
278   b_lo(l6);
279 
280   stp(q30, q31, mem[x7]);
281   add(x7, x7, x0);
282   stp(q28, q29, mem[x13]);
283   add(x13, x13, x0);
284   stp(q26, q27, mem[x10]);
285   add(x10, x10, x0);
286   stp(q24, q25, mem[x17]);
287   add(x17, x17, x0);
288   stp(q22, q23, mem[x16]);
289   add(x16, x16, x0);
290   stp(q20, q21, mem[x6]);
291   add(x6, x6, x0);
292 
293   sub(x4, x4, x3); // a -= ks
294 
295   // nc loop
296   b_hi(l0);
297 
298   // Restore x20,x21,x22,x23 from stack
299   ldp(x22, x23, mem[sp, 16]);
300   ldp(x20, x21, mem[sp], 32);
301   ret();
302 
303   bind(l4);
304   // Is there a remainder?- 2 floats of A (8 bytes)
305   tbz(x0, 3, l5);
306 
307   // Remainder- 2 floats of A (8 bytes)
308   ldr(d0, mem[x14], 8);
309   ldp(q16, q17, mem[x5], 32);
310   ldr(d1, mem[x15], 8);
311   ldr(d2, mem[x20], 8);
312   ldr(d3, mem[x21], 8);
313   ldr(d4, mem[x22], 8);
314   ldr(d5, mem[x23], 8);
315   fmla(v20.v4s(), v16.v4s(), v0.s()[0]);
316   fmla(v22.v4s(), v16.v4s(), v1.s()[0]);
317   fmla(v24.v4s(), v16.v4s(), v2.s()[0]);
318   fmla(v26.v4s(), v16.v4s(), v3.s()[0]);
319   ldp(q18, q19, mem[x5], 32);
320   fmla(v28.v4s(), v16.v4s(), v4.s()[0]);
321   fmla(v30.v4s(), v16.v4s(), v5.s()[0]);
322   fmla(v21.v4s(), v17.v4s(), v0.s()[0]);
323   fmla(v23.v4s(), v17.v4s(), v1.s()[0]);
324   fmla(v25.v4s(), v17.v4s(), v2.s()[0]);
325   fmla(v27.v4s(), v17.v4s(), v3.s()[0]);
326   fmla(v29.v4s(), v17.v4s(), v4.s()[0]);
327   fmla(v31.v4s(), v17.v4s(), v5.s()[0]);
328 
329   fmla(v20.v4s(), v18.v4s(), v0.s()[1]);
330   fmla(v22.v4s(), v18.v4s(), v1.s()[1]);
331   fmla(v24.v4s(), v18.v4s(), v2.s()[1]);
332   fmla(v26.v4s(), v18.v4s(), v3.s()[1]);
333   fmla(v28.v4s(), v18.v4s(), v4.s()[1]);
334   fmla(v30.v4s(), v18.v4s(), v5.s()[1]);
335   fmla(v21.v4s(), v19.v4s(), v0.s()[1]);
336   fmla(v23.v4s(), v19.v4s(), v1.s()[1]);
337   fmla(v25.v4s(), v19.v4s(), v2.s()[1]);
338   fmla(v27.v4s(), v19.v4s(), v3.s()[1]);
339   fmla(v29.v4s(), v19.v4s(), v4.s()[1]);
340   fmla(v31.v4s(), v19.v4s(), v5.s()[1]);
341 
342   // Is there a remainder?- 1 float of A (4 bytes)
343   tbz(x0, 2, l3);
344 
345   // Remainder- 1 float of A (4 bytes)
346   bind(l5);
347   ldr(s0, mem[x14], 4);
348   ldp(q16, q17, mem[x5], 32);
349   ldr(s1, mem[x15], 4);
350   ldr(s2, mem[x20], 4);
351   ldr(s3, mem[x21], 4);
352   ldr(s4, mem[x22], 4);
353   ldr(s5, mem[x23], 4);
354   fmla(v20.v4s(), v16.v4s(), v0.s()[0]);
355   fmla(v22.v4s(), v16.v4s(), v1.s()[0]);
356   fmla(v24.v4s(), v16.v4s(), v2.s()[0]);
357   fmla(v26.v4s(), v16.v4s(), v3.s()[0]);
358   fmla(v28.v4s(), v16.v4s(), v4.s()[0]);
359   fmla(v30.v4s(), v16.v4s(), v5.s()[0]);
360   fmla(v21.v4s(), v17.v4s(), v0.s()[0]);
361   fmla(v23.v4s(), v17.v4s(), v1.s()[0]);
362   fmla(v25.v4s(), v17.v4s(), v2.s()[0]);
363   fmla(v27.v4s(), v17.v4s(), v3.s()[0]);
364   fmla(v29.v4s(), v17.v4s(), v4.s()[0]);
365   fmla(v31.v4s(), v17.v4s(), v5.s()[0]);
366   b(l3);
367 
368   // Store odd width
369   bind(l6);
370   tbz(x1, 2, l7);
371   str(q30, mem[x7], 16);
372   mov(v30.v16b(), v31.v16b());
373   str(q28, mem[x13], 16);
374   mov(v28.v16b(), v29.v16b());
375   str(q26, mem[x10], 16);
376   mov(v26.v16b(), v27.v16b());
377   str(q24, mem[x17], 16);
378   mov(v24.v16b(), v25.v16b());
379   str(q22, mem[x16], 16);
380   mov(v22.v16b(), v23.v16b());
381   str(q20, mem[x6], 16);
382   mov(v20.v16b(), v21.v16b());
383   bind(l7);
384   tbz(x1, 1, l8);
385   str(d30, mem[x7], 8);
386   str(d28, mem[x13], 8);
387   dup(d30, v30.d()[1]);
388   dup(d28, v28.d()[1]);
389   str(d26, mem[x10], 8);
390   str(d24, mem[x17], 8);
391   dup(d26, v26.d()[1]);
392   dup(d24, v24.d()[1]);
393   str(d22, mem[x16], 8);
394   str(d20, mem[x6], 8);
395   dup(d22, v22.d()[1]);
396   dup(d20, v20.d()[1]);
397 
398   bind(l8);
399   tbz(x1, 0, l9);
400   str(s30, mem[x7]);
401   str(s28, mem[x13]);
402   str(s26, mem[x10]);
403   str(s24, mem[x17]);
404   str(s22, mem[x16]);
405   str(s20, mem[x6]);
406   bind(l9);
407   // Restore x20,x21,x22,x23 from stack
408   ldp(x22, x23, mem[sp, 16]);
409   ldp(x20, x21, mem[sp], 32);
410   ret();
411 
412   align(16, AlignInstruction::kHlt);
413 }
414 }  // namespace
415 }  // aarch64
416 }  // xnnpack
417 
xnn_generate_f32_igemm_ukernel_6x8__aarch64_neonfma_ld128(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)418 xnn_status_t xnn_generate_f32_igemm_ukernel_6x8__aarch64_neonfma_ld128(
419     xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
420   using namespace xnnpack::aarch64;
421   Generator g(code);
422   assert(params != nullptr);
423   const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
424   g.generate(max_mr, nc_mod_nr, kc, ks, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
425   g.finalize();
426   if (g.error() != xnnpack::Error::kNoError) {
427     return xnn_status_invalid_state;
428   }
429   return xnn_status_success;
430 }
431