1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6 #include <cassert>
7 #include <cstddef>
8 #include <limits>
9
10 #include <xnnpack.h>
11 #include <xnnpack/aarch64-assembler.h>
12 #include <xnnpack/allocator.h>
13 #include <xnnpack/igemm.h>
14
15 namespace xnnpack {
16 namespace aarch64 {
17 namespace {
18 class Generator : public Assembler {
19 using Assembler::Assembler;
20 public:
21 void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max);
22 };
23
24 // void xnn_f32_igemm_minmax_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(
25 // size_t mr, x0
26 // size_t nc, x1
27 // size_t kc, x2 / x0
28 // size_t ks, x3 / x9
29 // const float**restrict a, x4
30 // const float*restrict w, x5
31 // float*restrict c, x6
32 // size_t cm_stride, x7
33 // size_t cn_stride, [sp] -> x10
34 // size_t a_offset, [sp + 8] -> x11
35 // const float* zero, [sp + 16] -> x12
36 // const xnn_f32_minmax_params params [sp + 24] -> x8
37
38 // d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
39 // x21 used to store params->max if (!clamp_min && clamp_max).
40
41 // A pointers
42 // x20 a0
43 // x13 a1
44 // x14 a2
45 // x15 a3
46
47 // C pointers
48 // x6 c0
49 // x16 c1
50 // x17 c2
51 // x7 c3 / cm_stride
52
53 // Vector register usage
54 // A0 v0 v4
55 // A1 v1 v5
56 // A2 v2 v6
57 // A3 v3 v7
58 // B v8 v9 v10 v11
59 // B v12 v13 v14 v15
60 // B v16 v17 v18 v19
61 // B v20 v21 v22 v23
62 // C v24 v25
63 // C v26 v27
64 // C v28 v29
65 // C v30 v31
66 // Clamp v4 v5
67
68 // Converted from: src/f32-igemm/gen/4x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,float min,float max)69 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, float min, float max) {
70 assert(nc_mod_nr < 8);
71 assert(kc != 0);
72 assert(kc % sizeof(float) == 0);
73
74 Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10, l11;
75 const bool clamp_min = min != -std::numeric_limits<float>::infinity();
76 const bool clamp_max = max != +std::numeric_limits<float>::infinity();
77
78 // Load cn_stride, a_offset
79 ldp(x10, x11, mem[sp]);
80
81 // Load zero, params pointer
82 ldp(x12, x8, mem[sp, 16]);
83
84 // Load min/max values
85 if (clamp_min && clamp_max) {
86 ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
87 } else if (clamp_min) {
88 ld1r({v4.v4s()}, mem[x8]);
89 } else if (clamp_max) {
90 add(x21, x8, 4);
91 ld1r({v5.v4s()}, mem[x21]);
92 }
93
94 // Save x20, x21 on stack
95 stp(x20, x21, mem[sp, -80]++);
96
97 // Save d8-d15 on stack
98 stp(d8, d9, mem[sp, 16]);
99 stp(d10, d11, mem[sp, 32]);
100 stp(d12, d13, mem[sp, 48]);
101 stp(d14, d15, mem[sp, 64]);
102
103 // Clamp C pointers
104 cmp(x0, 2); // if mr < 2
105 add(x16, x6, x7); // c1 = c0 + cm_stride
106 csel(x16, x6, x16, kLO); // c1 = c0
107
108 add(x17, x16, x7); // c2 = c1 + cm_stride
109 // if mr <= 2
110 csel(x17, x16, x17, kLS); // c2 = c1
111
112 cmp(x0, 4); // if mr < 4
113 add(x7, x17, x7); // c3 = c2 + cm_stride
114 csel(x7, x17, x7, kLO); // c3 = c2
115
116 bind(l0);
117 // Load initial bias from w into accumulators
118 ldp(q24, q25, mem[x5], 32);
119 mov(v26.v16b(), v24.v16b());
120 mov(v27.v16b(), v25.v16b());
121 mov(v28.v16b(), v24.v16b());
122 mov(v29.v16b(), v25.v16b());
123 mov(v30.v16b(), v24.v16b());
124 mov(v31.v16b(), v25.v16b());
125
126 mov(x9, x3); // p = ks
127
128 bind(l1);
129 // Load next 4 A pointers
130 ldp(x20, x13, mem[x4], 16);
131 ldp(x14, x15, mem[x4], 16);
132
133 cmp(x20, x12); // if a0 == zero
134 add(x20, x20, x11); // a0 += a_offset
135 csel(x20, x12, x20, kEQ); // a0 = zero, else += a0 + a_offset
136 cmp(x13, x12); // if a1 == zero
137 add(x13, x13, x11); // a1 += a_offset
138 csel(x13, x12, x13, kEQ); // a1 = zero, else += a1 + a_offset
139 cmp(x14, x12); // if a2 == zero
140 add(x14, x14, x11); // a2 += a_offset
141 csel(x14, x12, x14, kEQ); // a2 = zero, else += a2 + a_offset
142 cmp(x15, x12); // if a3 == zero
143 add(x15, x15, x11); // a3 += a_offset
144 csel(x15, x12, x15, kEQ); // a3 = zero, else += a3 + a_offset
145
146 // Is there at least 8 floats (32 bytes) for prologue + epilogue?
147 subs(x0, x2, 32); // k = kc - 32
148 b_lo(l4);
149
150 // 16 prologue
151 // Read first block of 4 A and B.
152 ldr(q0, mem[x20], 16);
153 ldp(q16, q17, mem[x5], 32);
154 ldr(q1, mem[x13], 16);
155 ldr(q2, mem[x14], 16);
156 ldr(q3, mem[x15], 16);
157 ldp(q18, q19, mem[x5], 32);
158 ldp(q20, q21, mem[x5], 32);
159 ldp(q22, q23, mem[x5], 32);
160
161 // Is there at least 32. yes do main loop
162 subs(x0, x0, 32);
163 b_lo(l3);
164
165 // Main loop - 8 floats of A
166 bind(l2);
167 // First block of 4. FMA for first 4, loads for 2nd block of 4.
168 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
169 ldp(q8, q9, mem[x5], 32);
170 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
171 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
172 ldp(q10, q11, mem[x5], 32);
173 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
174 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
175 ldp(q12, q13, mem[x5], 32);
176 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
177 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
178 ldp(q14, q15, mem[x5], 32);
179 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
180 fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
181 ldr(q4, mem[x20], 16);
182 fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
183 fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
184 ldr(q5, mem[x13], 16);
185 fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
186 fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
187 ldr(q6, mem[x14], 16);
188 fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
189 fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
190 ldr(q7, mem[x15], 16);
191 fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
192 fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
193 if (prefetch) {
194 prfm(kPLDL1KEEP, mem[x5, 128]);
195 }
196 fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
197 fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
198 if (prefetch) {
199 prfm(kPLDL1KEEP, mem[x5, 192]);
200 }
201 fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
202 fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
203 if (prefetch) {
204 prfm(kPLDL1KEEP, mem[x5, 256]);
205 }
206 fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
207 fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
208 if (prefetch) {
209 prfm(kPLDL1KEEP, mem[x5, 320]);
210 }
211 fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
212 fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
213 fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
214 fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
215 fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
216 fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
217 fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
218 fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
219 fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
220
221 // Second block of 4. FMA for second 4, loads for 1st block of 4.
222 fmla(v24.v4s(), v8.v4s(), v4.s()[0]);
223 ldp(q16, q17, mem[x5], 32);
224 fmla(v25.v4s(), v9.v4s(), v4.s()[0]);
225 fmla(v26.v4s(), v8.v4s(), v5.s()[0]);
226 ldp(q18, q19, mem[x5], 32);
227 fmla(v27.v4s(), v9.v4s(), v5.s()[0]);
228 fmla(v28.v4s(), v8.v4s(), v6.s()[0]);
229 ldp(q20, q21, mem[x5], 32);
230 fmla(v29.v4s(), v9.v4s(), v6.s()[0]);
231 fmla(v30.v4s(), v8.v4s(), v7.s()[0]);
232 ldp(q22, q23, mem[x5], 32);
233 fmla(v31.v4s(), v9.v4s(), v7.s()[0]);
234 fmla(v24.v4s(), v10.v4s(), v4.s()[1]);
235 ldr(q0, mem[x20], 16);
236 fmla(v25.v4s(), v11.v4s(), v4.s()[1]);
237 fmla(v26.v4s(), v10.v4s(), v5.s()[1]);
238 ldr(q1, mem[x13], 16);
239 fmla(v27.v4s(), v11.v4s(), v5.s()[1]);
240 fmla(v28.v4s(), v10.v4s(), v6.s()[1]);
241 ldr(q2, mem[x14], 16);
242 fmla(v29.v4s(), v11.v4s(), v6.s()[1]);
243 fmla(v30.v4s(), v10.v4s(), v7.s()[1]);
244 ldr(q3, mem[x15], 16);
245 fmla(v31.v4s(), v11.v4s(), v7.s()[1]);
246 fmla(v24.v4s(), v12.v4s(), v4.s()[2]);
247 fmla(v25.v4s(), v13.v4s(), v4.s()[2]);
248 fmla(v26.v4s(), v12.v4s(), v5.s()[2]);
249 fmla(v27.v4s(), v13.v4s(), v5.s()[2]);
250 fmla(v28.v4s(), v12.v4s(), v6.s()[2]);
251 fmla(v29.v4s(), v13.v4s(), v6.s()[2]);
252 fmla(v30.v4s(), v12.v4s(), v7.s()[2]);
253 fmla(v31.v4s(), v13.v4s(), v7.s()[2]);
254 fmla(v24.v4s(), v14.v4s(), v4.s()[3]);
255 fmla(v25.v4s(), v15.v4s(), v4.s()[3]);
256 fmla(v26.v4s(), v14.v4s(), v5.s()[3]);
257 fmla(v27.v4s(), v15.v4s(), v5.s()[3]);
258 fmla(v28.v4s(), v14.v4s(), v6.s()[3]);
259 fmla(v29.v4s(), v15.v4s(), v6.s()[3]);
260 subs(x0, x0, 32);
261 fmla(v30.v4s(), v14.v4s(), v7.s()[3]);
262 fmla(v31.v4s(), v15.v4s(), v7.s()[3]);
263
264 b_hs(l2);
265
266 bind(l3);
267 // Epilogue
268 // First block of 4. FMA for first 4, loads for 2nd block of 4.
269 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
270 ldp(q8, q9, mem[x5], 32);
271 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
272 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
273 ldp(q10, q11, mem[x5], 32);
274 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
275 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
276 ldp(q12, q13, mem[x5], 32);
277 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
278 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
279 ldp(q14, q15, mem[x5], 32);
280 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
281 fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
282 ldr(q4, mem[x20], 16);
283 fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
284 fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
285 ldr(q5, mem[x13], 16);
286 fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
287 fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
288 ldr(q6, mem[x14], 16);
289 fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
290 fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
291 ldr(q7, mem[x15], 16);
292 fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
293 fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
294 fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
295 fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
296 fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
297 fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
298 fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
299 fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
300 fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
301 fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
302 fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
303 fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
304 fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
305 fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
306 fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
307 fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
308 fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
309
310 // Second block of 4. FMA for second 4, noloads
311 fmla(v24.v4s(), v8.v4s(), v4.s()[0]);
312 fmla(v25.v4s(), v9.v4s(), v4.s()[0]);
313 fmla(v26.v4s(), v8.v4s(), v5.s()[0]);
314 fmla(v27.v4s(), v9.v4s(), v5.s()[0]);
315 fmla(v28.v4s(), v8.v4s(), v6.s()[0]);
316 fmla(v29.v4s(), v9.v4s(), v6.s()[0]);
317 fmla(v30.v4s(), v8.v4s(), v7.s()[0]);
318 fmla(v31.v4s(), v9.v4s(), v7.s()[0]);
319 fmla(v24.v4s(), v10.v4s(), v4.s()[1]);
320 fmla(v25.v4s(), v11.v4s(), v4.s()[1]);
321 fmla(v26.v4s(), v10.v4s(), v5.s()[1]);
322 fmla(v27.v4s(), v11.v4s(), v5.s()[1]);
323 fmla(v28.v4s(), v10.v4s(), v6.s()[1]);
324 fmla(v29.v4s(), v11.v4s(), v6.s()[1]);
325 fmla(v30.v4s(), v10.v4s(), v7.s()[1]);
326 fmla(v31.v4s(), v11.v4s(), v7.s()[1]);
327 fmla(v24.v4s(), v12.v4s(), v4.s()[2]);
328 fmla(v25.v4s(), v13.v4s(), v4.s()[2]);
329 fmla(v26.v4s(), v12.v4s(), v5.s()[2]);
330 fmla(v27.v4s(), v13.v4s(), v5.s()[2]);
331 fmla(v28.v4s(), v12.v4s(), v6.s()[2]);
332 fmla(v29.v4s(), v13.v4s(), v6.s()[2]);
333 fmla(v30.v4s(), v12.v4s(), v7.s()[2]);
334 fmla(v31.v4s(), v13.v4s(), v7.s()[2]);
335
336 fmla(v24.v4s(), v14.v4s(), v4.s()[3]);
337 fmla(v25.v4s(), v15.v4s(), v4.s()[3]);
338 fmla(v26.v4s(), v14.v4s(), v5.s()[3]);
339 fmla(v27.v4s(), v15.v4s(), v5.s()[3]);
340
341 // Load min/max values
342 if (clamp_min && clamp_max) {
343 ld2r({v4.v4s(), v5.v4s()}, mem[x8]);
344 } else if (clamp_min) {
345 ld1r({v4.v4s()}, mem[x8]);
346 } else if (clamp_max) {
347 add(x13, x8, 4);
348 ld1r({v5.v4s()}, mem[x13]);
349 }
350
351 fmla(v28.v4s(), v14.v4s(), v6.s()[3]);
352 fmla(v29.v4s(), v15.v4s(), v6.s()[3]);
353 fmla(v30.v4s(), v14.v4s(), v7.s()[3]);
354 fmla(v31.v4s(), v15.v4s(), v7.s()[3]);
355
356 bind(l4);
357 // Remainder- 4 floats of A
358 tbz(x0, 4, l5);
359
360 ldr(q0, mem[x20], 16);
361 ldp(q16, q17, mem[x5], 32);
362 ldr(q1, mem[x13], 16);
363 ldr(q2, mem[x14], 16);
364 ldr(q3, mem[x15], 16);
365 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
366 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
367 ldp(q18, q19, mem[x5], 32);
368 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
369 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
370 ldp(q20, q21, mem[x5], 32);
371 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
372 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
373 ldp(q22, q23, mem[x5], 32);
374 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
375 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
376 fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
377 fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
378 fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
379 fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
380 fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
381 fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
382 fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
383 fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
384 fmla(v24.v4s(), v20.v4s(), v0.s()[2]);
385 fmla(v25.v4s(), v21.v4s(), v0.s()[2]);
386 fmla(v26.v4s(), v20.v4s(), v1.s()[2]);
387 fmla(v27.v4s(), v21.v4s(), v1.s()[2]);
388 fmla(v28.v4s(), v20.v4s(), v2.s()[2]);
389 fmla(v29.v4s(), v21.v4s(), v2.s()[2]);
390 fmla(v30.v4s(), v20.v4s(), v3.s()[2]);
391 fmla(v31.v4s(), v21.v4s(), v3.s()[2]);
392 fmla(v24.v4s(), v22.v4s(), v0.s()[3]);
393 fmla(v25.v4s(), v23.v4s(), v0.s()[3]);
394 fmla(v26.v4s(), v22.v4s(), v1.s()[3]);
395 fmla(v27.v4s(), v23.v4s(), v1.s()[3]);
396 fmla(v28.v4s(), v22.v4s(), v2.s()[3]);
397 fmla(v29.v4s(), v23.v4s(), v2.s()[3]);
398 fmla(v30.v4s(), v22.v4s(), v3.s()[3]);
399 fmla(v31.v4s(), v23.v4s(), v3.s()[3]);
400
401 bind(l5);
402 // Remainder- 2 floats of A
403 tbz(x0, 3, l6);
404
405 ldr(d0, mem[x20], 8);
406 ldp(q16, q17, mem[x5], 32);
407 ldr(d1, mem[x13], 8);
408 ldr(d2, mem[x14], 8);
409 ldr(d3, mem[x15], 8);
410 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
411 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
412 ldp(q18, q19, mem[x5], 32);
413 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
414 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
415 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
416 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
417 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
418 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
419 fmla(v24.v4s(), v18.v4s(), v0.s()[1]);
420 fmla(v25.v4s(), v19.v4s(), v0.s()[1]);
421 fmla(v26.v4s(), v18.v4s(), v1.s()[1]);
422 fmla(v27.v4s(), v19.v4s(), v1.s()[1]);
423 fmla(v28.v4s(), v18.v4s(), v2.s()[1]);
424 fmla(v29.v4s(), v19.v4s(), v2.s()[1]);
425 fmla(v30.v4s(), v18.v4s(), v3.s()[1]);
426 fmla(v31.v4s(), v19.v4s(), v3.s()[1]);
427
428 bind(l6);
429 // Remainder- 1 float of A
430 tbz(x0, 2, l7);
431
432 ldr(s0, mem[x20], 4);
433 ldp(q16, q17, mem[x5], 32);
434 ldr(s1, mem[x13], 4);
435 ldr(s2, mem[x14], 4);
436 ldr(s3, mem[x15], 4);
437 fmla(v24.v4s(), v16.v4s(), v0.s()[0]);
438 fmla(v25.v4s(), v17.v4s(), v0.s()[0]);
439 fmla(v26.v4s(), v16.v4s(), v1.s()[0]);
440 fmla(v27.v4s(), v17.v4s(), v1.s()[0]);
441 fmla(v28.v4s(), v16.v4s(), v2.s()[0]);
442 fmla(v29.v4s(), v17.v4s(), v2.s()[0]);
443 fmla(v30.v4s(), v16.v4s(), v3.s()[0]);
444 fmla(v31.v4s(), v17.v4s(), v3.s()[0]);
445
446 bind(l7);
447 // ks loop
448 subs(x9, x9, 32); // ks -= MR * sizeof(void*)
449 b_hi(l1);
450
451 // Clamp
452 if (clamp_min) {
453 fmax(v24.v4s(), v24.v4s(), v4.v4s());
454 fmax(v25.v4s(), v25.v4s(), v4.v4s());
455 fmax(v26.v4s(), v26.v4s(), v4.v4s());
456 fmax(v27.v4s(), v27.v4s(), v4.v4s());
457 fmax(v28.v4s(), v28.v4s(), v4.v4s());
458 fmax(v29.v4s(), v29.v4s(), v4.v4s());
459 fmax(v30.v4s(), v30.v4s(), v4.v4s());
460 fmax(v31.v4s(), v31.v4s(), v4.v4s());
461 }
462 if (clamp_max) {
463 fmin(v24.v4s(), v24.v4s(), v5.v4s());
464 fmin(v25.v4s(), v25.v4s(), v5.v4s());
465 fmin(v26.v4s(), v26.v4s(), v5.v4s());
466 fmin(v27.v4s(), v27.v4s(), v5.v4s());
467 fmin(v28.v4s(), v28.v4s(), v5.v4s());
468 fmin(v29.v4s(), v29.v4s(), v5.v4s());
469 fmin(v30.v4s(), v30.v4s(), v5.v4s());
470 fmin(v31.v4s(), v31.v4s(), v5.v4s());
471 }
472
473 // Store full 4 x 8
474 subs(x1, x1, 8);
475 b_lo(l8);
476
477 stp(q30, q31, mem[x7]);
478 add(x7, x7, x10);
479 stp(q28, q29, mem[x17]);
480 add(x17, x17, x10);
481 stp(q26, q27, mem[x16]);
482 add(x16, x16, x10);
483 stp(q24, q25, mem[x6]);
484 add(x6, x6, x10);
485
486 sub(x4, x4, x3); // a -= ks
487
488 // nc loop
489 b_hi(l0);
490
491 // Restore d8-d15 from stack
492 ldp(d14, d15, mem[sp, 64]);
493 ldp(d12, d13, mem[sp, 48]);
494 ldp(d10, d11, mem[sp, 32]);
495 ldp(d8, d9, mem[sp, 16]);
496
497 // Restore x20 from stack
498 ldr(x20, mem[sp], 80);
499 ret();
500
501 // Store odd width
502 bind(l8);
503 tbz(x1, 2, l9);
504 str(q30, mem[x7], 16);
505 mov(v30.v16b(), v31.v16b());
506 str(q28, mem[x17], 16);
507 mov(v28.v16b(), v29.v16b());
508 str(q26, mem[x16], 16);
509 mov(v26.v16b(), v27.v16b());
510 str(q24, mem[x6], 16);
511 mov(v24.v16b(), v25.v16b());
512
513 bind(l9);
514 tbz(x1, 1, l10);
515 str(d30, mem[x7], 8);
516 str(d28, mem[x17], 8);
517 dup(d30, v30.d()[1]);
518 dup(d28, v28.d()[1]);
519 str(d26, mem[x16], 8);
520 str(d24, mem[x6], 8);
521 dup(d26, v26.d()[1]);
522 dup(d24, v24.d()[1]);
523
524 bind(l10);
525 tbz(x1, 0, l11);
526 str(s30, mem[x7]);
527 str(s28, mem[x17]);
528 str(s26, mem[x16]);
529 str(s24, mem[x6]);
530 bind(l11);
531 // Restore d8-d15 from stack
532 ldp(d14, d15, mem[sp, 64]);
533 ldp(d12, d13, mem[sp, 48]);
534 ldp(d10, d11, mem[sp, 32]);
535 ldp(d8, d9, mem[sp, 16]);
536
537 // Restore x20, x21 from stack
538 ldp(x20, x21, mem[sp], 80);
539 ret();
540
541 align(16, AlignInstruction::kHlt);
542 }
543 } // namespace
544 } // aarch64
545 } // xnnpack
546
xnn_generate_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)547 xnn_status_t xnn_generate_f32_igemm_ukernel_4x8__aarch64_neonfma_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
548 using namespace xnnpack::aarch64;
549 Generator g(code);
550 assert(params != nullptr);
551 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
552 g.generate(false, max_mr, nc_mod_nr, kc, ks, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
553 g.finalize();
554 if (g.error() != xnnpack::Error::kNoError) {
555 return xnn_status_invalid_state;
556 }
557 return xnn_status_success;
558 }
559
xnn_generate_f32_igemm_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)560 xnn_status_t xnn_generate_f32_igemm_ukernel_4x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
561 using namespace xnnpack::aarch64;
562 Generator g(code);
563 assert(params != nullptr);
564 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
565 g.generate(true, max_mr, nc_mod_nr, kc, ks, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
566 g.finalize();
567 if (g.error() != xnnpack::Error::kNoError) {
568 return xnn_status_invalid_state;
569 }
570 return xnn_status_success;
571 }
572