1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <cassert>
7 #include <cstddef>
8 #include <limits>
9
10 #include <xnnpack.h>
11 #include <xnnpack/aarch64-assembler.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/gemm.h>
14
15 namespace xnnpack {
16 namespace aarch64 {
17 namespace {
18 class Generator : public Assembler {
19 using Assembler::Assembler;
20 public:
21 void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max);
22 };
23
24 // void xnn_f32_gemm_minmax_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(
25 // size_t mr, x0
26 // size_t nc, x1
27 // size_t kc, x2 / x0
28 // const uint8_t*restrict a, x3
29 // size_t a_stride, x4
30 // const void*restrict w, x5
31 // uint8_t*restrict c, x6
32 // size_t cm_stride, x7
33 // size_t cn_stride, [sp] -> x14
34 // const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) [sp + 8] -> x8
35
36 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
37 // x13 used to store pointer to params->max if (!clamp_min && clamp_max).
38
39 // A pointers
40 // x3 a0
41 // x11 a1
42 // x12 a2
43 // x4 a3 / a_stride
44
45 // C pointers
46 // x6 c0
47 // x9 c1
48 // x10 c2
49 // x7 c3 / cm_stride
50
51 // Vector register usage
52 // A0 v0 v4
53 // A1 v1 v5
54 // A2 v2 v6
55 // A3 v3 v7
56 // B v8 v9 v10 v11
57 // B v12 v13 v14 v15
58 // B v16 v17 v18 v19
59 // B v20 v21 v22 v23
60 // C v24 v25
61 // C v26 v27
62 // C v28 v29
63 // C v30 v31
64 // Clamp v4 v5
65
66 // Converted from: src/f32-gemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,float min,float max)67 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, float min, float max) {
68 assert(nc_mod_nr < 8);
69 assert(kc != 0);
70 assert(kc % sizeof(float) == 0);
71
72 Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;
73 const bool clamp_min = min != -std::numeric_limits<float>::infinity();
74 const bool clamp_max = max != +std::numeric_limits<float>::infinity();
75
76 // Load cn_stride, params pointer
77 ldp(x14, x8, mem[sp]);
78
79 // Load min/max values
80 if (clamp_min && clamp_max) {
81 ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
82 } else if (clamp_min) {
83 ld1r({v4.v4s()}, mem[x8]);
84 } else if (clamp_max) {
85 add(x13, x8, 4);
86 ld1r({v5.v4s()}, mem[x13]);
87 }
88
89 // Save d8-d15 on stack
90 stp(d8, d9, mem[sp, -64]++);
91 stp(d10, d11, mem[sp, 16]);
92 stp(d12, d13, mem[sp, 32]);
93 stp(d14, d15, mem[sp, 48]);
94
95 // Clamp A and C pointers
96 cmp(x0, 2); // if mr < 2
97 add(x11, x3, x4); // a1 = a0 + a_stride
98 add(x9, x6, x7); // c1 = c0 + cm_stride
99 csel(x11, x3, x11, kLO); // a1 = a0
100 csel(x9, x6, x9, kLO); // c1 = c0
101
102 add(x12, x11, x4); // a2 = a1 + a_stride
103 add(x10, x9, x7); // c2 = c1 + cm_stride
104 // if mr <= 2
105 csel(x12, x11, x12, kLS); // a2 = a1
106 csel(x10, x9, x10, kLS); // c2 = c1
107
108 cmp(x0, 4); // if mr < 4
109 add(x4, x12, x4); // a3 = a2 + a_stride
110 add(x7, x10, x7); // c3 = c2 + cm_stride
111 csel(x4, x12, x4, kLO); // a3 = a2
112 csel(x7, x10, x7, kLO); // c3 = c2
113
114 bind(l0);
115 // Load initial bias from w into accumulators
116 ldp(q24, q25, mem[x5], 32);
117 mov(v26.v16b(), v24.v16b());
118 mov(v27.v16b(), v25.v16b());
119 mov(v28.v16b(), v24.v16b());
120 mov(v29.v16b(), v25.v16b());
121 mov(v30.v16b(), v24.v16b());
122 mov(v31.v16b(), v25.v16b());
123
124 // Is there at least 8 floats (32 bytes) for prologue + epilogue?
125 subs(x0, x2, 32); // k = kc - 32
126 b_lo(l3);
127
128 // 16 prologue
129 // Read first block of 4 A and B.
130 ldr(q0, mem[x3], 16);
131 ldp(q16, q17, mem[x5], 32);
132 ldr(q1, mem[x11], 16);
133 ldr(q2, mem[x12], 16);
134 ldr(q3, mem[x4], 16);
135 ldp(q18, q19, mem[x5], 32);
136 ldp(q20, q21, mem[x5], 32);
137 ldp(q22, q23, mem[x5], 32);
138
139 // Is there at least 32. yes do main loop
140 subs(x0, x0, 32);
141 b_lo(l2);
142
143 // Main loop - 8 floats of A (32 bytes)
144 bind(l1);
145 // First block of 4. FMA for first 4, loads for 2nd block of 4.
146 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
147 ldp(q8, q9, mem[x5], 32);
148 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
149 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
150 ldp(q10, q11, mem[x5], 32);
151 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
152 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
153 ldp(q12, q13, mem[x5], 32);
154 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
155 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
156 ldp(q14, q15, mem[x5], 32);
157 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
158 fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
159 ldr(q4, mem[x3], 16);
160 fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
161 fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
162 ldr(q5, mem[x11], 16);
163 fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
164 fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
165 ldr(q6, mem[x12], 16);
166 fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
167 fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
168 ldr(q7, mem[x4], 16);
169 fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
170 fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
171 if (prefetch) {
172 prfm(kPLDL1KEEP, mem[x5, 128]);
173 }
174 fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
175 fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
176 if (prefetch) {
177 prfm(kPLDL1KEEP, mem[x5, 192]);
178 }
179 fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
180 fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
181 if (prefetch) {
182 prfm(kPLDL1KEEP, mem[x5, 256]);
183 }
184 fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
185 fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
186 if (prefetch) {
187 prfm(kPLDL1KEEP, mem[x5, 320]);
188 }
189 fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
190 fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
191 fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
192 fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
193 fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
194 fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
195 fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
196 fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
197 fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
198
199 // Second block of 4. FMA for second 4, loads for 1st block of 4.
200 fmla(v24.v4s(), v8.v4s(), v4.s()[0]);
201 ldp(q16, q17, mem[x5], 32);
202 fmla(v25.v4s(), v9.v4s(), v4.s()[0]);
203 fmla(v26.v4s(), v8.v4s(), v5.s()[0]);
204 ldp(q18, q19, mem[x5], 32);
205 fmla(v27.v4s(), v9.v4s(), v5.s()[0]);
206 fmla(v28.v4s(), v8.v4s(), v6.s()[0]);
207 ldp(q20, q21, mem[x5], 32);
208 fmla(v29.v4s(), v9.v4s(), v6.s()[0]);
209 fmla(v30.v4s(), v8.v4s(), v7.s()[0]);
210 ldp(q22, q23, mem[x5], 32);
211 fmla(v31.v4s(), v9.v4s(), v7.s()[0]);
212 fmla(v24.v4s(), v10.v4s(), v4.s()[1]);
213 ldr(q0, mem[x3], 16);
214 fmla(v25.v4s(), v11.v4s(), v4.s()[1]);
215 fmla(v26.v4s(), v10.v4s(), v5.s()[1]);
216 ldr(q1, mem[x11], 16);
217 fmla(v27.v4s(), v11.v4s(), v5.s()[1]);
218 fmla(v28.v4s(), v10.v4s(), v6.s()[1]);
219 ldr(q2, mem[x12], 16);
220 fmla(v29.v4s(), v11.v4s(), v6.s()[1]);
221 fmla(v30.v4s(), v10.v4s(), v7.s()[1]);
222 ldr(q3, mem[x4], 16);
223 fmla(v31.v4s(), v11.v4s(), v7.s()[1]);
224 fmla(v24.v4s(), v12.v4s(), v4.s()[2]);
225 fmla(v25.v4s(), v13.v4s(), v4.s()[2]);
226 fmla(v26.v4s(), v12.v4s(), v5.s()[2]);
227 fmla(v27.v4s(), v13.v4s(), v5.s()[2]);
228 fmla(v28.v4s(), v12.v4s(), v6.s()[2]);
229 fmla(v29.v4s(), v13.v4s(), v6.s()[2]);
230 fmla(v30.v4s(), v12.v4s(), v7.s()[2]);
231 fmla(v31.v4s(), v13.v4s(), v7.s()[2]);
232 fmla(v24.v4s(), v14.v4s(), v4.s()[3]);
233 fmla(v25.v4s(), v15.v4s(), v4.s()[3]);
234 fmla(v26.v4s(), v14.v4s(), v5.s()[3]);
235 fmla(v27.v4s(), v15.v4s(), v5.s()[3]);
236 fmla(v28.v4s(), v14.v4s(), v6.s()[3]);
237 fmla(v29.v4s(), v15.v4s(), v6.s()[3]);
238 subs(x0, x0, 32);
239 fmla(v30.v4s(), v14.v4s(), v7.s()[3]);
240 fmla(v31.v4s(), v15.v4s(), v7.s()[3]);
241 b_hs(l1);
242
243 bind(l2);
244 // Epilogue
245 // First block of 4. FMA for first 4, loads for 2nd block of 4.
246 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
247 ldp(q8, q9, mem[x5], 32);
248 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
249 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
250 ldp(q10, q11, mem[x5], 32);
251 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
252 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
253 ldp(q12, q13, mem[x5], 32);
254 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
255 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
256 ldp(q14, q15, mem[x5], 32);
257 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
258 fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
259 ldr(q4, mem[x3], 16);
260 fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
261 fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
262 ldr(q5, mem[x11], 16);
263 fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
264 fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
265 ldr(q6, mem[x12], 16);
266 fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
267 fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
268 ldr(q7, mem[x4], 16);
269 fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
270 fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
271 fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
272 fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
273 fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
274 fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
275 fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
276 fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
277 fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
278 fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
279 fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
280 fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
281 fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
282 fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
283 fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
284 fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
285 fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
286
287 // Second block of 4. FMA for second 4, noloads
288 fmla(v24.v4s(), v8.v4s(), v4.s()[0]);
289 fmla(v25.v4s(), v9.v4s(), v4.s()[0]);
290 fmla(v26.v4s(), v8.v4s(), v5.s()[0]);
291 fmla(v27.v4s(), v9.v4s(), v5.s()[0]);
292 fmla(v28.v4s(), v8.v4s(), v6.s()[0]);
293 fmla(v29.v4s(), v9.v4s(), v6.s()[0]);
294 fmla(v30.v4s(), v8.v4s(), v7.s()[0]);
295 fmla(v31.v4s(), v9.v4s(), v7.s()[0]);
296
297 fmla(v24.v4s(), v10.v4s(), v4.s()[1]);
298 fmla(v25.v4s(), v11.v4s(), v4.s()[1]);
299 fmla(v26.v4s(), v10.v4s(), v5.s()[1]);
300 fmla(v27.v4s(), v11.v4s(), v5.s()[1]);
301 fmla(v28.v4s(), v10.v4s(), v6.s()[1]);
302 fmla(v29.v4s(), v11.v4s(), v6.s()[1]);
303 fmla(v30.v4s(), v10.v4s(), v7.s()[1]);
304 fmla(v31.v4s(), v11.v4s(), v7.s()[1]);
305
306 fmla(v24.v4s(), v12.v4s(), v4.s()[2]);
307 fmla(v25.v4s(), v13.v4s(), v4.s()[2]);
308 fmla(v26.v4s(), v12.v4s(), v5.s()[2]);
309 fmla(v27.v4s(), v13.v4s(), v5.s()[2]);
310 fmla(v28.v4s(), v12.v4s(), v6.s()[2]);
311 fmla(v29.v4s(), v13.v4s(), v6.s()[2]);
312 fmla(v30.v4s(), v12.v4s(), v7.s()[2]);
313 fmla(v31.v4s(), v13.v4s(), v7.s()[2]);
314
315 fmla(v24.v4s(), v14.v4s(), v4.s()[3]);
316 fmla(v25.v4s(), v15.v4s(), v4.s()[3]);
317 fmla(v26.v4s(), v14.v4s(), v5.s()[3]);
318 fmla(v27.v4s(), v15.v4s(), v5.s()[3]);
319
320 // Load min/max values
321 if (clamp_min && clamp_max) {
322 ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
323 } else if (clamp_min) {
324 ld1r({v4.v4s()}, mem[x8]);
325 } else if (clamp_max) {
326 add(x13, x8, 4);
327 ld1r({v5.v4s()}, mem[x13]);
328 }
329
330 fmla(v28.v4s(), v14.v4s(), v6.s()[3]);
331 fmla(v29.v4s(), v15.v4s(), v6.s()[3]);
332 fmla(v30.v4s(), v14.v4s(), v7.s()[3]);
333 fmla(v31.v4s(), v15.v4s(), v7.s()[3]);
334
335 bind(l3);
336 // Remainder- 4 floats of A (16 bytes)
337 tbz(x0, 4, l4);
338
339 ldr(q0, mem[x3], 16);
340 ldp(q16, q17, mem[x5], 32);
341 ldr(q1, mem[x11], 16);
342 ldr(q2, mem[x12], 16);
343 ldr(q3, mem[x4], 16);
344 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
345 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
346 ldp(q18, q19, mem[x5], 32);
347 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
348 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
349 ldp(q20, q21, mem[x5], 32);
350 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
351 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
352 ldp(q22, q23, mem[x5], 32);
353 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
354 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
355 fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
356 fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
357 fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
358 fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
359 fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
360 fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
361 fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
362 fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
363 fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
364 fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
365 fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
366 fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
367 fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
368 fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
369 fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
370 fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
371 fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
372 fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
373 fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
374 fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
375 fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
376 fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
377 fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
378 fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
379
380 bind(l4);
381 // Remainder- 2 floats of A (8 bytes)
382 tbz(x0, 3, l5);
383
384 ldr(d0, mem[x3], 8);
385 ldp(q16, q17, mem[x5], 32);
386 ldr(d1, mem[x11], 8);
387 ldr(d2, mem[x12], 8);
388 ldr(d3, mem[x4], 8);
389 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
390 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
391 ldp(q18, q19, mem[x5], 32);
392 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
393 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
394 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
395 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
396 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
397 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
398 fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
399 fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
400 fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
401 fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
402 fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
403 fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
404 fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
405 fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
406
407 bind(l5);
408 // Remainder- 1 float of A (4 bytes)
409 tbz(x0, 2, l6);
410
411 ldr(s0, mem[x3], 4);
412 ldp(q16, q17, mem[x5], 32);
413 ldr(s1, mem[x11], 4);
414 ldr(s2, mem[x12], 4);
415 ldr(s3, mem[x4], 4);
416 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
417 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
418 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
419 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
420 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
421 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
422 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
423 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
424
425 bind(l6);
426 // Clamp
427 subs(x1, x1, 8);
428 if (clamp_min) {
429 fmax(v24.v4s(), v24.v4s(), v4.v4s());
430 fmax(v25.v4s(), v25.v4s(), v4.v4s());
431 fmax(v26.v4s(), v26.v4s(), v4.v4s());
432 fmax(v27.v4s(), v27.v4s(), v4.v4s());
433 fmax(v28.v4s(), v28.v4s(), v4.v4s());
434 fmax(v29.v4s(), v29.v4s(), v4.v4s());
435 fmax(v30.v4s(), v30.v4s(), v4.v4s());
436 fmax(v31.v4s(), v31.v4s(), v4.v4s());
437 }
438 if (clamp_max) {
439 fmin(v24.v4s(), v24.v4s(), v5.v4s());
440 fmin(v25.v4s(), v25.v4s(), v5.v4s());
441 fmin(v26.v4s(), v26.v4s(), v5.v4s());
442 fmin(v27.v4s(), v27.v4s(), v5.v4s());
443 fmin(v28.v4s(), v28.v4s(), v5.v4s());
444 fmin(v29.v4s(), v29.v4s(), v5.v4s());
445 fmin(v30.v4s(), v30.v4s(), v5.v4s());
446 fmin(v31.v4s(), v31.v4s(), v5.v4s());
447 }
448
449 // Store full 4 x 8
450 b_lo(l7);
451
452 stp(q24, q25, mem[x6]);
453 sub(x3, x3, x2); // a0 -= kc
454 add(x6, x6, x14);
455 stp(q26, q27, mem[x9]);
456 sub(x11, x11, x2); // a1 -= kc
457 add(x9, x9, x14);
458 stp(q28, q29, mem[x10]);
459 sub(x12, x12, x2); // a2 -= kc
460 add(x10, x10, x14);
461 stp(q30, q31, mem[x7]);
462 sub(x4, x4, x2); // a3 -= kc
463 add(x7, x7, x14);
464
465 b_hi(l0);
466
467 // Restore d8-d15 from stack
468 ldp(d14, d15, mem[sp, 48]);
469 ldp(d12, d13, mem[sp, 32]);
470 ldp(d10, d11, mem[sp, 16]);
471 ldp(d8, d9, mem[sp], 64);
472 ret();
473
474 // Store odd width
475 bind(l7);
476 tbz(x1, 2, l8);
477 str(q24, mem[x6], 16);
478 mov(v24.v16b(), v25.v16b());
479 str(q26, mem[x9], 16);
480 mov(v26.v16b(), v27.v16b());
481 str(q28, mem[x10], 16);
482 mov(v28.v16b(), v29.v16b());
483 str(q30, mem[x7], 16);
484 mov(v30.v16b(), v31.v16b());
485
486 bind(l8);
487 tbz(x1, 1, l9);
488 str(d24, mem[x6], 8);
489 str(d26, mem[x9], 8);
490 dup(d24, v24.d()[1]);
491 dup(d26, v26.d()[1]);
492 str(d28, mem[x10], 8);
493 str(d30, mem[x7], 8);
494 dup(d28, v28.d()[1]);
495 dup(d30, v30.d()[1]);
496
497 bind(l9);
498 tbz(x1, 0, l10);
499 str(s24, mem[x6]);
500 str(s26, mem[x9]);
501 str(s28, mem[x10]);
502 str(s30, mem[x7]);
503 bind(l10);
504 // Restore d8-d15 from stack
505 ldp(d14, d15, mem[sp, 48]);
506 ldp(d12, d13, mem[sp, 32]);
507 ldp(d10, d11, mem[sp, 16]);
508 ldp(d8, d9, mem[sp], 64);
509 ret();
510
511
512 align(16, AlignInstruction::kHlt);
513 }
514 } // namespace
515 } // aarch64
516 } // xnnpack
517
xnn_generate_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)518 xnn_status_t xnn_generate_f32_gemm_ukernel_4x8__aarch64_neonfma_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
519 using namespace xnnpack::aarch64;
520 Generator g(code);
521 assert(params != nullptr);
522 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
523 g.generate(false, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
524 g.finalize();
525 if (g.error() != xnnpack::Error::kNoError) {
526 return xnn_status_invalid_state;
527 }
528 return xnn_status_success;
529 }
530
xnn_generate_f32_gemm_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)531 xnn_status_t xnn_generate_f32_gemm_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
532 using namespace xnnpack::aarch64;
533 Generator g(code);
534 assert(params != nullptr);
535 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
536 g.generate(true, max_mr, nc_mod_nr, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
537 g.finalize();
538 if (g.error() != xnnpack::Error::kNoError) {
539 return xnn_status_invalid_state;
540 }
541 return xnn_status_success;
542 }
543