xref: /aosp_15_r20/external/XNNPACK/src/qs8-igemm/2x8c8-aarch64-neon-mlal-cortex-a53.S.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert REQUANTIZATION in ["FP32", "RNDNU"]
7$assert not CHANNELWISE or REQUANTIZATION == "FP32"
8
9#include <xnnpack/assembly.h>
10
11$DATATYPE = "qc8" if CHANNELWISE else "qs8"
12$PARAMS_UNION = "xnn_qs8_minmax_params" if CHANNELWISE else "xnn_qs8_conv_minmax_params"
13$REWIND_DECREMENT = 3 if CHANNELWISE else {"RNDNU": 15, "FP32": 7}[REQUANTIZATION]
14# void xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c8__aarch64_neon_mlal${"_prfm" if PREFETCH else ""}_cortex_a53(
15#     size_t mr,                 x0
16#     size_t nc,                 x1
17#     size_t kc,                 x2 / x0
18#     size_t ks,                 x3 / x9
19#     const int8_t**restrict a,  x4
20#     const int8_t* restrict w,  x5
21#     int8_t* restrict c,        x6
22#     size_t cm_stride,          x7
23#     size_t cn_stride,                  [sp] -> x10
24#     size_t a_offset,                   [sp + 8] -> x8
25#     const int8_t* zero,                [sp + 16] -> x12
26#     const union ${PARAMS_UNION} params [sp + 24] -> x11
27
28# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
29
30# Register usage
31# A0 x13  v0  v6
32# A1 x15  v1  v7
33# B   x5  v4  v5  v8  v9
34# C0  x6 v16 v18 v20 v22 v24 v26 v28 v30
35# C1  x7 v17 v19 v21 v23 v25 v27 v29 v31
36# temp0   v2 v10 v12 v14
37# temp1   v3 v11 v13 v15
38# x16, x17, x20, x21 tenporary a53 gpr load data
39
40
41BEGIN_FUNCTION xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c8__aarch64_neon_mlal${"_prfm" if PREFETCH else ""}_cortex_a53
42
43        # Clamp C pointers
44        LDP     x10, x8, [sp]           // Load cn_stride, a_offset
45        CMP     x0, 2                   // if mr < 2
46        LDP     x12, x11, [sp, 16]      // Load zero, params pointer
47        ADD     x7, x6, x7              // c1 = c0 + cm_stride
48        STP     d8, d9, [sp, -80]!
49        ADD     x2, x2, 7               // kc = (kc + 7) & ~7
50        STP     d10, d11, [sp, 16]
51        CSEL    x7, x6, x7, LO          //   c1 = c0
52        STP     d12, d13, [sp, 32]
53        BIC     x2, x2, 7
54        STP     d14, d15, [sp, 48]
55        STP     x20, x21, [sp, 64]      // Save x20,x21 on stack
56
57        .p2align 3
580:
59        # Load initial bias from w into accumulators
60        LDP     s16, s18, [x5], 8
61        MOV     v17.16b, v16.16b
62        MOV     v19.16b, v18.16b
63        LDP     s20, s22, [x5], 8
64        MOV     v21.16b, v20.16b
65        MOV     v23.16b, v22.16b
66        LDP     s24, s26, [x5], 8
67        MOV     v25.16b, v24.16b
68        MOV     v27.16b, v26.16b
69        LDP     s28, s30, [x5], 8
70        MOV     v29.16b, v28.16b
71        MOV     v31.16b, v30.16b
72        MOV     x9, x3                  // p = ks
73
74        .p2align 3
751:
76        # Load next 2 A pointers
77        LDP     x13, x15, [x4], 16
78        CMP     x13, x12                // if a0 == zero
79        ADD     x13, x13, x8            // a0 += a_offset
80        CSEL    x13, x12, x13, EQ       //   a0 = zero, else += a0 + a_offset
81        CMP     x15, x12                // if a1 == zero
82        ADD     x15, x15, x8            // a1 += a_offset
83        CSEL    x15, x12, x15, EQ       //   a1 = zero, else += a1 + a_offset
84
85        # Is there at least 16 bytes for epilogue?
86        SUBS    x0, x2, 16              // k = kc - 16
87        B.LO    5f
88
89        # Prologue: load A0, A1 and 2 B's
90        LDP     d4, d5, [x5]            // Read B
91        LDP     d0, d6, [x13], 16
92        LDP     d1, d7, [x15], 16
93//        LDP     d8, d9, [x5, 64]
94        LDR     x17, [x5, 64]           // Read B
95        LDR     x16, [x5, 16]
96
97        # Is there at least 16 bytes for main loop?
98        SUBS    x0, x0, 16              // k = k - 16
99        B.LO    3f
100
101         # Main loop - 16 bytes of A
102         # 4 groups of 4 mul/mla/adap + 2 load = 18 cycles.
103         # 2 loads for A0 = +2 cycles.  Total 18 * 4 + 2 = 74 cycles.
104
105        .p2align 3
1062:
107        # BLOCK 0 - 18 cycles - includes prfm
108        LDR     d9, [x5, 72]            // Read B
109        INS     v8.d[0], x17
110        SMULL   v2.8h, v4.8b, v0.8b
111        SMULL   v3.8h, v4.8b, v1.8b
112        LDR     x17, [x5, 80]
113        SMULL   v10.8h, v5.8b, v0.8b
114        SMULL   v11.8h, v5.8b, v1.8b
115        LDR     d5, [x5, 24]
116        INS     v4.d[0], x16
117        SMLAL   v2.8h, v8.8b, v6.8b
118        SMLAL   v3.8h, v8.8b, v7.8b
119        LDR     x16, [x5, 32]
120        SMLAL   v10.8h, v9.8b, v6.8b
121        SMLAL   v11.8h, v9.8b, v7.8b
122        $if PREFETCH:
123          PRFM    PLDL1KEEP, [x5, 448]
124        SADALP  v16.4s,  v2.8h
125        SADALP  v17.4s,  v3.8h
126        $if PREFETCH:
127          PRFM    PLDL1KEEP, [x5, 512]
128        SADALP  v18.4s, v10.8h
129        SADALP  v19.4s, v11.8h
130
131        # BLOCK 1- 18 cycles
132        LDR     d9, [x5, 88]
133        INS     v8.d[0], x17
134        SMULL   v12.8h, v4.8b, v0.8b
135        SMULL   v13.8h, v4.8b, v1.8b
136        LDR     x17, [x5, 96]
137        SMULL   v14.8h, v5.8b, v0.8b
138        SMULL   v15.8h, v5.8b, v1.8b
139        LDR     d5, [x5, 40]
140        INS     v4.d[0], x16
141        SMLAL   v12.8h, v8.8b, v6.8b
142        SMLAL   v13.8h, v8.8b, v7.8b
143        LDR     x16, [x5, 48]
144        SMLAL   v14.8h, v9.8b, v6.8b
145        SMLAL   v15.8h, v9.8b, v7.8b
146        $if PREFETCH:
147          PRFM    PLDL1KEEP, [x13, 128]
148        SADALP  v20.4s, v12.8h
149        SADALP  v21.4s, v13.8h
150        $if PREFETCH:
151          PRFM    PLDL1KEEP, [x15, 128]
152        SADALP  v22.4s, v14.8h
153        SADALP  v23.4s, v15.8h
154
155        # BLOCK 2 - 18 cycles
156        LDR     d9, [x5, 104]
157        INS     v8.d[0], x17
158        SMULL   v2.8h, v4.8b, v0.8b
159        SMULL   v3.8h, v4.8b, v1.8b
160        LDR     x17, [x5, 112]
161        SMULL   v10.8h, v5.8b, v0.8b
162        SMULL   v11.8h, v5.8b, v1.8b
163        LDR     d5, [x5, 56]
164        INS     v4.d[0], x16
165        SMLAL   v2.8h, v8.8b, v6.8b
166        SMLAL   v3.8h, v8.8b, v7.8b
167        LDR     x16, [x5, 128]
168        SMLAL   v10.8h, v9.8b, v6.8b
169        SMLAL   v11.8h, v9.8b, v7.8b
170        SADALP  v24.4s,  v2.8h
171        LDR     x20, [x13], 8           // Read A0
172        SADALP  v25.4s,  v3.8h
173        LDR     x21, [x15], 8           // Read A1
174        SADALP  v26.4s, v10.8h
175        SADALP  v27.4s, v11.8h
176        SUBS    x0, x0, 16
177
178        # BLOCK 3 - includes 2 cycles to read A0, A1 = 20 cycles
179        LDR     d9, [x5, 120]
180        INS     v8.d[0], x17
181        SMULL   v12.8h, v4.8b, v0.8b
182        SMULL   v13.8h, v4.8b, v1.8b
183        LDR     x17, [x5, 192]          // Read B
184        SMULL   v14.8h, v5.8b, v0.8b
185        SMULL   v15.8h, v5.8b, v1.8b
186        LDR     d5, [x5, 136]           // Read B
187        INS     v4.d[0], x16
188        SMLAL   v12.8h, v8.8b, v6.8b
189        SMLAL   v13.8h, v8.8b, v7.8b
190        LDR     x16, [x5, 144]
191        SMLAL   v14.8h, v9.8b, v6.8b
192        SMLAL   v15.8h, v9.8b, v7.8b
193        LDR     d6, [x13], 8            // Read A0
194        INS     v0.d[0], x20
195        LDR     d7, [x15], 8            // Read A1
196        INS     v1.d[0], x21
197        SADALP  v28.4s, v12.8h
198        SADALP  v29.4s, v13.8h
199        ADD     x5, x5, 128
200        SADALP  v30.4s, v14.8h
201        SADALP  v31.4s, v15.8h
202        B.HS    2b
203
204        # Epilogue
205        # Same as main loop except no loads at end of loop
206        .p2align 3
2073:
208        # BLOCK 0 - 18 cycles
209        LDR     d9, [x5, 72]            // Read B
210        INS     v8.d[0], x17
211        SMULL   v2.8h, v4.8b, v0.8b
212        SMULL   v3.8h, v4.8b, v1.8b
213        LDR     x17, [x5, 80]
214        SMULL   v10.8h, v5.8b, v0.8b
215        SMULL   v11.8h, v5.8b, v1.8b
216        LDR     d5, [x5, 24]
217        INS     v4.d[0], x16
218        SMLAL   v2.8h, v8.8b, v6.8b
219        SMLAL   v3.8h, v8.8b, v7.8b
220        LDR     x16, [x5, 32]
221        SMLAL   v10.8h, v9.8b, v6.8b
222        SMLAL   v11.8h, v9.8b, v7.8b
223        SADALP  v16.4s,  v2.8h
224        SADALP  v17.4s,  v3.8h
225        SADALP  v18.4s, v10.8h
226        SADALP  v19.4s, v11.8h
227
228        # BLOCK 1- 18 cycles
229        LDR     d9, [x5, 88]
230        INS     v8.d[0], x17
231        SMULL   v12.8h, v4.8b, v0.8b
232        SMULL   v13.8h, v4.8b, v1.8b
233        LDR     x17, [x5, 96]
234        SMULL   v14.8h, v5.8b, v0.8b
235        SMULL   v15.8h, v5.8b, v1.8b
236        LDR     d5, [x5, 40]
237        INS     v4.d[0], x16
238        SMLAL   v12.8h, v8.8b, v6.8b
239        SMLAL   v13.8h, v8.8b, v7.8b
240        LDR     x16, [x5, 48]
241        SMLAL   v14.8h, v9.8b, v6.8b
242        SMLAL   v15.8h, v9.8b, v7.8b
243        SADALP  v20.4s, v12.8h
244        SADALP  v21.4s, v13.8h
245        SADALP  v22.4s, v14.8h
246        SADALP  v23.4s, v15.8h
247
248        # BLOCK 2 - 18 cycles
249        LDR     d9, [x5, 104]
250        INS     v8.d[0], x17
251        SMULL   v2.8h, v4.8b, v0.8b
252        SMULL   v3.8h, v4.8b, v1.8b
253        LDR     x17, [x5, 112]
254        SMULL   v10.8h, v5.8b, v0.8b
255        SMULL   v11.8h, v5.8b, v1.8b
256        LDR     d5, [x5, 56]
257        INS     v4.d[0], x16
258        SMLAL   v2.8h, v8.8b, v6.8b
259        SMLAL   v3.8h, v8.8b, v7.8b
260        SMLAL   v10.8h, v9.8b, v6.8b
261        SMLAL   v11.8h, v9.8b, v7.8b
262        SADALP  v24.4s,  v2.8h
263        SADALP  v25.4s,  v3.8h
264        SADALP  v26.4s, v10.8h
265        SADALP  v27.4s, v11.8h
266
267        # BLOCK 3 - 17 cycles
268        LDR     d9, [x5, 120]
269        INS     v8.d[0], x17
270        SMULL   v12.8h, v4.8b, v0.8b
271        SMULL   v13.8h, v4.8b, v1.8b
272        SMULL   v14.8h, v5.8b, v0.8b
273        SMULL   v15.8h, v5.8b, v1.8b
274        SMLAL   v12.8h, v8.8b, v6.8b
275        SMLAL   v13.8h, v8.8b, v7.8b
276        SMLAL   v14.8h, v9.8b, v6.8b
277        SMLAL   v15.8h, v9.8b, v7.8b
278        SADALP  v28.4s, v12.8h
279        SADALP  v29.4s, v13.8h
280        ADD     x5, x5, 128
281        SADALP  v30.4s, v14.8h
282        SADALP  v31.4s, v15.8h
283
284        # Is there a remainder?- 8 bytes of A
285        TBNZ    x0, 3, 5f
286
287        # ks loop
288        SUBS    x9, x9, 16              // ks -= MR * sizeof(int8_t*)
289        B.HI    1b
290
2914:
292        # Add columns
293        ADDP    v16.4s, v16.4s, v18.4s
294        ADDP    v20.4s, v20.4s, v22.4s
295        $if REQUANTIZATION == "RNDNU":
296          LD1R    {v4.4s}, [x11], 4
297        ADDP    v24.4s, v24.4s, v26.4s
298        ADDP    v28.4s, v28.4s, v30.4s
299        $if REQUANTIZATION == "RNDNU":
300          LD1R    {v7.4s}, [x11], 4
301        ADDP    v17.4s, v17.4s, v19.4s
302        ADDP    v21.4s, v21.4s, v23.4s
303        ADDP    v25.4s, v25.4s, v27.4s
304        ADDP    v29.4s, v29.4s, v31.4s
305        ADDP    v0.4s, v16.4s, v20.4s
306        ADDP    v1.4s, v24.4s, v28.4s
307        ADDP    v2.4s, v17.4s, v21.4s
308        ADDP    v3.4s, v25.4s, v29.4s
309
310        $if REQUANTIZATION == "RNDNU":
311          # Apply params - preshift, scale, postshift, bias and clamp
312          LD1R    {v5.4s}, [x11], 4
313          SQSHL   v0.4s, v0.4s, v4.4s     // shift to upper bits
314          SQSHL   v1.4s, v1.4s, v4.4s
315          SQSHL   v2.4s, v2.4s, v4.4s
316          SQSHL   v3.4s, v3.4s, v4.4s
317          SQDMULH v0.4s, v0.4s, v7.4s     // scale without rounding
318          SQDMULH v1.4s, v1.4s, v7.4s
319          SQDMULH v2.4s, v2.4s, v7.4s
320          SQDMULH v3.4s, v3.4s, v7.4s
321          SRSHL   v0.4s, v0.4s, v5.4s     // signed rounding shift left
322          SRSHL   v1.4s, v1.4s, v5.4s
323          SRSHL   v2.4s, v2.4s, v5.4s
324          SRSHL   v3.4s, v3.4s, v5.4s
325        $elif REQUANTIZATION == "FP32":
326          $if not CHANNELWISE:
327            # Apply params - scale, bias and clamp
328            SCVTF   v0.4s, v0.4s
329            LD1R    {v4.4s}, [x11], 4
330            SCVTF   v1.4s, v1.4s
331            SCVTF   v2.4s, v2.4s
332            SCVTF   v3.4s, v3.4s
333            FMUL    v0.4s, v0.4s, v4.4s
334            FMUL    v1.4s, v1.4s, v4.4s
335            FMUL    v2.4s, v2.4s, v4.4s
336            FMUL    v3.4s, v3.4s, v4.4s
337          $else:
338            # Load per channel scale values from weights
339            SCVTF   v0.4s, v0.4s
340            LDR     q4, [x5], 16
341            SCVTF   v1.4s, v1.4s
342            LDR     q5, [x5], 16
343            SCVTF   v2.4s, v2.4s
344            SCVTF   v3.4s, v3.4s
345            FMUL    v0.4s, v0.4s, v4.4s
346            FMUL    v1.4s, v1.4s, v5.4s
347            FMUL    v2.4s, v2.4s, v4.4s
348            FMUL    v3.4s, v3.4s, v5.4s
349
350          FCVTNS  v0.4s, v0.4s
351          FCVTNS  v1.4s, v1.4s
352          FCVTNS  v2.4s, v2.4s
353          FCVTNS  v3.4s, v3.4s
354
355        LD1R    {v5.8h}, [x11], 2
356        SQXTN   v0.4h, v0.4s
357        SQXTN   v2.4h, v2.4s
358        SQXTN2  v0.8h, v1.4s
359        SQXTN2  v2.8h, v3.4s
360        SUBS    x1, x1, 8
361        SQADD   v0.8h, v0.8h, v5.8h
362        SQADD   v1.8h, v2.8h, v5.8h
363        SQXTN   v0.8b, v0.8h
364        SQXTN2  v0.16b, v1.8h
365        LD1R    {v1.16b}, [x11], 1
366        LD1R    {v2.16b}, [x11]
367        SMAX    v0.16b, v0.16b, v1.16b
368        SUB     x11, x11, ${REWIND_DECREMENT}          // rewind params pointer
369        SMIN    v0.16b, v0.16b, v2.16b
370        B.LO    6f
371
372        # Store full 2 x 8
373        ST1     {v0.d}[1], [x7], x10
374        ST1     {v0.8b}, [x6], x10
375
376        SUB     x4, x4, x3              // a -= ks
377
378        # nc loop
379        B.HI    0b
380
381        # Restore x20,x21 from stack
382        LDP     x20, x21, [sp, 64]
383
384        # Restore d8-d15 from stack
385        LDP     d14, d15, [sp, 48]
386        LDP     d12, d13, [sp, 32]
387        LDP     d10, d11, [sp, 16]
388        LDP     d8, d9, [sp], 80
389        RET
390
391        # Remainder - 8 bytes of A
392        .p2align 3
3935:
394        LDR     d0, [x13], 8
395        LDP     d4, d5, [x5]
396        LDR     d1, [x15], 8
397        LDP     d6, d7, [x5, 16]
398        SMULL   v2.8h, v4.8b, v0.8b
399        SMULL   v3.8h, v4.8b, v1.8b
400        SMULL   v10.8h, v5.8b, v0.8b
401        SMULL   v11.8h, v5.8b, v1.8b
402        SMULL   v12.8h, v6.8b, v0.8b
403        SADALP  v16.4s,  v2.8h
404        SMULL   v13.8h, v6.8b, v1.8b
405        SADALP  v17.4s,  v3.8h
406        SMULL   v14.8h, v7.8b, v0.8b
407        SADALP  v18.4s, v10.8h
408        SMULL   v15.8h, v7.8b, v1.8b
409        SADALP  v19.4s, v11.8h
410        LDP     d4, d5, [x5, 32]
411        SMULL   v2.8h, v4.8b, v0.8b
412        SADALP  v20.4s, v12.8h
413        SMULL   v3.8h, v4.8b, v1.8b
414        SADALP  v21.4s, v13.8h
415        SMULL   v10.8h, v5.8b, v0.8b
416        SADALP  v22.4s, v14.8h
417        SMULL   v11.8h, v5.8b, v1.8b
418        SADALP  v23.4s, v15.8h
419        LDP     d6, d7, [x5, 48]
420        SMULL   v12.8h, v6.8b, v0.8b
421        SADALP  v24.4s,  v2.8h
422        SMULL   v13.8h, v6.8b, v1.8b
423        SADALP  v25.4s,  v3.8h
424        SMULL   v14.8h, v7.8b, v0.8b
425        SADALP  v26.4s, v10.8h
426        SMULL   v15.8h, v7.8b, v1.8b
427        SADALP  v27.4s, v11.8h
428        ADD     x5, x5, 64
429        SADALP  v28.4s, v12.8h
430        SADALP  v29.4s, v13.8h
431        SADALP  v30.4s, v14.8h
432        SADALP  v31.4s, v15.8h
433
434        # ks loop
435        SUBS    x9, x9, 16              // ks -= MR * sizeof(int8_t*)
436        B.HI    1b
437        B       4b
438
439        # Store odd width
440        .p2align 3
4416:
442        TBZ     x1, 2, 7f
443        ST1     {v0.s}[2], [x7], 4
444        STR     s0, [x6], 4
445        EXT     v0.16b, v0.16b, v0.16b, 4
446
4477:
448        TBZ     x1, 1, 8f
449        ST1     {v0.h}[4], [x7], 2
450        STR     h0, [x6], 2
451        EXT     v0.16b, v0.16b, v0.16b, 2
4528:
453        TBZ     x1, 0, 9f
454        ST1     {v0.b}[8], [x7]
455        STR     b0, [x6]
4569:
457        # Restore x20,x21 from stack
458        LDP     x20, x21, [sp, 64]
459
460        # Restore d8-d15 from stack
461        LDP     d14, d15, [sp, 48]
462        LDP     d12, d13, [sp, 32]
463        LDP     d10, d11, [sp, 16]
464        LDP     d8, d9, [sp], 80
465        RET
466
467END_FUNCTION xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_2x8c8__aarch64_neon_mlal${"_prfm" if PREFETCH else ""}_cortex_a53
468
469#ifdef __ELF__
470.section ".note.GNU-stack","",%progbits
471#endif
472