1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6
7 #include <cassert>
8
9 #include <xnnpack.h>
10 #include <xnnpack/aarch32-assembler.h>
11 #include <xnnpack/allocator.h>
12 #include <xnnpack/gemm.h>
13
14 namespace xnnpack {
15 namespace aarch32 {
16 namespace {
17 class Generator : public Assembler {
18 using Assembler::Assembler;
19 public:
20 void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params);
21 };
22
23
24 // void xnn_qc8_gemm_minmax_fp32_ukernel_4x8__aarch32_neonv8_mlal_lane_prfm_ld64(
25 // size_t mr, r0
26 // size_t nc, r1
27 // size_t kc, r2 -> r5
28 // const uint8_t*restrict a, r3
29 // size_t a_stride, sp + 64 -> (r7)
30 // const void*restrict w, sp + 68 -> r9
31 // uint8_t*restrict c, sp + 72 -> r11
32 // size_t cm_stride, sp + 76 -> (r6)
33 // size_t cn_stride, sp + 80 -> r7
34 // xnn_qs8_minmax_params params) sp + 84 -> (r5)
35
36 // d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved.
37
38 // Register usage
39
40 // A0 r3 d0-d1 q0
41 // A1 r12 d2-d3 q1
42 // A2 r10 d4-d5 q2
43 // A3 r0 d6-d7 q3
44
45 // B r9 d8-d9 q4
46
47 // C0 r11 d16-d17 q8 d18-d19 q9
48 // C1 r4 d20-d21 q10 d22-d23 q11
49 // C2 r8 d24-d25 q12 d26-d27 q13
50 // C3 r6 d28-d29 q14 d30-d31 q15
51
52 // unused q6 q7
53
54 // params structure is 4 bytes
55 // struct {
56 // int16_t output_zero_point; d11[2]
57 // int8_t output_min; d11[6]
58 // int8_t output_max; d11[7]
59 // } xnn_qs8_minmax_params.neonv8;
60
61 // Converted from: src/qc8-gemm/gen/4x8-minmax-fp32-aarch32-neonv8-mlal-lane-prfm-ld64.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)62 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params)
63 {
64 assert(nc_mod_nr < 8);
65 assert(kc != 0);
66
67 Label l0, l1, l2, l3, l4, l5, l6, l7;
68
69 // Push 64 bytes
70 push({r4, r5, r6, r7, r8, r9, r10, r11}); // 32
71 vpush({d8-d11}); // +32 = 64
72
73 ldr(r7, mem[sp, 64]); // a_stride
74 ldr(r11, mem[sp, 72]); // c
75 ldr(r6, mem[sp, 76]); // cm_stride
76 ldr(r9, mem[sp, 68]); // w
77 ldr(r5, mem[sp, 84]); // params
78
79 // Clamp A and C pointers
80 cmp(r0, 2); // if mr >= 2
81 add(r12, r3, r7); // a1 = a0 + a_stride
82 add(r4, r11, r6); // c1 = c0 + cm_stride
83 movlo(r12, r3); // a1
84 movlo(r4, r11); // c1
85 // if mr > 2
86 add(r10, r12, r7); // a2 = a1 + a_stride
87 add(r8, r4, r6); // c2 = c1 + cm_stride
88 movls(r10, r12); // a2
89 movls(r8, r4); // c2
90
91 cmp(r0, 4); // if mr >=4
92 add(r0, r10, r7); // a3 = a2 + a_stride
93 add(r6, r8, r6); // c3 = c2 + cm_stride
94 movlo(r0, r10); // a3
95 movlo(r6, r8); // c3
96
97 // Load params values
98 vld1r_32({d11}, mem[r5]); // QC8 params
99 ldr(r7, mem[sp, 80]); // cn_stride
100
101 if (prefetch) {
102 pld(mem[r9, 64]); // Prefetch B
103 pld(mem[r9, 128]);
104 pld(mem[r9, 192]);
105 pld(mem[r9, 256]);
106 pld(mem[r9, 320]);
107 pld(mem[r9, 384]);
108 }
109
110 align(8);
111 bind(l0);
112 // Load initial bias from w into accumulators
113 vldm(mem[r9]++, {d16-d19}); // Bias
114 subs(r5, r2, 8); // k = kc - 8
115
116 vmov(q10, q8);
117 if (prefetch) {
118 pld(mem[r3, 64]); // Prefetch A
119 }
120 vmov(q11, q9);
121 if (prefetch) {
122 pld(mem[r12, 64]);
123 }
124 vmov(q12, q8);
125 if (prefetch) {
126 pld(mem[r10, 64]);
127 }
128 vmov(q13, q9);
129 if (prefetch) {
130 pld(mem[r0, 64]);
131 }
132 vmov(q14, q8);
133 vmov(q15, q9);
134 blo(l3); // less than 8 channels?
135
136 // Main loop - 8 bytes
137 // 64 bytes for weights.
138 align(8);
139 bind(l1);
140 vld1_8({d0}, mem[r3]++); // A0
141 vld1_8({d8}, mem[r9]++); // B
142 vld1_8({d2}, mem[r12]++); // A1
143 vld1_8({d4}, mem[r10]++); // A2
144 vld1_8({d6}, mem[r0]++); // A3
145 subs(r5, r5, 8);
146 if (prefetch) {
147 pld(mem[r3, 128]);
148 }
149 vmovl_s8(q0, d0);
150 if (prefetch) {
151 pld(mem[r12, 128]);
152 }
153 vmovl_s8(q4, d8);
154 if (prefetch) {
155 pld(mem[r10, 128]);
156 }
157 vmovl_s8(q1, d2);
158 if (prefetch) {
159 pld(mem[r0, 128]);
160 }
161 vmovl_s8(q2, d4);
162 if (prefetch) {
163 pld(mem[r9, 448]);
164 }
165 vmovl_s8(q3, d6);
166 vmlal_s16(q8, d8, d0[0]);
167 vmlal_s16(q9, d9, d0[0]);
168 vmlal_s16(q10, d8, d2[0]);
169 vmlal_s16(q11, d9, d2[0]);
170 vmlal_s16(q12, d8, d4[0]);
171 vmlal_s16(q13, d9, d4[0]);
172 vmlal_s16(q14, d8, d6[0]);
173 vmlal_s16(q15, d9, d6[0]);
174
175 vld1_8({d8}, mem[r9]++);
176 vmovl_s8(q4, d8);
177 vmlal_s16(q8, d8, d0[1]);
178 vmlal_s16(q9, d9, d0[1]);
179 vmlal_s16(q10, d8, d2[1]);
180 vmlal_s16(q11, d9, d2[1]);
181 vmlal_s16(q12, d8, d4[1]);
182 vmlal_s16(q13, d9, d4[1]);
183 vmlal_s16(q14, d8, d6[1]);
184 vmlal_s16(q15, d9, d6[1]);
185
186 vld1_8({d8}, mem[r9]++);
187 vmovl_s8(q4, d8);
188 vmlal_s16(q8, d8, d0[2]);
189 vmlal_s16(q9, d9, d0[2]);
190 vmlal_s16(q10, d8, d2[2]);
191 vmlal_s16(q11, d9, d2[2]);
192 vmlal_s16(q12, d8, d4[2]);
193 vmlal_s16(q13, d9, d4[2]);
194 vmlal_s16(q14, d8, d6[2]);
195 vmlal_s16(q15, d9, d6[2]);
196
197 vld1_8({d8}, mem[r9]++);
198 vmovl_s8(q4, d8);
199 vmlal_s16(q8, d8, d0[3]);
200 vmlal_s16(q9, d9, d0[3]);
201 vmlal_s16(q10, d8, d2[3]);
202 vmlal_s16(q11, d9, d2[3]);
203 vmlal_s16(q12, d8, d4[3]);
204 vmlal_s16(q13, d9, d4[3]);
205 vmlal_s16(q14, d8, d6[3]);
206 vmlal_s16(q15, d9, d6[3]);
207
208 vld1_8({d8}, mem[r9]++);
209 vmovl_s8(q4, d8);
210 vmlal_s16(q8, d8, d1[0]);
211 vmlal_s16(q9, d9, d1[0]);
212 vmlal_s16(q10, d8, d3[0]);
213 vmlal_s16(q11, d9, d3[0]);
214 vmlal_s16(q12, d8, d5[0]);
215 vmlal_s16(q13, d9, d5[0]);
216 vmlal_s16(q14, d8, d7[0]);
217 vmlal_s16(q15, d9, d7[0]);
218
219 vld1_8({d8}, mem[r9]++);
220 vmovl_s8(q4, d8);
221 vmlal_s16(q8, d8, d1[1]);
222 vmlal_s16(q9, d9, d1[1]);
223 vmlal_s16(q10, d8, d3[1]);
224 vmlal_s16(q11, d9, d3[1]);
225 vmlal_s16(q12, d8, d5[1]);
226 vmlal_s16(q13, d9, d5[1]);
227 vmlal_s16(q14, d8, d7[1]);
228 vmlal_s16(q15, d9, d7[1]);
229
230 vld1_8({d8}, mem[r9]++);
231 vmovl_s8(q4, d8);
232 vmlal_s16(q8, d8, d1[2]);
233 vmlal_s16(q9, d9, d1[2]);
234 vmlal_s16(q10, d8, d3[2]);
235 vmlal_s16(q11, d9, d3[2]);
236 vmlal_s16(q12, d8, d5[2]);
237 vmlal_s16(q13, d9, d5[2]);
238 vmlal_s16(q14, d8, d7[2]);
239 vmlal_s16(q15, d9, d7[2]);
240
241 vld1_8({d8}, mem[r9]++);
242 vmovl_s8(q4, d8);
243 vmlal_s16(q8, d8, d1[3]);
244 vmlal_s16(q9, d9, d1[3]);
245 vmlal_s16(q10, d8, d3[3]);
246 vmlal_s16(q11, d9, d3[3]);
247 vmlal_s16(q12, d8, d5[3]);
248 vmlal_s16(q13, d9, d5[3]);
249 vmlal_s16(q14, d8, d7[3]);
250 vmlal_s16(q15, d9, d7[3]);
251 bhs(l1);
252
253 // Is there a remainder?- 1-7 bytes of A
254 adds(r5, r5, 8);
255 bne(l3);
256
257 bind(l2);
258 // QC8 FP32 quantization
259 vld1_8({q0-q1}, mem[r9]++);
260
261 vcvt_f32_s32(q8, q8);
262 vcvt_f32_s32(q9, q9);
263 vcvt_f32_s32(q10, q10);
264 vcvt_f32_s32(q11, q11);
265 vcvt_f32_s32(q12, q12);
266 vcvt_f32_s32(q13, q13);
267 vcvt_f32_s32(q14, q14);
268 vcvt_f32_s32(q15, q15);
269
270 vmul_f32(q8, q8, q0); // multiplier
271 vmul_f32(q9, q9, q1);
272 vmul_f32(q10, q10, q0);
273 vmul_f32(q11, q11, q1);
274 vmul_f32(q12, q12, q0);
275 vmul_f32(q13, q13, q1);
276 vmul_f32(q14, q14, q0);
277 vmul_f32(q15, q15, q1);
278
279 vcvtn_s32_f32(q8, q8);
280 vcvtn_s32_f32(q9, q9);
281 vcvtn_s32_f32(q10, q10);
282 vcvtn_s32_f32(q11, q11);
283 vcvtn_s32_f32(q12, q12);
284 vcvtn_s32_f32(q13, q13);
285 vcvtn_s32_f32(q14, q14);
286 vcvtn_s32_f32(q15, q15);
287
288 vdup_16(q0, d11[2]); // output_zero_point
289
290 vqmovn_s32(d16, q8);
291 vqmovn_s32(d17, q9);
292 vqmovn_s32(d18, q10);
293 vqmovn_s32(d19, q11);
294 vqmovn_s32(d20, q12);
295 vqmovn_s32(d21, q13);
296 vqmovn_s32(d22, q14);
297 vqmovn_s32(d23, q15);
298
299 vqadd_s16(q8, q8, q0);
300 vqadd_s16(q9, q9, q0);
301 vqadd_s16(q10, q10, q0);
302 vqadd_s16(q11, q11, q0);
303
304 vdup_8(q12, d11[6]); // output_min
305
306 vqmovn_s16(d0, q8);
307 vqmovn_s16(d1, q9);
308 vqmovn_s16(d2, q10);
309 vqmovn_s16(d3, q11);
310
311 vdup_8(q13, d11[7]); // output_max
312
313 vmax_s8(q0, q0, q12);
314 vmax_s8(q1, q1, q12);
315
316 subs(r1, r1, 8);
317
318 vmin_s8(q0, q0, q13);
319 vmin_s8(q1, q1, q13);
320
321 // Store full 4 x 8
322 blo(l4);
323 vst1_8({d0}, mem[r11], r7);
324 sub(r3, r3, r2);
325 vst1_8({d1}, mem[r4], r7);
326 sub(r12, r12, r2);
327 vst1_8({d2}, mem[r8], r7);
328 sub(r10, r10, r2);
329 vst1_8({d3}, mem[r6], r7);
330 sub(r0, r0, r2);
331 bhi(l0);
332
333 vpop({d8-d11});
334 pop({r4, r5, r6, r7, r8, r9, r10, r11});
335 bx(lr);
336
337 // Remainder- 1 to 7 bytes of A
338 align(8);
339 bind(l3);
340 and_(r5, r5, 7); // kc remainder 1 to 7
341
342 vld1_8({d0}, mem[r3], r5);
343 vld1_8({d8}, mem[r9]++);
344 vld1_8({d2}, mem[r12], r5);
345 vld1_8({d4}, mem[r10], r5);
346 vld1_8({d6}, mem[r0], r5);
347
348 vmovl_s8(q0, d0);
349 vmovl_s8(q4, d8);
350 vmovl_s8(q1, d2);
351 vmovl_s8(q2, d4);
352 vmovl_s8(q3, d6);
353 vmlal_s16(q8, d8, d0[0]);
354 vmlal_s16(q9, d9, d0[0]);
355 vmlal_s16(q10, d8, d2[0]);
356 vmlal_s16(q11, d9, d2[0]);
357 vmlal_s16(q12, d8, d4[0]);
358 vmlal_s16(q13, d9, d4[0]);
359 vmlal_s16(q14, d8, d6[0]);
360 vmlal_s16(q15, d9, d6[0]);
361 cmp(r5, 2);
362 blo(l2);
363
364 vld1_8({d8}, mem[r9]++);
365 vmovl_s8(q4, d8);
366 vmlal_s16(q8, d8, d0[1]);
367 vmlal_s16(q9, d9, d0[1]);
368 vmlal_s16(q10, d8, d2[1]);
369 vmlal_s16(q11, d9, d2[1]);
370 vmlal_s16(q12, d8, d4[1]);
371 vmlal_s16(q13, d9, d4[1]);
372 vmlal_s16(q14, d8, d6[1]);
373 vmlal_s16(q15, d9, d6[1]);
374 beq(l2);
375
376 vld1_8({d8}, mem[r9]++);
377 vmovl_s8(q4, d8);
378 vmlal_s16(q8, d8, d0[2]);
379 vmlal_s16(q9, d9, d0[2]);
380 vmlal_s16(q10, d8, d2[2]);
381 vmlal_s16(q11, d9, d2[2]);
382 vmlal_s16(q12, d8, d4[2]);
383 vmlal_s16(q13, d9, d4[2]);
384 vmlal_s16(q14, d8, d6[2]);
385 vmlal_s16(q15, d9, d6[2]);
386 cmp(r5, 4);
387 blo(l2);
388
389 vld1_8({d8}, mem[r9]++);
390 vmovl_s8(q4, d8);
391 vmlal_s16(q8, d8, d0[3]);
392 vmlal_s16(q9, d9, d0[3]);
393 vmlal_s16(q10, d8, d2[3]);
394 vmlal_s16(q11, d9, d2[3]);
395 vmlal_s16(q12, d8, d4[3]);
396 vmlal_s16(q13, d9, d4[3]);
397 vmlal_s16(q14, d8, d6[3]);
398 vmlal_s16(q15, d9, d6[3]);
399 beq(l2);
400
401 vld1_8({d8}, mem[r9]++);
402 vmovl_s8(q4, d8);
403 vmlal_s16(q8, d8, d1[0]);
404 vmlal_s16(q9, d9, d1[0]);
405 vmlal_s16(q10, d8, d3[0]);
406 vmlal_s16(q11, d9, d3[0]);
407 vmlal_s16(q12, d8, d5[0]);
408 vmlal_s16(q13, d9, d5[0]);
409 vmlal_s16(q14, d8, d7[0]);
410 vmlal_s16(q15, d9, d7[0]);
411 cmp(r5, 6);
412 blo(l2);
413
414 vld1_8({d8}, mem[r9]++);
415 vmovl_s8(q4, d8);
416 vmlal_s16(q8, d8, d1[1]);
417 vmlal_s16(q9, d9, d1[1]);
418 vmlal_s16(q10, d8, d3[1]);
419 vmlal_s16(q11, d9, d3[1]);
420 vmlal_s16(q12, d8, d5[1]);
421 vmlal_s16(q13, d9, d5[1]);
422 vmlal_s16(q14, d8, d7[1]);
423 vmlal_s16(q15, d9, d7[1]);
424 beq(l2);
425
426 vld1_8({d8}, mem[r9]++);
427 vmovl_s8(q4, d8);
428 vmlal_s16(q8, d8, d1[2]);
429 vmlal_s16(q9, d9, d1[2]);
430 vmlal_s16(q10, d8, d3[2]);
431 vmlal_s16(q11, d9, d3[2]);
432 vmlal_s16(q12, d8, d5[2]);
433 vmlal_s16(q13, d9, d5[2]);
434 vmlal_s16(q14, d8, d7[2]);
435 vmlal_s16(q15, d9, d7[2]);
436 b(l2);
437
438 // Store odd width
439 align(8);
440 bind(l4);
441 tst(r1, 4);
442 beq(l5);
443 vst1_32({d0[0]}, mem[r11]++);
444 vst1_32({d1[0]}, mem[r4]++);
445 vst1_32({d2[0]}, mem[r8]++);
446 vst1_32({d3[0]}, mem[r6]++);
447 vext_8(q0, q0, q0, 4);
448 vext_8(q1, q1, q1, 4);
449 bind(l5);
450 tst(r1, 2);
451 beq(l6);
452 vst1_16({d0[0]}, mem[r11]++);
453 vst1_16({d1[0]}, mem[r4]++);
454 vst1_16({d2[0]}, mem[r8]++);
455 vst1_16({d3[0]}, mem[r6]++);
456 vext_8(q0, q0, q0, 2);
457 vext_8(q1, q1, q1, 2);
458
459 bind(l6);
460 tst(r1, 1);
461 beq(l7);
462 vst1_8({d0[0]}, mem[r11]);
463 vst1_8({d1[0]}, mem[r4]);
464 vst1_8({d2[0]}, mem[r8]);
465 vst1_8({d3[0]}, mem[r6]);
466
467 bind(l7);
468 vpop({d8-d11});
469 pop({r4, r5, r6, r7, r8, r9, r10, r11});
470 bx(lr);
471 }
472 } // namespace
473 } // aarch32
474 } // xnnpack
475
xnn_generate_qc8_gemm_fp32_ukernel_4x8__aarch32_neonv8_mlal_lane_ld64(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)476 xnn_status_t xnn_generate_qc8_gemm_fp32_ukernel_4x8__aarch32_neonv8_mlal_lane_ld64(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
477 using namespace xnnpack::aarch32;
478 Generator g(code);
479 g.generate(false, max_mr, nc_mod_nr, kc, nullptr);
480 g.finalize();
481 if (g.error() != xnnpack::Error::kNoError) {
482 return xnn_status_invalid_state;
483 }
484 return xnn_status_success;
485 }
486
xnn_generate_qc8_gemm_fp32_ukernel_4x8__aarch32_neonv8_mlal_lane_prfm_ld64(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,const void * params)487 xnn_status_t xnn_generate_qc8_gemm_fp32_ukernel_4x8__aarch32_neonv8_mlal_lane_prfm_ld64(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, const void* params) {
488 using namespace xnnpack::aarch32;
489 Generator g(code);
490 g.generate(true, max_mr, nc_mod_nr, kc, nullptr);
491 g.finalize();
492 if (g.error() != xnnpack::Error::kNoError) {
493 return xnn_status_invalid_state;
494 }
495 return xnn_status_success;
496 }
497