xref: /aosp_15_r20/external/XNNPACK/src/qs8-igemm/4x16c4-aarch64-neondot-ld64.S.in (revision 4bdc94577ba0e567308109d787f7fec7b531ce36)
1// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$assert REQUANTIZATION in ["FP32", "RNDNU"]
7$assert not CHANNELWISE or REQUANTIZATION == "FP32"
8
9#include <xnnpack/assembly.h>
10
11$DATATYPE = "qc8" if CHANNELWISE else "qs8"
12$PARAMS_UNION = "xnn_qs8_minmax_params" if CHANNELWISE else "xnn_qs8_conv_minmax_params"
13$REWIND_DECREMENT = 3 if CHANNELWISE else {"RNDNU": 15, "FP32": 7}[REQUANTIZATION]
14# void xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x16c4__aarch64_neondot_ld64(
15#     size_t mr,                 x0
16#     size_t nc,                 x1
17#     size_t kc,                 x2 / x0
18#     size_t ks,                 x3 / x9
19#     const int8_t**restrict a,  x4
20#     const int8_t* restrict w,  x5
21#     int8_t* restrict c,        x6
22#     size_t cm_stride,          x7
23#     size_t cn_stride,                  [sp] -> (x0)
24#     size_t a_offset,                   [sp + 8] -> x8
25#     const int8_t* zero,                [sp + 16] -> x12
26#     const union ${PARAMS_UNION} params [sp + 24] -> x11
27
28# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
29
30# Register usage
31# A0  x13  v0
32# A1  x14  v1
33# A2  x15  v2
34# A3  x10  v3
35# B    x5  v4  v5  v6  v7
36# C0   x6 v16 v20 v24 v28
37# C1  x16 v17 v21 v25 v29
38# C2  x17 v18 v22 v26 v30
39# C3   x7 v19 v23 v27 v31
40# unused v8 v9 v10 v11 v12 v13 v14 v15
41
42BEGIN_FUNCTION xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x16c4__aarch64_neondot_ld64
43
44        # Clamp C pointers
45        CMP     x0, 2                   // if mr < 2
46        LDR     x8, [sp, 8]             // Load a_offset
47        ADD     x16, x6, x7             // c1 = c0 + cm_stride
48        CSEL    x16, x6,  x16, LO       //   c1 = c0
49        ADD     x2, x2, 3               // kc = (kc + 3) & ~3
50
51        ADD     x17, x16, x7            // c2 = c1 + cm_stride
52        LDP     x12, x11, [sp, 16]      // Load zero, params pointer
53                                        // if mr <= 2
54        CSEL    x17, x16, x17, LS       //   c2 = c1
55        BIC     x2, x2, 3
56
57        CMP     x0, 4                   // if mr < 4
58        ADD     x7,  x17, x7            // c3 = c2 + cm_stride
59        CSEL    x7,  x17, x7, LO        //   c3 = c2
60
61        .p2align 3
620:
63        # Load initial bias from w into accumulators
64        LDP     q16, q20, [x5], 32
65        MOV     v17.16b, v16.16b
66        MOV     v18.16b, v16.16b
67        LDP     q24, q28, [x5], 32
68        MOV     v19.16b, v16.16b
69        MOV     v21.16b, v20.16b
70        MOV     v22.16b, v20.16b
71        MOV     v23.16b, v20.16b
72        MOV     v25.16b, v24.16b
73        MOV     v26.16b, v24.16b
74        MOV     v27.16b, v24.16b
75        MOV     v29.16b, v28.16b
76        MOV     v30.16b, v28.16b
77        MOV     v31.16b, v28.16b
78        MOV     x9, x3                  // p = ks
79
80        .p2align 3
811:
82        # Load next 4 A pointers
83        LDP     x13, x14, [x4], 16
84        LDP     x15, x10, [x4], 16
85
86        CMP     x13, x12                // if a0 == zero
87        ADD     x13, x13, x8            // a0 += a_offset
88        CSEL    x13, x12, x13, EQ       //   a0 = zero, else += a0 + a_offset
89        CMP     x14, x12                // if a1 == zero
90        ADD     x14, x14, x8            // a1 += a_offset
91        CSEL    x14, x12, x14, EQ       //   a1 = zero, else += a1 + a_offset
92        CMP     x15, x12                // if a2 == zero
93        ADD     x15, x15, x8            // a2 += a_offset
94        CSEL    x15, x12, x15, EQ       //   a2 = zero, else += a2 + a_offset
95        CMP     x10, x12                // if a3 == zero
96        ADD     x10, x10, x8            // a3 += a_offset
97        CSEL    x10, x12, x10, EQ       //   a3 = zero, else += a3 + a_offset
98
99        # Is there at least 8 bytes for main loop?
100        SUBS    x0, x2, 8               // k = kc - 8
101        B.LO    4f
102
103        # Main loop - 8 bytes of A
104        .p2align 3
1052:
106        LDR     d0, [x13], 8
107        LDR     q4,  [x5], 16
108        LDR     d1, [x14], 8
109        LDR     d2, [x15], 8
110        LDR     d3, [x10], 8
111        LDR     q5,  [x5], 16
112        SDOT    v16.4s, v4.16b,  v0.4b[0]
113        SDOT    v17.4s, v4.16b,  v1.4b[0]
114        LDP     q6, q7, [x5], 32
115        SDOT    v18.4s, v4.16b,  v2.4b[0]
116        SDOT    v19.4s, v4.16b,  v3.4b[0]
117        SDOT    v20.4s, v5.16b,  v0.4b[0]
118        SDOT    v21.4s, v5.16b,  v1.4b[0]
119        SDOT    v22.4s, v5.16b,  v2.4b[0]
120        SDOT    v23.4s, v5.16b,  v3.4b[0]
121        SDOT    v24.4s, v6.16b, v0.4b[0]
122        SDOT    v25.4s, v6.16b, v1.4b[0]
123        LDP     q4, q5, [x5], 32
124        SDOT    v26.4s, v6.16b, v2.4b[0]
125        SDOT    v27.4s, v6.16b, v3.4b[0]
126        SDOT    v28.4s, v7.16b, v0.4b[0]
127        SDOT    v29.4s, v7.16b, v1.4b[0]
128        SDOT    v30.4s, v7.16b, v2.4b[0]
129        SDOT    v31.4s, v7.16b, v3.4b[0]
130        SDOT    v16.4s, v4.16b,  v0.4b[1]
131        SDOT    v17.4s, v4.16b,  v1.4b[1]
132        LDP     q6, q7, [x5], 32
133        SDOT    v18.4s, v4.16b,  v2.4b[1]
134        SDOT    v19.4s, v4.16b,  v3.4b[1]
135        SDOT    v20.4s, v5.16b,  v0.4b[1]
136        SDOT    v21.4s, v5.16b,  v1.4b[1]
137        SDOT    v22.4s, v5.16b,  v2.4b[1]
138        SDOT    v23.4s, v5.16b,  v3.4b[1]
139        SDOT    v24.4s, v6.16b,  v0.4b[1]
140        SDOT    v25.4s, v6.16b,  v1.4b[1]
141        SDOT    v26.4s, v6.16b,  v2.4b[1]
142        SDOT    v27.4s, v6.16b,  v3.4b[1]
143        SDOT    v28.4s, v7.16b,  v0.4b[1]
144        SDOT    v29.4s, v7.16b,  v1.4b[1]
145        SDOT    v30.4s, v7.16b,  v2.4b[1]
146        SUBS    x0, x0, 8
147        SDOT    v31.4s, v7.16b,  v3.4b[1]
148        B.HS    2b
149
150        # Is there a remainder?- 4 bytes of A
151        TBNZ    x0, 2, 4f
152
153        # ks loop
154        SUBS    x9, x9, 32              // ks -= MR * sizeof(int8_t*)
155        B.HI    1b
156
1573:
158        $if REQUANTIZATION == "RNDNU":
159          # Apply params - preshift, scale, postshift, bias and clamp
160          LD1R    {v4.4s}, [x11], 4
161          SQSHL   v16.4s, v16.4s, v4.4s   // shift to upper bits
162          SQSHL   v17.4s, v17.4s, v4.4s
163          SQSHL   v18.4s, v18.4s, v4.4s
164          SQSHL   v19.4s, v19.4s, v4.4s
165          SQSHL   v20.4s, v20.4s, v4.4s
166          SQSHL   v21.4s, v21.4s, v4.4s
167          SQSHL   v22.4s, v22.4s, v4.4s
168          SQSHL   v23.4s, v23.4s, v4.4s
169          LD1R    {v5.4s}, [x11], 4
170          SQSHL   v24.4s, v24.4s, v4.4s
171          SQSHL   v25.4s, v25.4s, v4.4s
172          SQSHL   v26.4s, v26.4s, v4.4s
173          SQSHL   v27.4s, v27.4s, v4.4s
174          SQSHL   v28.4s, v28.4s, v4.4s
175          SQSHL   v29.4s, v29.4s, v4.4s
176          SQSHL   v30.4s, v30.4s, v4.4s
177          SQSHL   v31.4s, v31.4s, v4.4s
178          LD1R    {v6.4s}, [x11], 4
179          SQDMULH v16.4s, v16.4s, v5.4s   // scale without rounding
180          SQDMULH v17.4s, v17.4s, v5.4s
181          SQDMULH v18.4s, v18.4s, v5.4s
182          SQDMULH v19.4s, v19.4s, v5.4s
183          SQDMULH v20.4s, v20.4s, v5.4s
184          SQDMULH v21.4s, v21.4s, v5.4s
185          SQDMULH v22.4s, v22.4s, v5.4s
186          SQDMULH v23.4s, v23.4s, v5.4s
187          SQDMULH v24.4s, v24.4s, v5.4s
188          SQDMULH v25.4s, v25.4s, v5.4s
189          SQDMULH v26.4s, v26.4s, v5.4s
190          SQDMULH v27.4s, v27.4s, v5.4s
191          SQDMULH v28.4s, v28.4s, v5.4s
192          SQDMULH v29.4s, v29.4s, v5.4s
193          SQDMULH v30.4s, v30.4s, v5.4s
194          SQDMULH v31.4s, v31.4s, v5.4s
195          SRSHL   v16.4s, v16.4s, v6.4s   // signed rounding shift left
196          SRSHL   v17.4s, v17.4s, v6.4s
197          SRSHL   v18.4s, v18.4s, v6.4s
198          SRSHL   v19.4s, v19.4s, v6.4s
199          SRSHL   v20.4s, v20.4s, v6.4s
200          SRSHL   v21.4s, v21.4s, v6.4s
201          SRSHL   v22.4s, v22.4s, v6.4s
202          SRSHL   v23.4s, v23.4s, v6.4s
203          SRSHL   v24.4s, v24.4s, v6.4s
204          SRSHL   v25.4s, v25.4s, v6.4s
205          SRSHL   v26.4s, v26.4s, v6.4s
206          SRSHL   v27.4s, v27.4s, v6.4s
207          SRSHL   v28.4s, v28.4s, v6.4s
208          SRSHL   v29.4s, v29.4s, v6.4s
209          SRSHL   v30.4s, v30.4s, v6.4s
210          SRSHL   v31.4s, v31.4s, v6.4s
211        $elif REQUANTIZATION == "FP32":
212          SCVTF   v16.4s, v16.4s
213          SCVTF   v17.4s, v17.4s
214          $if not CHANNELWISE:
215            # Apply params - scale, bias and clamp
216            LD1R    {v4.4s}, [x11], 4
217            SCVTF   v18.4s, v18.4s
218            SCVTF   v19.4s, v19.4s
219          $else:
220            # Load per channel scale values from weights
221            LDR     q4, [x5], 16
222            SCVTF   v18.4s, v18.4s
223            SCVTF   v19.4s, v19.4s
224            LDR     q5, [x5], 16
225          SCVTF   v20.4s, v20.4s
226          SCVTF   v21.4s, v21.4s
227          SCVTF   v22.4s, v22.4s
228          SCVTF   v23.4s, v23.4s
229          SCVTF   v24.4s, v24.4s
230          SCVTF   v25.4s, v25.4s
231          SCVTF   v26.4s, v26.4s
232          SCVTF   v27.4s, v27.4s
233          SCVTF   v28.4s, v28.4s
234          SCVTF   v29.4s, v29.4s
235          SCVTF   v30.4s, v30.4s
236          SCVTF   v31.4s, v31.4s
237
238          $if CHANNELWISE:
239            LDR     q6, [x5], 16
240            FMUL    v16.4s, v16.4s, v4.4s
241            FMUL    v17.4s, v17.4s, v4.4s
242            FMUL    v18.4s, v18.4s, v4.4s
243            FMUL    v19.4s, v19.4s, v4.4s
244            FMUL    v20.4s, v20.4s, v5.4s
245            LDR     q4, [x5], 16
246            FMUL    v21.4s, v21.4s, v5.4s
247            FMUL    v22.4s, v22.4s, v5.4s
248            FMUL    v23.4s, v23.4s, v5.4s
249            FMUL    v24.4s, v24.4s, v6.4s
250            FMUL    v25.4s, v25.4s, v6.4s
251            FMUL    v26.4s, v26.4s, v6.4s
252            FMUL    v27.4s, v27.4s, v6.4s
253            FMUL    v28.4s, v28.4s, v4.4s
254            FMUL    v29.4s, v29.4s, v4.4s
255            FMUL    v30.4s, v30.4s, v4.4s
256            FMUL    v31.4s, v31.4s, v4.4s
257          $else:
258            FMUL    v16.4s, v16.4s, v4.4s
259            FMUL    v17.4s, v17.4s, v4.4s
260            FMUL    v18.4s, v18.4s, v4.4s
261            FMUL    v19.4s, v19.4s, v4.4s
262            FMUL    v20.4s, v20.4s, v4.4s
263            FMUL    v21.4s, v21.4s, v4.4s
264            FMUL    v22.4s, v22.4s, v4.4s
265            FMUL    v23.4s, v23.4s, v4.4s
266            FMUL    v24.4s, v24.4s, v4.4s
267            FMUL    v25.4s, v25.4s, v4.4s
268            FMUL    v26.4s, v26.4s, v4.4s
269            FMUL    v27.4s, v27.4s, v4.4s
270            FMUL    v28.4s, v28.4s, v4.4s
271            FMUL    v29.4s, v29.4s, v4.4s
272            FMUL    v30.4s, v30.4s, v4.4s
273            FMUL    v31.4s, v31.4s, v4.4s
274
275          FCVTNS  v16.4s, v16.4s
276          FCVTNS  v17.4s, v17.4s
277          FCVTNS  v18.4s, v18.4s
278          FCVTNS  v19.4s, v19.4s
279          FCVTNS  v20.4s, v20.4s
280          FCVTNS  v21.4s, v21.4s
281          FCVTNS  v22.4s, v22.4s
282          FCVTNS  v23.4s, v23.4s
283          FCVTNS  v24.4s, v24.4s
284          FCVTNS  v25.4s, v25.4s
285          FCVTNS  v26.4s, v26.4s
286          FCVTNS  v27.4s, v27.4s
287          FCVTNS  v28.4s, v28.4s
288          FCVTNS  v29.4s, v29.4s
289          FCVTNS  v30.4s, v30.4s
290          FCVTNS  v31.4s, v31.4s
291
292        SQXTN   v16.4h, v16.4s
293        SQXTN   v17.4h, v17.4s
294        SQXTN   v18.4h, v18.4s
295        SQXTN   v19.4h, v19.4s
296        SQXTN   v24.4h, v24.4s
297        SQXTN   v25.4h, v25.4s
298        SQXTN   v26.4h, v26.4s
299        SQXTN   v27.4h, v27.4s
300        LD1R    {v6.8h}, [x11], 2        // add bias
301
302        SQXTN2  v16.8h, v20.4s
303        SQXTN2  v17.8h, v21.4s
304        SQXTN2  v18.8h, v22.4s
305        SQXTN2  v19.8h, v23.4s
306        SQXTN2  v24.8h, v28.4s
307        SQXTN2  v25.8h, v29.4s
308        SQXTN2  v26.8h, v30.4s
309        SQXTN2  v27.8h, v31.4s
310
311        SQADD   v16.8h, v16.8h, v6.8h
312        SQADD   v17.8h, v17.8h, v6.8h
313        SQADD   v18.8h, v18.8h, v6.8h
314        SQADD   v19.8h, v19.8h, v6.8h
315        SQADD   v24.8h, v24.8h, v6.8h
316        SQADD   v25.8h, v25.8h, v6.8h
317        SQADD   v26.8h, v26.8h, v6.8h
318        SQADD   v27.8h, v27.8h, v6.8h
319        LD1R    {v4.16b}, [x11], 1       // clamp min value
320
321        SQXTN   v0.8b, v16.8h
322        SQXTN   v1.8b, v17.8h
323        SQXTN   v2.8b, v18.8h
324        SQXTN   v3.8b, v19.8h
325        LD1R    {v5.16b}, [x11]          // clamp max value
326        SQXTN2  v0.16b, v24.8h
327        SQXTN2  v1.16b, v25.8h
328        SQXTN2  v2.16b, v26.8h
329        SQXTN2  v3.16b, v27.8h
330        LDR     x0, [sp]                 // cn_stride
331        SMAX    v0.16b, v0.16b, v4.16b
332        SMAX    v1.16b, v1.16b, v4.16b
333        SUB     x11, x11, ${REWIND_DECREMENT}          // rewind params pointer
334        SMAX    v2.16b, v2.16b, v4.16b
335        SMAX    v3.16b, v3.16b, v4.16b
336        SUBS    x1, x1, 16
337        SMIN    v0.16b, v0.16b, v5.16b
338        SMIN    v1.16b, v1.16b, v5.16b
339        SMIN    v2.16b, v2.16b, v5.16b
340        SMIN    v3.16b, v3.16b, v5.16b
341        B.LO    5f
342
343        # Store full 4 x 16
344        ST1     {v3.16b},  [x7], x0
345        ST1     {v2.16b}, [x17], x0
346        ST1     {v1.16b}, [x16], x0
347        ST1     {v0.16b},  [x6], x0
348
349        SUB     x4, x4, x3              // a -= ks
350
351        # nc loop
352        B.HI    0b
353        RET
354
355        # Remainder- 4 bytes of A
356        .p2align 3
3574:
358        LDR     s0, [x13], 4
359        LDR     q4, [x5], 16
360        LDR     s1, [x14], 4
361        LDR     s2, [x15], 4
362        LDR     s3, [x10], 4
363        LDR     q5, [x5], 16
364        SDOT    v16.4s, v4.16b,  v0.4b[0]
365        SDOT    v17.4s, v4.16b,  v1.4b[0]
366        LDP     q6, q7, [x5], 32
367        SDOT    v18.4s, v4.16b,  v2.4b[0]
368        SDOT    v19.4s, v4.16b,  v3.4b[0]
369        SDOT    v20.4s, v5.16b,  v0.4b[0]
370        SDOT    v21.4s, v5.16b,  v1.4b[0]
371        SDOT    v22.4s, v5.16b,  v2.4b[0]
372        SDOT    v23.4s, v5.16b,  v3.4b[0]
373        SDOT    v24.4s, v6.16b, v0.4b[0]
374        SDOT    v25.4s, v6.16b, v1.4b[0]
375        SDOT    v26.4s, v6.16b, v2.4b[0]
376        SDOT    v27.4s, v6.16b, v3.4b[0]
377        SDOT    v28.4s, v7.16b, v0.4b[0]
378        SDOT    v29.4s, v7.16b, v1.4b[0]
379        SDOT    v30.4s, v7.16b, v2.4b[0]
380        SDOT    v31.4s, v7.16b, v3.4b[0]
381
382        # ks loop
383        SUBS    x9, x9, 32              // ks -= MR * sizeof(int8_t*)
384        B.HI    1b
385        B       3b
386
387        # Store odd width
388        .p2align 3
3895:
390        TBZ     x1, 3, 6f
391        STR     d3, [x7], 8
392        STR     d2, [x17], 8
393        DUP     d3, v3.d[1]
394        DUP     d2, v2.d[1]
395        STR     d1, [x16], 8
396        STR     d0, [x6], 8
397        DUP     d1, v1.d[1]
398        DUP     d0, v0.d[1]
3996:
400        TBZ     x1, 2, 7f
401        STR     s3, [x7], 4
402        STR     s2, [x17], 4
403        DUP     s3, v3.s[1]
404        DUP     s2, v2.s[1]
405        STR     s1, [x16], 4
406        STR     s0, [x6], 4
407        DUP     s1, v1.s[1]
408        DUP     s0, v0.s[1]
4097:
410        TBZ     x1, 1, 8f
411        STR     h3, [x7], 2
412        STR     h2, [x17], 2
413        DUP     h3, v3.h[1]
414        DUP     h2, v2.h[1]
415        STR     h1, [x16], 2
416        STR     h0, [x6], 2
417        DUP     h1, v1.h[1]
418        DUP     h0, v0.h[1]
4198:
420        TBZ     x1, 0, 9f
421        STR     b3, [x7]
422        STR     b2, [x17]
423        STR     b1, [x16]
424        STR     b0, [x6]
4259:
426        RET
427
428END_FUNCTION xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_4x16c4__aarch64_neondot_ld64
429
430#ifdef __ELF__
431.section ".note.GNU-stack","",%progbits
432#endif
433