xref: /aosp_15_r20/external/XNNPACK/src/f32-igemm/4x8-aarch64-neonfma-cortex-a75.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <cassert>
7 #include <cstddef>
8 #include <limits>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/aarch64-assembler.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/igemm.h>
14 
15 namespace xnnpack {
16 namespace aarch64 {
17 namespace {
18 class Generator : public Assembler {
19   using Assembler::Assembler;
20  public:
21   void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max);
22 };
23 
24 // void xnn_f32_igemm_minmax_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(
25 //     size_t mr,                         x0
26 //     size_t nc,                         x1
27 //     size_t kc,                         x2 / x0
28 //     size_t ks,                         x3 / x9
29 //     const float**restrict a,           x4
30 //     const float*restrict w,            x5
31 //     float*restrict c,                  x6
32 //     size_t cm_stride,                  x7
33 //     size_t cn_stride,                  [sp] -> x10
34 //     size_t a_offset,                   [sp + 8] -> x11
35 //     const float* zero,                 [sp + 16] -> x12
36 //     const xnn_f32_minmax_params params [sp + 24] -> x8
37 
38 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
39 // x21 used to store params->max if (!clamp_min && clamp_max).
40 
41 // A pointers
42 // x20 a0
43 // x13 a1
44 // x14 a2
45 // x15 a3
46 
47 // C pointers
48 // x6  c0
49 // x16 c1
50 // x17 c2
51 // x7  c3 / cm_stride
52 
53 // Vector register usage
54 // A0  v0  v4
55 // A1  v1  v5
56 // A2  v2  v6
57 // A3  v3  v7
58 // B   v8  v9 v10 v11
59 // B  v12 v13 v14 v15
60 // B  v16 v17 v18 v19
61 // B  v20 v21 v22 v23
62 // C  v24 v25
63 // C  v26 v27
64 // C  v28 v29
65 // C  v30 v31
66 // Clamp v4 v5
67 
68 // Converted from: src/f32-igemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,float min,float max)69 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max) {
70   assert(nc_mod_nr < 8);
71   assert(kc != 0);
72   assert(kc % sizeof(float) == 0);
73 
74   Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11;
75   const bool clamp_min = min != -std::numeric_limits<float>::infinity();
76   const bool clamp_max = max != +std::numeric_limits<float>::infinity();
77 
78   // Load cn_stride, a_offset
79   ldp(x10, x11, mem[sp]);
80 
81   // Load zero, params pointer
82   ldp(x12, x8, mem[sp, 16]);
83 
84   // Load min/max values
85   if (clamp_min && clamp_max) {
86     ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
87   } else if (clamp_min) {
88     ld1r({v4.v4s()}, mem[x8]);
89   } else if (clamp_max) {
90     add(x21, x8, 4);
91     ld1r({v5.v4s()}, mem[x21]);
92   }
93 
94   // Save x20, x21 on stack
95   stp(x20, x21, mem[sp, -80]++);
96 
97   // Save d8-d15 on stack
98   stp(d8, d9, mem[sp, 16]);
99   stp(d10, d11, mem[sp, 32]);
100   stp(d12, d13, mem[sp, 48]);
101   stp(d14, d15, mem[sp, 64]);
102 
103   // Clamp C pointers
104   cmp(x0, 2); // if mr < 2
105   add(x16, x6, x7); // c1 = c0 + cm_stride
106   csel(x16, x6, x16, kLO); //   c1 = c0
107 
108   add(x17, x16, x7); // c2 = c1 + cm_stride
109   // if mr <= 2
110   csel(x17, x16, x17, kLS); //   c2 = c1
111 
112   cmp(x0, 4); // if mr < 4
113   add(x7, x17, x7); // c3 = c2 + cm_stride
114   csel(x7, x17, x7, kLO); //   c3 = c2
115 
116   bind(l0);
117   // Load initial bias from w into accumulators
118   ldp(q24, q25, mem[x5], 32);
119   mov(v26.v16b(), v24.v16b());
120   mov(v27.v16b(), v25.v16b());
121   mov(v28.v16b(), v24.v16b());
122   mov(v29.v16b(), v25.v16b());
123   mov(v30.v16b(), v24.v16b());
124   mov(v31.v16b(), v25.v16b());
125 
126   mov(x9, x3); // p = ks
127 
128   bind(l1);
129   // Load next 4 A pointers
130   ldp(x20, x13, mem[x4], 16);
131   ldp(x14, x15, mem[x4], 16);
132 
133   cmp(x20, x12); // if a0 == zero
134   add(x20, x20, x11); // a0 += a_offset
135   csel(x20, x12, x20, kEQ); //   a0 = zero, else += a0 + a_offset
136   cmp(x13, x12); // if a1 == zero
137   add(x13, x13, x11); // a1 += a_offset
138   csel(x13, x12, x13, kEQ); //   a1 = zero, else += a1 + a_offset
139   cmp(x14, x12); // if a2 == zero
140   add(x14, x14, x11); // a2 += a_offset
141   csel(x14, x12, x14, kEQ); //   a2 = zero, else += a2 + a_offset
142   cmp(x15, x12); // if a3 == zero
143   add(x15, x15, x11); // a3 += a_offset
144   csel(x15, x12, x15, kEQ); //   a3 = zero, else += a3 + a_offset
145 
146   // Is there at least 8 floats (32 bytes) for prologue + epilogue?
147   subs(x0, x2, 32); // k = kc - 32
148   b_lo(l4);
149 
150   // 16 prologue
151   // Read first block of 4 A and B.
152   ldr(q0, mem[x20], 16);
153   ldp(q16, q17, mem[x5], 32);
154   ldr(q1, mem[x13], 16);
155   ldr(q2, mem[x14], 16);
156   ldr(q3, mem[x15], 16);
157   ldp(q18, q19, mem[x5], 32);
158   ldp(q20, q21, mem[x5], 32);
159   ldp(q22, q23, mem[x5], 32);
160 
161   // Is there at least 32.  yes do main loop
162   subs(x0, x0, 32);
163   b_lo(l3);
164 
165   // Main loop - 8 floats of A
166   bind(l2);
167   // First block of 4.  FMA for first 4, loads for 2nd block of 4.
168   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
169   ldp(q8, q9, mem[x5], 32);
170   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
171   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
172   ldp(q10, q11, mem[x5], 32);
173   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
174   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
175   ldp(q12, q13, mem[x5], 32);
176   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
177   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
178   ldp(q14, q15, mem[x5], 32);
179   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
180   fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
181   ldr(q4, mem[x20], 16);
182   fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
183   fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
184   ldr(q5, mem[x13], 16);
185   fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
186   fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
187   ldr(q6, mem[x14], 16);
188   fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
189   fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
190   ldr(q7, mem[x15], 16);
191   fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
192   fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
193   if (prefetch) {
194     prfm(kPLDL1KEEP, mem[x5, 128]);
195   }
196   fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
197   fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
198   if (prefetch) {
199     prfm(kPLDL1KEEP, mem[x5, 192]);
200   }
201   fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
202   fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
203   if (prefetch) {
204     prfm(kPLDL1KEEP, mem[x5, 256]);
205   }
206   fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
207   fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
208   if (prefetch) {
209     prfm(kPLDL1KEEP, mem[x5, 320]);
210   }
211   fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
212   fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
213   fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
214   fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
215   fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
216   fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
217   fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
218   fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
219   fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
220 
221   // Second block of 4.  FMA for second 4, loads for 1st block of 4.
222   fmla(v24.v4s(), v8.v4s(), v4.s()[0]);
223   ldp(q16, q17, mem[x5], 32);
224   fmla(v25.v4s(), v9.v4s(), v4.s()[0]);
225   fmla(v26.v4s(), v8.v4s(), v5.s()[0]);
226   ldp(q18, q19, mem[x5], 32);
227   fmla(v27.v4s(), v9.v4s(), v5.s()[0]);
228   fmla(v28.v4s(), v8.v4s(), v6.s()[0]);
229   ldp(q20, q21, mem[x5], 32);
230   fmla(v29.v4s(), v9.v4s(), v6.s()[0]);
231   fmla(v30.v4s(), v8.v4s(), v7.s()[0]);
232   ldp(q22, q23, mem[x5], 32);
233   fmla(v31.v4s(), v9.v4s(), v7.s()[0]);
234   fmla(v24.v4s(), v10.v4s(), v4.s()[1]);
235   ldr(q0, mem[x20], 16);
236   fmla(v25.v4s(), v11.v4s(), v4.s()[1]);
237   fmla(v26.v4s(), v10.v4s(), v5.s()[1]);
238   ldr(q1, mem[x13], 16);
239   fmla(v27.v4s(), v11.v4s(), v5.s()[1]);
240   fmla(v28.v4s(), v10.v4s(), v6.s()[1]);
241   ldr(q2, mem[x14], 16);
242   fmla(v29.v4s(), v11.v4s(), v6.s()[1]);
243   fmla(v30.v4s(), v10.v4s(), v7.s()[1]);
244   ldr(q3, mem[x15], 16);
245   fmla(v31.v4s(), v11.v4s(), v7.s()[1]);
246   fmla(v24.v4s(), v12.v4s(), v4.s()[2]);
247   fmla(v25.v4s(), v13.v4s(), v4.s()[2]);
248   fmla(v26.v4s(), v12.v4s(), v5.s()[2]);
249   fmla(v27.v4s(), v13.v4s(), v5.s()[2]);
250   fmla(v28.v4s(), v12.v4s(), v6.s()[2]);
251   fmla(v29.v4s(), v13.v4s(), v6.s()[2]);
252   fmla(v30.v4s(), v12.v4s(), v7.s()[2]);
253   fmla(v31.v4s(), v13.v4s(), v7.s()[2]);
254   fmla(v24.v4s(), v14.v4s(), v4.s()[3]);
255   fmla(v25.v4s(), v15.v4s(), v4.s()[3]);
256   fmla(v26.v4s(), v14.v4s(), v5.s()[3]);
257   fmla(v27.v4s(), v15.v4s(), v5.s()[3]);
258   fmla(v28.v4s(), v14.v4s(), v6.s()[3]);
259   fmla(v29.v4s(), v15.v4s(), v6.s()[3]);
260   subs(x0, x0, 32);
261   fmla(v30.v4s(), v14.v4s(), v7.s()[3]);
262   fmla(v31.v4s(), v15.v4s(), v7.s()[3]);
263 
264   b_hs(l2);
265 
266   bind(l3);
267   // Epilogue
268   // First block of 4.  FMA for first 4, loads for 2nd block of 4.
269   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
270   ldp(q8, q9, mem[x5], 32);
271   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
272   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
273   ldp(q10, q11, mem[x5], 32);
274   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
275   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
276   ldp(q12, q13, mem[x5], 32);
277   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
278   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
279   ldp(q14, q15, mem[x5], 32);
280   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
281   fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
282   ldr(q4, mem[x20], 16);
283   fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
284   fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
285   ldr(q5, mem[x13], 16);
286   fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
287   fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
288   ldr(q6, mem[x14], 16);
289   fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
290   fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
291   ldr(q7, mem[x15], 16);
292   fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
293   fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
294   fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
295   fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
296   fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
297   fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
298   fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
299   fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
300   fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
301   fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
302   fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
303   fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
304   fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
305   fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
306   fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
307   fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
308   fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
309 
310   // Second block of 4.  FMA for second 4, noloads
311   fmla(v24.v4s(), v8.v4s(), v4.s()[0]);
312   fmla(v25.v4s(), v9.v4s(), v4.s()[0]);
313   fmla(v26.v4s(), v8.v4s(), v5.s()[0]);
314   fmla(v27.v4s(), v9.v4s(), v5.s()[0]);
315   fmla(v28.v4s(), v8.v4s(), v6.s()[0]);
316   fmla(v29.v4s(), v9.v4s(), v6.s()[0]);
317   fmla(v30.v4s(), v8.v4s(), v7.s()[0]);
318   fmla(v31.v4s(), v9.v4s(), v7.s()[0]);
319   fmla(v24.v4s(), v10.v4s(), v4.s()[1]);
320   fmla(v25.v4s(), v11.v4s(), v4.s()[1]);
321   fmla(v26.v4s(), v10.v4s(), v5.s()[1]);
322   fmla(v27.v4s(), v11.v4s(), v5.s()[1]);
323   fmla(v28.v4s(), v10.v4s(), v6.s()[1]);
324   fmla(v29.v4s(), v11.v4s(), v6.s()[1]);
325   fmla(v30.v4s(), v10.v4s(), v7.s()[1]);
326   fmla(v31.v4s(), v11.v4s(), v7.s()[1]);
327   fmla(v24.v4s(), v12.v4s(), v4.s()[2]);
328   fmla(v25.v4s(), v13.v4s(), v4.s()[2]);
329   fmla(v26.v4s(), v12.v4s(), v5.s()[2]);
330   fmla(v27.v4s(), v13.v4s(), v5.s()[2]);
331   fmla(v28.v4s(), v12.v4s(), v6.s()[2]);
332   fmla(v29.v4s(), v13.v4s(), v6.s()[2]);
333   fmla(v30.v4s(), v12.v4s(), v7.s()[2]);
334   fmla(v31.v4s(), v13.v4s(), v7.s()[2]);
335 
336   fmla(v24.v4s(), v14.v4s(), v4.s()[3]);
337   fmla(v25.v4s(), v15.v4s(), v4.s()[3]);
338   fmla(v26.v4s(), v14.v4s(), v5.s()[3]);
339   fmla(v27.v4s(), v15.v4s(), v5.s()[3]);
340 
341   // Load min/max values
342   if (clamp_min && clamp_max) {
343     ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
344   } else if (clamp_min) {
345     ld1r({v4.v4s()}, mem[x8]);
346   } else if (clamp_max) {
347     add(x13, x8, 4);
348     ld1r({v5.v4s()}, mem[x13]);
349   }
350 
351   fmla(v28.v4s(), v14.v4s(), v6.s()[3]);
352   fmla(v29.v4s(), v15.v4s(), v6.s()[3]);
353   fmla(v30.v4s(), v14.v4s(), v7.s()[3]);
354   fmla(v31.v4s(), v15.v4s(), v7.s()[3]);
355 
356   bind(l4);
357   // Remainder- 4 floats of A
358   tbz(x0, 4, l5);
359 
360   ldr(q0, mem[x20], 16);
361   ldp(q16, q17, mem[x5], 32);
362   ldr(q1, mem[x13], 16);
363   ldr(q2, mem[x14], 16);
364   ldr(q3, mem[x15], 16);
365   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
366   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
367   ldp(q18, q19, mem[x5], 32);
368   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
369   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
370   ldp(q20, q21, mem[x5], 32);
371   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
372   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
373   ldp(q22, q23, mem[x5], 32);
374   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
375   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
376   fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
377   fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
378   fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
379   fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
380   fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
381   fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
382   fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
383   fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
384   fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
385   fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
386   fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
387   fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
388   fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
389   fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
390   fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
391   fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
392   fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
393   fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
394   fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
395   fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
396   fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
397   fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
398   fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
399   fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
400 
401   bind(l5);
402   // Remainder- 2 floats of A
403   tbz(x0, 3, l6);
404 
405   ldr(d0, mem[x20], 8);
406   ldp(q16, q17, mem[x5], 32);
407   ldr(d1, mem[x13], 8);
408   ldr(d2, mem[x14], 8);
409   ldr(d3, mem[x15], 8);
410   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
411   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
412   ldp(q18, q19, mem[x5], 32);
413   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
414   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
415   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
416   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
417   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
418   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
419   fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
420   fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
421   fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
422   fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
423   fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
424   fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
425   fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
426   fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
427 
428   bind(l6);
429   // Remainder- 1 float of A
430   tbz(x0, 2, l7);
431 
432   ldr(s0, mem[x20], 4);
433   ldp(q16, q17, mem[x5], 32);
434   ldr(s1, mem[x13], 4);
435   ldr(s2, mem[x14], 4);
436   ldr(s3, mem[x15], 4);
437   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
438   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
439   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
440   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
441   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
442   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
443   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
444   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
445 
446   bind(l7);
447   // ks loop
448   subs(x9, x9, 32); // ks -= MR * sizeof(void*)
449   b_hi(l1);
450 
451   // Clamp
452   if (clamp_min) {
453     fmax(v24.v4s(), v24.v4s(), v4.v4s());
454     fmax(v25.v4s(), v25.v4s(), v4.v4s());
455     fmax(v26.v4s(), v26.v4s(), v4.v4s());
456     fmax(v27.v4s(), v27.v4s(), v4.v4s());
457     fmax(v28.v4s(), v28.v4s(), v4.v4s());
458     fmax(v29.v4s(), v29.v4s(), v4.v4s());
459     fmax(v30.v4s(), v30.v4s(), v4.v4s());
460     fmax(v31.v4s(), v31.v4s(), v4.v4s());
461   }
462   if (clamp_max) {
463     fmin(v24.v4s(), v24.v4s(), v5.v4s());
464     fmin(v25.v4s(), v25.v4s(), v5.v4s());
465     fmin(v26.v4s(), v26.v4s(), v5.v4s());
466     fmin(v27.v4s(), v27.v4s(), v5.v4s());
467     fmin(v28.v4s(), v28.v4s(), v5.v4s());
468     fmin(v29.v4s(), v29.v4s(), v5.v4s());
469     fmin(v30.v4s(), v30.v4s(), v5.v4s());
470     fmin(v31.v4s(), v31.v4s(), v5.v4s());
471   }
472 
473   // Store full 4 x 8
474   subs(x1, x1, 8);
475   b_lo(l8);
476 
477   stp(q30, q31, mem[x7]);
478   add(x7, x7, x10);
479   stp(q28, q29, mem[x17]);
480   add(x17, x17, x10);
481   stp(q26, q27, mem[x16]);
482   add(x16, x16, x10);
483   stp(q24, q25, mem[x6]);
484   add(x6, x6, x10);
485 
486   sub(x4, x4, x3); // a -= ks
487 
488   // nc loop
489   b_hi(l0);
490 
491   // Restore d8-d15 from stack
492   ldp(d14, d15, mem[sp, 64]);
493   ldp(d12, d13, mem[sp, 48]);
494   ldp(d10, d11, mem[sp, 32]);
495   ldp(d8, d9, mem[sp, 16]);
496 
497   // Restore x20 from stack
498   ldr(x20, mem[sp], 80);
499   ret();
500 
501   // Store odd width
502   bind(l8);
503   tbz(x1, 2, l9);
504   str(q30, mem[x7], 16);
505   mov(v30.v16b(), v31.v16b());
506   str(q28, mem[x17], 16);
507   mov(v28.v16b(), v29.v16b());
508   str(q26, mem[x16], 16);
509   mov(v26.v16b(), v27.v16b());
510   str(q24, mem[x6], 16);
511   mov(v24.v16b(), v25.v16b());
512 
513   bind(l9);
514   tbz(x1, 1, l10);
515   str(d30, mem[x7], 8);
516   str(d28, mem[x17], 8);
517   dup(d30, v30.d()[1]);
518   dup(d28, v28.d()[1]);
519   str(d26, mem[x16], 8);
520   str(d24, mem[x6], 8);
521   dup(d26, v26.d()[1]);
522   dup(d24, v24.d()[1]);
523 
524   bind(l10);
525   tbz(x1, 0, l11);
526   str(s30, mem[x7]);
527   str(s28, mem[x17]);
528   str(s26, mem[x16]);
529   str(s24, mem[x6]);
530   bind(l11);
531   // Restore d8-d15 from stack
532   ldp(d14, d15, mem[sp, 64]);
533   ldp(d12, d13, mem[sp, 48]);
534   ldp(d10, d11, mem[sp, 32]);
535   ldp(d8, d9, mem[sp, 16]);
536 
537   // Restore x20, x21 from stack
538   ldp(x20, x21, mem[sp], 80);
539   ret();
540 
541   align(16, AlignInstruction::kHlt);
542 }
543 }  // namespace
544 }  // aarch64
545 }  // xnnpack
546 
xnn_generate_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)547 xnn_status_t xnn_generate_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
548   using namespace xnnpack::aarch64;
549   Generator g(code);
550   assert(params != nullptr);
551   const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
552   g.generate(false, max_mr, nc_mod_nr, kc, ks, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
553   g.finalize();
554   if (g.error() != xnnpack::Error::kNoError) {
555     return xnn_status_invalid_state;
556   }
557   return xnn_status_success;
558 }
559 
xnn_generate_f32_igemm_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)560 xnn_status_t xnn_generate_f32_igemm_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
561   using namespace xnnpack::aarch64;
562   Generator g(code);
563   assert(params != nullptr);
564   const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
565   g.generate(true, max_mr, nc_mod_nr, kc, ks, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
566   g.finalize();
567   if (g.error() != xnnpack::Error::kNoError) {
568     return xnn_status_invalid_state;
569   }
570   return xnn_status_success;
571 }
572