1 // Copyright 2022 Google LLC
2 //
3 // This source code is licensed under the BSD-style license found in the
4 // LICENSE file in the root directory of this source tree.
5
6
7 #include <cassert>
8
9 #include <xnnpack.h>
10 #include <xnnpack/aarch32-assembler.h>
11 #include <xnnpack/allocator.h>
12 #include <xnnpack/igemm.h>
13
14 namespace xnnpack {
15 namespace aarch32 {
16 namespace {
17 class Generator : public Assembler {
18 using Assembler::Assembler;
19 public:
20 void generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params);
21 };
22
23
24 // void xnn_qs8_igemm_minmax_rndnu_ukernel_4x8__aarch32_neon_mlal_lane_prfm_ld64
25 // size_t mr, r0
26 // size_t nc, r1
27 // size_t kc, r2 -> r5 -> sp + 36
28 // size_t ks, r3 -> sp + 40 -> r14
29 // const int8_t**restrict a, sp + 80 -> r2
30 // const void*restrict w, sp + 84 -> r9
31 // int8_t*restrict c, sp + 88 -> r11
32 // size_t cm_stride, sp + 92 -> (r6)
33 // size_t cn_stride, sp + 96 -> (r7)
34 // size_t a_offset, sp + 100 -> (r5)
35 // const int8_t* zero, sp + 104 -> (r7)
36 // xnn_qs8_conv_minmax_params*params); sp + 108 -> (r5)
37
38 // d8-d15, r4-r11,r14(lr) need to be preserved if used. r13(sp),r15(pc) are reserved.
39
40 // Register usage
41
42 // A0 r3 d0-d1 q0
43 // A1 r12 d2-d3 q1
44 // A2 r10 d4-d5 q2
45 // A3 r0 d6-d7 q3
46
47 // B r9 d8-d9 q4
48
49 // C0 r11 d16-d17 q8 d18-d19 q9
50 // C1 r4 d20-d21 q10 d22-d23 q11
51 // C2 r8 d24-d25 q12 d26-d27 q13
52 // C3 r6 d28-d29 q14 d30-d31 q15
53
54 // Unused q6 q7
55
56 // params structure is 16 bytes
57 // struct {
58 // int32_t right_pre_shift; d10[0]
59 // int32_t multiplier; d10[1]
60 // int32_t right_post_shift; d11[0]
61 // int16_t output_zero_point; d11[2]
62 // int8_t output_min; d11[6]
63 // int8_t output_max; d11[7]
64 // } rndnu_neon;
65
66 // Converted from: src/qs8-igemm/gen/4x8-minmax-rndnu-aarch32-neon-mlal-lane-prfm-ld64.S
generate(bool prefetch,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)67 void Generator::generate(bool prefetch, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params)
68 {
69 assert(nc_mod_nr < 8);
70 assert(kc != 0);
71 assert(ks != 0);
72
73 Label l0, l1, l2, l3, l4, l5, l6, l7, l8;
74
75 // Push 80 bytes
76 // r2 will be reloaded in outer loop. r3 is ks
77 push({r2, r3, r4, r5, r6, r7, r8, r9, r10, r11, lr}); // +44
78 sub(sp, sp, 4); // 4
79 vpush({d8-d11}); // +32 = 80
80
81 ldr(r11, mem[sp, 88]); // c
82 ldr(r6, mem[sp, 92]); // cm_stride
83 ldr(r2, mem[sp, 80]); // a
84 ldr(r9, mem[sp, 84]); // w
85 ldr(r5, mem[sp, 108]); // params
86 mov(r14, r3); // p = ks
87
88 // Clamp C pointers
89 cmp(r0, 2); // if mr >= 2
90 add(r4, r11, r6); // c1 = c0 + cm_stride
91 movlo(r4, r11); // c1
92 // if mr > 2
93 add(r8, r4, r6); // c2 = c1 + cm_stride
94 movls(r8, r4); // c2
95 cmp(r0, 4); // if mr >=4
96 add(r6, r8, r6); // c3 = c2 + cm_stride
97 movlo(r6, r8); // c3
98
99 // Load params values
100 vldm(mem[r5], {d10-d11}); // RNDNU params
101
102 if (prefetch) {
103 pld(mem[r9, 64]); // Prefetch B
104 pld(mem[r9, 128]);
105 pld(mem[r9, 192]);
106 pld(mem[r9, 256]);
107 pld(mem[r9, 320]);
108 pld(mem[r9, 384]);
109 }
110
111 align(8);
112 bind(l0);
113 // Load initial bias from w into accumulators
114 vldm(mem[r9]++, {d16-d19}); // Bias
115 vmov(q10, q8);
116 vmov(q11, q9);
117 vmov(q12, q8);
118 vmov(q13, q9);
119 vmov(q14, q8);
120 vmov(q15, q9);
121
122 align(8);
123 bind(l1);
124 // Load next 4 A pointers
125 ldr(r3, mem[r2, 0]);
126 ldr(r12, mem[r2, 4]);
127 ldr(r10, mem[r2, 8]);
128 ldr(r0, mem[r2, 12]);
129 add(r2, r2, 16);
130
131 if (prefetch) {
132 pld(mem[r3, 64]);
133 pld(mem[r12, 64]);
134 pld(mem[r10, 64]);
135 pld(mem[r0, 64]);
136 }
137
138 // Add a_offset
139 ldr(r5, mem[sp, 100]); // a_offset
140 ldr(r7, mem[sp, 104]); // zero
141 cmp(r3, r7); // if a0 == zero
142 add(r3, r3, r5); // a0 += a_offset
143 moveq(r3, r7); // a0 = zero, else += a0 + a_offset
144 cmp(r12, r7); // if a1 == zero
145 add(r12, r12, r5); // a1 += a_offset
146 moveq(r12, r7); // a1 = zero, else += a1 + a_offset
147 cmp(r10, r7); // if a2 == zero
148 add(r10, r10, r5); // a2 += a_offset
149 moveq(r10, r7); // a2 = zero, else += a2 + a_offset
150 cmp(r0, r7); // if a3 == zero
151 add(r0, r0, r5); // a3 += a_offset
152 ldr(r5, mem[sp, 36]); // kc
153 moveq(r0, r7); // a3 = zero, else += a3 + a_offset
154
155 subs(r5, r5, 8); // kc - 8
156 blo(l4); // less than 8 channels?
157
158 // Main loop - 8 bytes
159 // 64 bytes for weights.
160 align(8);
161 bind(l2);
162 vld1_8({d0}, mem[r3]++); // A0
163 vld1_8({d8}, mem[r9]++); // B
164 vld1_8({d2}, mem[r12]++); // A1
165 vld1_8({d4}, mem[r10]++); // A2
166 vld1_8({d6}, mem[r0]++); // A3
167 subs(r5, r5, 8);
168 if (prefetch) {
169 pld(mem[r3, 128]);
170 }
171 vmovl_s8(q0, d0);
172 if (prefetch) {
173 pld(mem[r12, 128]);
174 }
175 vmovl_s8(q4, d8);
176 if (prefetch) {
177 pld(mem[r10, 128]);
178 }
179 vmovl_s8(q1, d2);
180 if (prefetch) {
181 pld(mem[r0, 128]);
182 }
183 vmovl_s8(q2, d4);
184 if (prefetch) {
185 pld(mem[r9, 448]);
186 }
187 vmovl_s8(q3, d6);
188 vmlal_s16(q8, d8, d0[0]);
189 vmlal_s16(q9, d9, d0[0]);
190 vmlal_s16(q10, d8, d2[0]);
191 vmlal_s16(q11, d9, d2[0]);
192 vmlal_s16(q12, d8, d4[0]);
193 vmlal_s16(q13, d9, d4[0]);
194 vmlal_s16(q14, d8, d6[0]);
195 vmlal_s16(q15, d9, d6[0]);
196
197 vld1_8({d8}, mem[r9]++);
198 vmovl_s8(q4, d8);
199 vmlal_s16(q8, d8, d0[1]);
200 vmlal_s16(q9, d9, d0[1]);
201 vmlal_s16(q10, d8, d2[1]);
202 vmlal_s16(q11, d9, d2[1]);
203 vmlal_s16(q12, d8, d4[1]);
204 vmlal_s16(q13, d9, d4[1]);
205 vmlal_s16(q14, d8, d6[1]);
206 vmlal_s16(q15, d9, d6[1]);
207
208 vld1_8({d8}, mem[r9]++);
209 vmovl_s8(q4, d8);
210 vmlal_s16(q8, d8, d0[2]);
211 vmlal_s16(q9, d9, d0[2]);
212 vmlal_s16(q10, d8, d2[2]);
213 vmlal_s16(q11, d9, d2[2]);
214 vmlal_s16(q12, d8, d4[2]);
215 vmlal_s16(q13, d9, d4[2]);
216 vmlal_s16(q14, d8, d6[2]);
217 vmlal_s16(q15, d9, d6[2]);
218
219 vld1_8({d8}, mem[r9]++);
220 vmovl_s8(q4, d8);
221 vmlal_s16(q8, d8, d0[3]);
222 vmlal_s16(q9, d9, d0[3]);
223 vmlal_s16(q10, d8, d2[3]);
224 vmlal_s16(q11, d9, d2[3]);
225 vmlal_s16(q12, d8, d4[3]);
226 vmlal_s16(q13, d9, d4[3]);
227 vmlal_s16(q14, d8, d6[3]);
228 vmlal_s16(q15, d9, d6[3]);
229
230 vld1_8({d8}, mem[r9]++);
231 vmovl_s8(q4, d8);
232 vmlal_s16(q8, d8, d1[0]);
233 vmlal_s16(q9, d9, d1[0]);
234 vmlal_s16(q10, d8, d3[0]);
235 vmlal_s16(q11, d9, d3[0]);
236 vmlal_s16(q12, d8, d5[0]);
237 vmlal_s16(q13, d9, d5[0]);
238 vmlal_s16(q14, d8, d7[0]);
239 vmlal_s16(q15, d9, d7[0]);
240
241 vld1_8({d8}, mem[r9]++);
242 vmovl_s8(q4, d8);
243 vmlal_s16(q8, d8, d1[1]);
244 vmlal_s16(q9, d9, d1[1]);
245 vmlal_s16(q10, d8, d3[1]);
246 vmlal_s16(q11, d9, d3[1]);
247 vmlal_s16(q12, d8, d5[1]);
248 vmlal_s16(q13, d9, d5[1]);
249 vmlal_s16(q14, d8, d7[1]);
250 vmlal_s16(q15, d9, d7[1]);
251
252 vld1_8({d8}, mem[r9]++);
253 vmovl_s8(q4, d8);
254 vmlal_s16(q8, d8, d1[2]);
255 vmlal_s16(q9, d9, d1[2]);
256 vmlal_s16(q10, d8, d3[2]);
257 vmlal_s16(q11, d9, d3[2]);
258 vmlal_s16(q12, d8, d5[2]);
259 vmlal_s16(q13, d9, d5[2]);
260 vmlal_s16(q14, d8, d7[2]);
261 vmlal_s16(q15, d9, d7[2]);
262
263 vld1_8({d8}, mem[r9]++);
264 vmovl_s8(q4, d8);
265 vmlal_s16(q8, d8, d1[3]);
266 vmlal_s16(q9, d9, d1[3]);
267 vmlal_s16(q10, d8, d3[3]);
268 vmlal_s16(q11, d9, d3[3]);
269 vmlal_s16(q12, d8, d5[3]);
270 vmlal_s16(q13, d9, d5[3]);
271 vmlal_s16(q14, d8, d7[3]);
272 vmlal_s16(q15, d9, d7[3]);
273 bhs(l2);
274
275 // Is there a remainder?- 1-7 bytes of A
276 adds(r5, r5, 8);
277 bne(l4);
278
279 bind(l3);
280 // ks loop
281 subs(r14, r14, 16); // ks -= MR * sizeof(void*)
282 bhi(l1);
283
284 ldr(r7, mem[sp, 96]); // cn_stride
285 ldr(r14, mem[sp, 40]); // p = ks
286
287 // RNDNU quantization
288 vdup_32(q0, d10[0]); // right_pre_shift
289
290 vqshl_s32(q8, q8, q0);
291 vqshl_s32(q9, q9, q0);
292 vqshl_s32(q10, q10, q0);
293 vqshl_s32(q11, q11, q0);
294 vqshl_s32(q12, q12, q0);
295 vqshl_s32(q13, q13, q0);
296 vqshl_s32(q14, q14, q0);
297 vqshl_s32(q15, q15, q0);
298
299 vdup_32(q2, d11[0]); // right_post_shift
300
301 vqdmulh_s32(q8, q8, d10[1]); // multiplier
302 vqdmulh_s32(q9, q9, d10[1]);
303 vqdmulh_s32(q10, q10, d10[1]);
304 vqdmulh_s32(q11, q11, d10[1]);
305 vqdmulh_s32(q12, q12, d10[1]);
306 vqdmulh_s32(q13, q13, d10[1]);
307 vqdmulh_s32(q14, q14, d10[1]);
308 vqdmulh_s32(q15, q15, d10[1]);
309
310 vrshl_s32(q8, q8, q2);
311 vrshl_s32(q9, q9, q2);
312 vrshl_s32(q10, q10, q2);
313 vrshl_s32(q11, q11, q2);
314 vrshl_s32(q12, q12, q2);
315 vrshl_s32(q13, q13, q2);
316 vrshl_s32(q14, q14, q2);
317 vrshl_s32(q15, q15, q2);
318
319 vdup_16(q0, d11[2]); // output_zero_point
320
321 vqmovn_s32(d16, q8);
322 vqmovn_s32(d17, q9);
323 vqmovn_s32(d18, q10);
324 vqmovn_s32(d19, q11);
325 vqmovn_s32(d20, q12);
326 vqmovn_s32(d21, q13);
327 vqmovn_s32(d22, q14);
328 vqmovn_s32(d23, q15);
329
330 vqadd_s16(q8, q8, q0);
331 vqadd_s16(q9, q9, q0);
332 vqadd_s16(q10, q10, q0);
333 vqadd_s16(q11, q11, q0);
334
335 vdup_8(q12, d11[6]); // output_min
336
337 vqmovn_s16(d0, q8);
338 vqmovn_s16(d1, q9);
339 vqmovn_s16(d2, q10);
340 vqmovn_s16(d3, q11);
341
342 vdup_8(q13, d11[7]); // output_max
343
344 vmax_s8(q0, q0, q12);
345 vmax_s8(q1, q1, q12);
346
347 subs(r1, r1, 8); // nc -= 8
348
349 vmin_s8(q0, q0, q13);
350 vmin_s8(q1, q1, q13);
351
352 // Store full 4 x 8
353 blo(l5);
354 vst1_8({d3}, mem[r6], r7);
355 vst1_8({d2}, mem[r8], r7);
356 vst1_8({d1}, mem[r4], r7);
357 vst1_8({d0}, mem[r11], r7);
358 sub(r2, r2, r14); // a -= ks
359 bhi(l0);
360
361 vpop({d8-d11});
362 add(sp, sp, 12); // skip pad, r2, r3
363 pop({r4, r5, r6, r7, r8, r9, r10, r11, pc});
364
365 // Remainder- 1 to 7 bytes of A
366 align(8);
367 bind(l4);
368 and_(r5, r5, 7); // kc remainder 1 to 7
369
370 vld1_8({d0}, mem[r3]);
371 vld1_8({d8}, mem[r9]++);
372 vld1_8({d2}, mem[r12]);
373 vld1_8({d4}, mem[r10]);
374 vld1_8({d6}, mem[r0]);
375
376 vmovl_s8(q0, d0);
377 vmovl_s8(q4, d8);
378 vmovl_s8(q1, d2);
379 vmovl_s8(q2, d4);
380 vmovl_s8(q3, d6);
381 vmlal_s16(q8, d8, d0[0]);
382 vmlal_s16(q9, d9, d0[0]);
383 vmlal_s16(q10, d8, d2[0]);
384 vmlal_s16(q11, d9, d2[0]);
385 vmlal_s16(q12, d8, d4[0]);
386 vmlal_s16(q13, d9, d4[0]);
387 vmlal_s16(q14, d8, d6[0]);
388 vmlal_s16(q15, d9, d6[0]);
389 cmp(r5, 2);
390 blo(l3);
391
392 vld1_8({d8}, mem[r9]++);
393 vmovl_s8(q4, d8);
394 vmlal_s16(q8, d8, d0[1]);
395 vmlal_s16(q9, d9, d0[1]);
396 vmlal_s16(q10, d8, d2[1]);
397 vmlal_s16(q11, d9, d2[1]);
398 vmlal_s16(q12, d8, d4[1]);
399 vmlal_s16(q13, d9, d4[1]);
400 vmlal_s16(q14, d8, d6[1]);
401 vmlal_s16(q15, d9, d6[1]);
402 beq(l3);
403
404 vld1_8({d8}, mem[r9]++);
405 vmovl_s8(q4, d8);
406 vmlal_s16(q8, d8, d0[2]);
407 vmlal_s16(q9, d9, d0[2]);
408 vmlal_s16(q10, d8, d2[2]);
409 vmlal_s16(q11, d9, d2[2]);
410 vmlal_s16(q12, d8, d4[2]);
411 vmlal_s16(q13, d9, d4[2]);
412 vmlal_s16(q14, d8, d6[2]);
413 vmlal_s16(q15, d9, d6[2]);
414 cmp(r5, 4);
415 blo(l3);
416
417 vld1_8({d8}, mem[r9]++);
418 vmovl_s8(q4, d8);
419 vmlal_s16(q8, d8, d0[3]);
420 vmlal_s16(q9, d9, d0[3]);
421 vmlal_s16(q10, d8, d2[3]);
422 vmlal_s16(q11, d9, d2[3]);
423 vmlal_s16(q12, d8, d4[3]);
424 vmlal_s16(q13, d9, d4[3]);
425 vmlal_s16(q14, d8, d6[3]);
426 vmlal_s16(q15, d9, d6[3]);
427 beq(l3);
428
429 vld1_8({d8}, mem[r9]++);
430 vmovl_s8(q4, d8);
431 vmlal_s16(q8, d8, d1[0]);
432 vmlal_s16(q9, d9, d1[0]);
433 vmlal_s16(q10, d8, d3[0]);
434 vmlal_s16(q11, d9, d3[0]);
435 vmlal_s16(q12, d8, d5[0]);
436 vmlal_s16(q13, d9, d5[0]);
437 vmlal_s16(q14, d8, d7[0]);
438 vmlal_s16(q15, d9, d7[0]);
439 cmp(r5, 6);
440 blo(l3);
441
442 vld1_8({d8}, mem[r9]++);
443 vmovl_s8(q4, d8);
444 vmlal_s16(q8, d8, d1[1]);
445 vmlal_s16(q9, d9, d1[1]);
446 vmlal_s16(q10, d8, d3[1]);
447 vmlal_s16(q11, d9, d3[1]);
448 vmlal_s16(q12, d8, d5[1]);
449 vmlal_s16(q13, d9, d5[1]);
450 vmlal_s16(q14, d8, d7[1]);
451 vmlal_s16(q15, d9, d7[1]);
452 beq(l3);
453
454 vld1_8({d8}, mem[r9]++);
455 vmovl_s8(q4, d8);
456 vmlal_s16(q8, d8, d1[2]);
457 vmlal_s16(q9, d9, d1[2]);
458 vmlal_s16(q10, d8, d3[2]);
459 vmlal_s16(q11, d9, d3[2]);
460 vmlal_s16(q12, d8, d5[2]);
461 vmlal_s16(q13, d9, d5[2]);
462 vmlal_s16(q14, d8, d7[2]);
463 vmlal_s16(q15, d9, d7[2]);
464 b(l3);
465
466 // Store odd width
467 align(8);
468 bind(l5);
469 tst(r1, 4);
470 beq(l6);
471 vst1_32({d3[0]}, mem[r6]++);
472 vst1_32({d2[0]}, mem[r8]++);
473 vst1_32({d1[0]}, mem[r4]++);
474 vst1_32({d0[0]}, mem[r11]++);
475 vext_8(q1, q1, q1, 4);
476 vext_8(q0, q0, q0, 4);
477 bind(l6);
478 tst(r1, 2);
479 beq(l7);
480 vst1_16({d3[0]}, mem[r6]++);
481 vst1_16({d2[0]}, mem[r8]++);
482 vst1_16({d1[0]}, mem[r4]++);
483 vst1_16({d0[0]}, mem[r11]++);
484 vext_8(q1, q1, q1, 2);
485 vext_8(q0, q0, q0, 2);
486
487 bind(l7);
488 tst(r1, 1);
489 beq(l8);
490 vst1_8({d3[0]}, mem[r6]);
491 vst1_8({d2[0]}, mem[r8]);
492 vst1_8({d1[0]}, mem[r4]);
493 vst1_8({d0[0]}, mem[r11]);
494
495 bind(l8);
496 vpop({d8-d11});
497 add(sp, sp, 12); // skip pad, r2, r3
498 pop({r4, r5, r6, r7, r8, r9, r10, r11, pc});
499 }
500 } // namespace
501 } // aarch32
502 } // xnnpack
503
xnn_generate_qs8_igemm_rndnu_ukernel_4x8__aarch32_neon_mlal_lane_ld64(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)504 xnn_status_t xnn_generate_qs8_igemm_rndnu_ukernel_4x8__aarch32_neon_mlal_lane_ld64(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
505 using namespace xnnpack::aarch32;
506 Generator g(code);
507 g.generate(false, max_mr, nc_mod_nr, kc, ks, nullptr);
508 g.finalize();
509 if (g.error() != xnnpack::Error::kNoError) {
510 return xnn_status_invalid_state;
511 }
512 return xnn_status_success;
513 }
514
xnn_generate_qs8_igemm_rndnu_ukernel_4x8__aarch32_neon_mlal_lane_prfm_ld64(xnn_code_buffer * code,size_t max_mr,size_t nc_mod_nr,size_t kc,size_t ks,const void * params)515 xnn_status_t xnn_generate_qs8_igemm_rndnu_ukernel_4x8__aarch32_neon_mlal_lane_prfm_ld64(xnn_code_buffer* code, size_t max_mr, size_t nc_mod_nr, size_t kc, size_t ks, const void* params) {
516 using namespace xnnpack::aarch32;
517 Generator g(code);
518 g.generate(true, max_mr, nc_mod_nr, kc, ks, nullptr);
519 g.finalize();
520 if (g.error() != xnnpack::Error::kNoError) {
521 return xnn_status_invalid_state;
522 }
523 return xnn_status_success;
524 }
525