xref: /aosp_15_r20/external/XNNPACK/src/f32-gemm/4x8-aarch64-neonfma-cortex-a75.cc (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5 
6 #include <cassert>
7 #include <cstddef>
8 #include <limits>
9 
10 #include <xnnpack.h>
11 #include <xnnpack/aarch64-assembler.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/gemm.h>
14 
15 namespace xnnpack {
16 namespace aarch64 {
17 namespace {
18 class Generator : public Assembler {
19   using Assembler::Assembler;
20  public:
21   void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max);
22 };
23 
24 // void xnn_f32_gemm_minmax_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(
25 //     size_t mr,                x0
26 //     size_t nc,                x1
27 //     size_t kc,                x2 / x0
28 //     const uint8_t*restrict a, x3
29 //     size_t a_stride,          x4
30 //     const void*restrict w,    x5
31 //     uint8_t*restrict c,       x6
32 //     size_t cm_stride,         x7
33 //     size_t cn_stride,         [sp] -> x14
34 //     const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])  [sp + 8] -> x8
35 
36 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
37 // x13 used to store pointer to params->max if (!clamp_min && clamp_max).
38 
39 // A pointers
40 // x3  a0
41 // x11 a1
42 // x12 a2
43 // x4  a3 / a_stride
44 
45 // C pointers
46 // x6  c0
47 // x9  c1
48 // x10 c2
49 // x7  c3 / cm_stride
50 
51 // Vector register usage
52 // A0  v0  v4
53 // A1  v1  v5
54 // A2  v2  v6
55 // A3  v3  v7
56 // B   v8  v9 v10 v11
57 // B  v12 v13 v14 v15
58 // B  v16 v17 v18 v19
59 // B  v20 v21 v22 v23
60 // C  v24 v25
61 // C  v26 v27
62 // C  v28 v29
63 // C  v30 v31
64 // Clamp v4 v5
65 
66 // Converted from: src/f32-gemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,float min,float max)67 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max) {
68   assert(nc_mod_nr < 8);
69   assert(kc != 0);
70   assert(kc % sizeof(float) == 0);
71 
72   Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;
73   const bool clamp_min = min != -std::numeric_limits<float>::infinity();
74   const bool clamp_max = max != +std::numeric_limits<float>::infinity();
75 
76   // Load cn_stride, params pointer
77   ldp(x14, x8, mem[sp]);
78 
79   // Load min/max values
80   if (clamp_min && clamp_max) {
81     ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
82   } else if (clamp_min) {
83     ld1r({v4.v4s()}, mem[x8]);
84   } else if (clamp_max) {
85     add(x13, x8, 4);
86     ld1r({v5.v4s()}, mem[x13]);
87   }
88 
89   // Save d8-d15 on stack
90   stp(d8, d9, mem[sp, -64]++);
91   stp(d10, d11, mem[sp, 16]);
92   stp(d12, d13, mem[sp, 32]);
93   stp(d14, d15, mem[sp, 48]);
94 
95   // Clamp A and C pointers
96   cmp(x0, 2); // if mr < 2
97   add(x11, x3, x4); // a1 = a0 + a_stride
98   add(x9, x6, x7); // c1 = c0 + cm_stride
99   csel(x11, x3, x11, kLO); //   a1 = a0
100   csel(x9, x6, x9, kLO); //   c1 = c0
101 
102   add(x12, x11, x4); // a2 = a1 + a_stride
103   add(x10, x9, x7); // c2 = c1 + cm_stride
104   // if mr <= 2
105   csel(x12, x11, x12, kLS); //   a2 = a1
106   csel(x10, x9, x10, kLS); //   c2 = c1
107 
108   cmp(x0, 4); // if mr < 4
109   add(x4, x12, x4); // a3 = a2 + a_stride
110   add(x7, x10, x7); // c3 = c2 + cm_stride
111   csel(x4, x12, x4, kLO); //   a3 = a2
112   csel(x7, x10, x7, kLO); //   c3 = c2
113 
114   bind(l0);
115   // Load initial bias from w into accumulators
116   ldp(q24, q25, mem[x5], 32);
117   mov(v26.v16b(), v24.v16b());
118   mov(v27.v16b(), v25.v16b());
119   mov(v28.v16b(), v24.v16b());
120   mov(v29.v16b(), v25.v16b());
121   mov(v30.v16b(), v24.v16b());
122   mov(v31.v16b(), v25.v16b());
123 
124   // Is there at least 8 floats (32 bytes) for prologue + epilogue?
125   subs(x0, x2, 32); // k = kc - 32
126   b_lo(l3);
127 
128   // 16 prologue
129   // Read first block of 4 A and B.
130   ldr(q0, mem[x3], 16);
131   ldp(q16, q17, mem[x5], 32);
132   ldr(q1, mem[x11], 16);
133   ldr(q2, mem[x12], 16);
134   ldr(q3, mem[x4], 16);
135   ldp(q18, q19, mem[x5], 32);
136   ldp(q20, q21, mem[x5], 32);
137   ldp(q22, q23, mem[x5], 32);
138 
139   // Is there at least 32.  yes do main loop
140   subs(x0, x0, 32);
141   b_lo(l2);
142 
143   // Main loop - 8 floats of A (32 bytes)
144   bind(l1);
145   // First block of 4.  FMA for first 4, loads for 2nd block of 4.
146   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
147   ldp(q8, q9, mem[x5], 32);
148   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
149   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
150   ldp(q10, q11, mem[x5], 32);
151   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
152   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
153   ldp(q12, q13, mem[x5], 32);
154   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
155   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
156   ldp(q14, q15, mem[x5], 32);
157   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
158   fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
159   ldr(q4, mem[x3], 16);
160   fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
161   fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
162   ldr(q5, mem[x11], 16);
163   fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
164   fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
165   ldr(q6, mem[x12], 16);
166   fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
167   fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
168   ldr(q7, mem[x4], 16);
169   fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
170   fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
171   if (prefetch) {
172     prfm(kPLDL1KEEP, mem[x5, 128]);
173   }
174   fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
175   fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
176   if (prefetch) {
177     prfm(kPLDL1KEEP, mem[x5, 192]);
178   }
179   fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
180   fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
181   if (prefetch) {
182     prfm(kPLDL1KEEP, mem[x5, 256]);
183   }
184   fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
185   fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
186   if (prefetch) {
187     prfm(kPLDL1KEEP, mem[x5, 320]);
188   }
189   fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
190   fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
191   fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
192   fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
193   fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
194   fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
195   fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
196   fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
197   fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
198 
199   // Second block of 4.  FMA for second 4, loads for 1st block of 4.
200   fmla(v24.v4s(), v8.v4s(), v4.s()[0]);
201   ldp(q16, q17, mem[x5], 32);
202   fmla(v25.v4s(), v9.v4s(), v4.s()[0]);
203   fmla(v26.v4s(), v8.v4s(), v5.s()[0]);
204   ldp(q18, q19, mem[x5], 32);
205   fmla(v27.v4s(), v9.v4s(), v5.s()[0]);
206   fmla(v28.v4s(), v8.v4s(), v6.s()[0]);
207   ldp(q20, q21, mem[x5], 32);
208   fmla(v29.v4s(), v9.v4s(), v6.s()[0]);
209   fmla(v30.v4s(), v8.v4s(), v7.s()[0]);
210   ldp(q22, q23, mem[x5], 32);
211   fmla(v31.v4s(), v9.v4s(), v7.s()[0]);
212   fmla(v24.v4s(), v10.v4s(), v4.s()[1]);
213   ldr(q0, mem[x3], 16);
214   fmla(v25.v4s(), v11.v4s(), v4.s()[1]);
215   fmla(v26.v4s(), v10.v4s(), v5.s()[1]);
216   ldr(q1, mem[x11], 16);
217   fmla(v27.v4s(), v11.v4s(), v5.s()[1]);
218   fmla(v28.v4s(), v10.v4s(), v6.s()[1]);
219   ldr(q2, mem[x12], 16);
220   fmla(v29.v4s(), v11.v4s(), v6.s()[1]);
221   fmla(v30.v4s(), v10.v4s(), v7.s()[1]);
222   ldr(q3, mem[x4], 16);
223   fmla(v31.v4s(), v11.v4s(), v7.s()[1]);
224   fmla(v24.v4s(), v12.v4s(), v4.s()[2]);
225   fmla(v25.v4s(), v13.v4s(), v4.s()[2]);
226   fmla(v26.v4s(), v12.v4s(), v5.s()[2]);
227   fmla(v27.v4s(), v13.v4s(), v5.s()[2]);
228   fmla(v28.v4s(), v12.v4s(), v6.s()[2]);
229   fmla(v29.v4s(), v13.v4s(), v6.s()[2]);
230   fmla(v30.v4s(), v12.v4s(), v7.s()[2]);
231   fmla(v31.v4s(), v13.v4s(), v7.s()[2]);
232   fmla(v24.v4s(), v14.v4s(), v4.s()[3]);
233   fmla(v25.v4s(), v15.v4s(), v4.s()[3]);
234   fmla(v26.v4s(), v14.v4s(), v5.s()[3]);
235   fmla(v27.v4s(), v15.v4s(), v5.s()[3]);
236   fmla(v28.v4s(), v14.v4s(), v6.s()[3]);
237   fmla(v29.v4s(), v15.v4s(), v6.s()[3]);
238   subs(x0, x0, 32);
239   fmla(v30.v4s(), v14.v4s(), v7.s()[3]);
240   fmla(v31.v4s(), v15.v4s(), v7.s()[3]);
241   b_hs(l1);
242 
243   bind(l2);
244   // Epilogue
245   // First block of 4.  FMA for first 4, loads for 2nd block of 4.
246   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
247   ldp(q8, q9, mem[x5], 32);
248   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
249   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
250   ldp(q10, q11, mem[x5], 32);
251   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
252   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
253   ldp(q12, q13, mem[x5], 32);
254   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
255   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
256   ldp(q14, q15, mem[x5], 32);
257   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
258   fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
259   ldr(q4, mem[x3], 16);
260   fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
261   fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
262   ldr(q5, mem[x11], 16);
263   fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
264   fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
265   ldr(q6, mem[x12], 16);
266   fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
267   fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
268   ldr(q7, mem[x4], 16);
269   fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
270   fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
271   fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
272   fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
273   fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
274   fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
275   fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
276   fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
277   fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
278   fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
279   fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
280   fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
281   fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
282   fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
283   fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
284   fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
285   fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
286 
287   // Second block of 4.  FMA for second 4, noloads
288   fmla(v24.v4s(), v8.v4s(), v4.s()[0]);
289   fmla(v25.v4s(), v9.v4s(), v4.s()[0]);
290   fmla(v26.v4s(), v8.v4s(), v5.s()[0]);
291   fmla(v27.v4s(), v9.v4s(), v5.s()[0]);
292   fmla(v28.v4s(), v8.v4s(), v6.s()[0]);
293   fmla(v29.v4s(), v9.v4s(), v6.s()[0]);
294   fmla(v30.v4s(), v8.v4s(), v7.s()[0]);
295   fmla(v31.v4s(), v9.v4s(), v7.s()[0]);
296 
297   fmla(v24.v4s(), v10.v4s(), v4.s()[1]);
298   fmla(v25.v4s(), v11.v4s(), v4.s()[1]);
299   fmla(v26.v4s(), v10.v4s(), v5.s()[1]);
300   fmla(v27.v4s(), v11.v4s(), v5.s()[1]);
301   fmla(v28.v4s(), v10.v4s(), v6.s()[1]);
302   fmla(v29.v4s(), v11.v4s(), v6.s()[1]);
303   fmla(v30.v4s(), v10.v4s(), v7.s()[1]);
304   fmla(v31.v4s(), v11.v4s(), v7.s()[1]);
305 
306   fmla(v24.v4s(), v12.v4s(), v4.s()[2]);
307   fmla(v25.v4s(), v13.v4s(), v4.s()[2]);
308   fmla(v26.v4s(), v12.v4s(), v5.s()[2]);
309   fmla(v27.v4s(), v13.v4s(), v5.s()[2]);
310   fmla(v28.v4s(), v12.v4s(), v6.s()[2]);
311   fmla(v29.v4s(), v13.v4s(), v6.s()[2]);
312   fmla(v30.v4s(), v12.v4s(), v7.s()[2]);
313   fmla(v31.v4s(), v13.v4s(), v7.s()[2]);
314 
315   fmla(v24.v4s(), v14.v4s(), v4.s()[3]);
316   fmla(v25.v4s(), v15.v4s(), v4.s()[3]);
317   fmla(v26.v4s(), v14.v4s(), v5.s()[3]);
318   fmla(v27.v4s(), v15.v4s(), v5.s()[3]);
319 
320   // Load min/max values
321   if (clamp_min && clamp_max) {
322     ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
323   } else if (clamp_min) {
324     ld1r({v4.v4s()}, mem[x8]);
325   } else if (clamp_max) {
326     add(x13, x8, 4);
327     ld1r({v5.v4s()}, mem[x13]);
328   }
329 
330   fmla(v28.v4s(), v14.v4s(), v6.s()[3]);
331   fmla(v29.v4s(), v15.v4s(), v6.s()[3]);
332   fmla(v30.v4s(), v14.v4s(), v7.s()[3]);
333   fmla(v31.v4s(), v15.v4s(), v7.s()[3]);
334 
335   bind(l3);
336   // Remainder- 4 floats of A (16 bytes)
337   tbz(x0, 4, l4);
338 
339   ldr(q0, mem[x3], 16);
340   ldp(q16, q17, mem[x5], 32);
341   ldr(q1, mem[x11], 16);
342   ldr(q2, mem[x12], 16);
343   ldr(q3, mem[x4], 16);
344   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
345   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
346   ldp(q18, q19, mem[x5], 32);
347   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
348   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
349   ldp(q20, q21, mem[x5], 32);
350   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
351   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
352   ldp(q22, q23, mem[x5], 32);
353   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
354   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
355   fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
356   fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
357   fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
358   fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
359   fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
360   fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
361   fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
362   fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
363   fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
364   fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
365   fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
366   fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
367   fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
368   fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
369   fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
370   fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
371   fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
372   fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
373   fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
374   fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
375   fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
376   fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
377   fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
378   fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
379 
380   bind(l4);
381   // Remainder- 2 floats of A (8 bytes)
382   tbz(x0, 3, l5);
383 
384   ldr(d0, mem[x3], 8);
385   ldp(q16, q17, mem[x5], 32);
386   ldr(d1, mem[x11], 8);
387   ldr(d2, mem[x12], 8);
388   ldr(d3, mem[x4], 8);
389   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
390   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
391   ldp(q18, q19, mem[x5], 32);
392   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
393   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
394   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
395   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
396   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
397   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
398   fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
399   fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
400   fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
401   fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
402   fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
403   fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
404   fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
405   fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
406 
407   bind(l5);
408   // Remainder- 1 float of A (4 bytes)
409   tbz(x0, 2, l6);
410 
411   ldr(s0, mem[x3], 4);
412   ldp(q16, q17, mem[x5], 32);
413   ldr(s1, mem[x11], 4);
414   ldr(s2, mem[x12], 4);
415   ldr(s3, mem[x4], 4);
416   fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
417   fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
418   fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
419   fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
420   fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
421   fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
422   fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
423   fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
424 
425   bind(l6);
426   // Clamp
427   subs(x1, x1, 8);
428   if (clamp_min) {
429     fmax(v24.v4s(), v24.v4s(), v4.v4s());
430     fmax(v25.v4s(), v25.v4s(), v4.v4s());
431     fmax(v26.v4s(), v26.v4s(), v4.v4s());
432     fmax(v27.v4s(), v27.v4s(), v4.v4s());
433     fmax(v28.v4s(), v28.v4s(), v4.v4s());
434     fmax(v29.v4s(), v29.v4s(), v4.v4s());
435     fmax(v30.v4s(), v30.v4s(), v4.v4s());
436     fmax(v31.v4s(), v31.v4s(), v4.v4s());
437   }
438   if (clamp_max) {
439     fmin(v24.v4s(), v24.v4s(), v5.v4s());
440     fmin(v25.v4s(), v25.v4s(), v5.v4s());
441     fmin(v26.v4s(), v26.v4s(), v5.v4s());
442     fmin(v27.v4s(), v27.v4s(), v5.v4s());
443     fmin(v28.v4s(), v28.v4s(), v5.v4s());
444     fmin(v29.v4s(), v29.v4s(), v5.v4s());
445     fmin(v30.v4s(), v30.v4s(), v5.v4s());
446     fmin(v31.v4s(), v31.v4s(), v5.v4s());
447   }
448 
449   // Store full 4 x 8
450   b_lo(l7);
451 
452   stp(q24, q25, mem[x6]);
453   sub(x3, x3, x2); // a0 -= kc
454   add(x6, x6, x14);
455   stp(q26, q27, mem[x9]);
456   sub(x11, x11, x2); // a1 -= kc
457   add(x9, x9, x14);
458   stp(q28, q29, mem[x10]);
459   sub(x12, x12, x2); // a2 -= kc
460   add(x10, x10, x14);
461   stp(q30, q31, mem[x7]);
462   sub(x4, x4, x2); // a3 -= kc
463   add(x7, x7, x14);
464 
465   b_hi(l0);
466 
467   // Restore d8-d15 from stack
468   ldp(d14, d15, mem[sp, 48]);
469   ldp(d12, d13, mem[sp, 32]);
470   ldp(d10, d11, mem[sp, 16]);
471   ldp(d8, d9, mem[sp], 64);
472   ret();
473 
474   // Store odd width
475   bind(l7);
476   tbz(x1, 2, l8);
477   str(q24, mem[x6], 16);
478   mov(v24.v16b(), v25.v16b());
479   str(q26, mem[x9], 16);
480   mov(v26.v16b(), v27.v16b());
481   str(q28, mem[x10], 16);
482   mov(v28.v16b(), v29.v16b());
483   str(q30, mem[x7], 16);
484   mov(v30.v16b(), v31.v16b());
485 
486   bind(l8);
487   tbz(x1, 1, l9);
488   str(d24, mem[x6], 8);
489   str(d26, mem[x9], 8);
490   dup(d24, v24.d()[1]);
491   dup(d26, v26.d()[1]);
492   str(d28, mem[x10], 8);
493   str(d30, mem[x7], 8);
494   dup(d28, v28.d()[1]);
495   dup(d30, v30.d()[1]);
496 
497   bind(l9);
498   tbz(x1, 0, l10);
499   str(s24, mem[x6]);
500   str(s26, mem[x9]);
501   str(s28, mem[x10]);
502   str(s30, mem[x7]);
503   bind(l10);
504   // Restore d8-d15 from stack
505   ldp(d14, d15, mem[sp, 48]);
506   ldp(d12, d13, mem[sp, 32]);
507   ldp(d10, d11, mem[sp, 16]);
508   ldp(d8, d9, mem[sp], 64);
509   ret();
510 
511 
512   align(16, AlignInstruction::kHlt);
513 }
514 }  // namespace
515 }  // aarch64
516 }  // xnnpack
517 
xnn_generate_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)518 xnn_status_t xnn_generate_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
519   using namespace xnnpack::aarch64;
520   Generator g(code);
521   assert(params != nullptr);
522   const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
523   g.generate(false, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
524   g.finalize();
525   if (g.error() != xnnpack::Error::kNoError) {
526     return xnn_status_invalid_state;
527   }
528   return xnn_status_success;
529 }
530 
xnn_generate_f32_gemm_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)531 xnn_status_t xnn_generate_f32_gemm_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
532   using namespace xnnpack::aarch64;
533   Generator g(code);
534   assert(params != nullptr);
535   const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
536   g.generate(true, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
537   g.finalize();
538   if (g.error() != xnnpack::Error::kNoError) {
539     return xnn_status_invalid_state;
540   }
541   return xnn_status_success;
542 }
543