xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/kernels/a64_gemm_u16_8x12/generic.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __aarch64__
25 
26 #include <arm_neon.h>
27 
28 #include "../../asmlib.hpp"
29 
30 namespace arm_gemm {
31 
a64_gemm_u16_asimd_8x12(const uint16_t * Apanel,const uint16_t * Bpanel,uint32_t * Cpanel,int ablocks,int bblocks,int K)32 void a64_gemm_u16_asimd_8x12(const uint16_t *Apanel, const uint16_t *Bpanel, uint32_t *Cpanel, int ablocks, int bblocks, int K)
33 {
34   const uint16_t *a_ptr = Apanel;
35   uint32_t *c_ptr = Cpanel;
36 
37   for (int yb = 0; yb < ablocks; yb++)
38   {
39     const uint16_t *a_ptr0 = a_ptr;
40     const uint16_t *b_ptr = Bpanel;
41 
42     for (int xb = 0; xb < bblocks; xb++)
43     {
44       a_ptr = a_ptr0;
45       const bool odd_k = K & 0x1;
46       int k = (K+1)/2 - 1;
47 
48       register uint16x8_t aa asm("v0");
49       register uint16x8_t ab asm("v1");
50       register uint16x8_t b0 asm("v2");
51       register uint16x8_t b1 asm("v3");
52       register uint16x8_t b2 asm("v4");
53 
54       __asm __volatile (
55         "ldr %d[aa], [%x[a_ptr]]\n"  // Load A[A].lower
56         "movi v5.4s, #0\n"
57         "ldr x20, [%x[a_ptr], #0x08]\n"  // Load A[A].upper
58         "movi v6.4s, #0\n"
59         "ldr %d[b0], [%x[b_ptr]]\n"  // Load B[0].lower
60         "ins %[aa].d[1], x20\n"  // Merge A[A].lower and upper
61         "movi v7.4s, #0\n"
62         ASM_PREFETCH("[%[a_ptr], #64]")
63         "movi v8.4s, #0\n"
64         "ldr x20, [%x[b_ptr], #0x08]\n"  // Load B[0].upper
65         "movi v9.4s, #0\n"
66         ASM_PREFETCH("[%[b_ptr], #64]")
67         "movi v10.4s, #0\n"
68         "ldr %d[b1], [%x[b_ptr], #0x10]\n"  // Load B[1].lower
69         "ins %[b0].d[1], x20\n"  // Merge B[0].lower and upper
70         "movi v11.4s, #0\n"
71         ASM_PREFETCH("[%[a_ptr], #96]")
72         "movi v12.4s, #0\n"
73         "movi v13.4s, #0\n"
74         ASM_PREFETCH("[%[b_ptr], #96]")
75         "movi v14.4s, #0\n"
76         "movi v15.4s, #0\n"
77         ASM_PREFETCH("[%[a_ptr], #128]")
78         "movi v16.4s, #0\n"
79         "movi v17.4s, #0\n"
80         ASM_PREFETCH("[%[b_ptr], #128]")
81         "movi v18.4s, #0\n"
82         "movi v19.4s, #0\n"
83         ASM_PREFETCH("[%[a_ptr], #160]")
84         "movi v20.4s, #0\n"
85         "movi v21.4s, #0\n"
86         ASM_PREFETCH("[%[b_ptr], #160]")
87         "movi v22.4s, #0\n"
88         "movi v23.4s, #0\n"
89         ASM_PREFETCH("[%[a_ptr], #192]")
90         "movi v24.4s, #0\n"
91         "add %x[a_ptr], %x[a_ptr], #0x10\n"
92         "movi v25.4s, #0\n"
93         ASM_PREFETCH("[%[b_ptr], #192]")
94         "movi v26.4s, #0\n"
95         "add %x[b_ptr], %x[b_ptr], #0x18\n"
96         "movi v27.4s, #0\n"
97         "movi v28.4s, #0\n"
98 
99         "cbz %x[k], 2f\n"  // Skip the loop if doing zero iterations.
100 
101         "1:\n"  // Main loop
102           // First unroll
103           "umlal v5.4s, %[b0].4h, %[aa].h[0]\n"
104           "ldr x20, [%x[b_ptr]]\n"  // Load B[1].upper
105           "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
106           "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
107           "ldr %d[ab], [%x[a_ptr]]\n"  // Load A[B].lower
108           "ins %[b1].d[1], x20\n"  // Merge B[1].lower and .upper
109           "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
110           "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
111           "ldr x20, [%x[a_ptr], #0x8]\n"  // Load A[B].upper
112           "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
113           "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
114           "ldr %d[b2], [%x[b_ptr], #0x8]\n"  // Load B[2].lower
115           "ins %[ab].d[1], x20\n"  // Merge A[B].lower and .upper
116           "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
117           "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
118           "ldr x20, [%x[b_ptr], #0x10]\n"  // Load B[2].upper
119           "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
120           "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
121           "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
122           "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
123           "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
124           "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
125           "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
126           "ldr %d[b0], [%x[b_ptr], #0x18]\n"  // Load B[0].lower
127           "ins %[b2].d[1], x20\n"  // Merge B[2].lower and .upper
128           "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
129           "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
130           "ldr x20, [%x[b_ptr], #0x20]\n"  // Load B[0].upper
131           "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
132           "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
133           "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
134           "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
135           "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
136           "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
137 
138           // Second unroll
139           "umlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
140           "ldr %d[aa], [%x[a_ptr], #0x10]\n"  // Load A[A].lower
141           "ins %[b0].d[1], x20\n"  // Merge B[0].lower and .upper
142           "umlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
143           "umlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
144           "ldr x20, [%x[a_ptr], #0x18]\n"  // Load A[A].upper
145           "umlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
146           "umlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
147           "umlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
148           "umlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
149           "add %x[a_ptr], %x[a_ptr], #0x20\n"
150           "umlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
151           "umlal v13.4s, %[b2].4h, %[ab].h[0]\n"
152           ASM_PREFETCH("[%[b_ptr], #320]")
153           "umlal v14.4s, %[b2].4h, %[ab].h[1]\n"
154           "umlal v15.4s, %[b2].4h, %[ab].h[2]\n"
155           ASM_PREFETCH("[%[a_ptr], #320]")
156           "umlal v16.4s, %[b2].4h, %[ab].h[3]\n"
157           "umlal v17.4s, %[b2].4h, %[ab].h[4]\n"
158           ASM_PREFETCH("[%[b_ptr], #448]")
159           "umlal v18.4s, %[b2].4h, %[ab].h[5]\n"
160           "umlal v19.4s, %[b2].4h, %[ab].h[6]\n"
161           "umlal v20.4s, %[b2].4h, %[ab].h[7]\n"
162           "umlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
163           "umlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
164           "subs %x[k], %x[k], #0x1\n"
165           "umlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
166           "umlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
167           "ldr %d[b1], [%x[b_ptr], #0x28]\n"  // Load B[1].lower
168           "ins %[aa].d[1], x20\n"  // Merge A[A].lower and .upper
169           "umlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
170           "umlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
171           "add %x[b_ptr], %x[b_ptr], #0x30\n"
172           "umlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
173           "umlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
174           "bne 1b\n"
175 
176         "2:\n"  // Even tail
177           "cbnz %x[odd_k], 3f\n"
178 
179           "umlal v5.4s, %[b0].4h, %[aa].h[0]\n"
180           "ldr x20, [%x[b_ptr]]\n"  // Load B[1].upper
181           "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
182           "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
183           "ldr %d[ab], [%x[a_ptr]]\n"  // Load A[B].lower
184           "ins %[b1].d[1], x20\n"  // Merge B[1].lower and .upper
185           "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
186           "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
187           "ldr x20, [%x[a_ptr], #0x8]\n"  // Load A[B].upper
188           "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
189           "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
190           "ldr %d[b2], [%x[b_ptr], #0x8]\n"  // Load B[2].lower
191           "ins %[ab].d[1], x20\n"  // Merge A[B].lower and .upper
192           "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
193           "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
194           "ldr x20, [%x[b_ptr], #0x10]\n"  // Load B[2].upper
195           "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
196           "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
197           "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
198           "add %[a_ptr], %[a_ptr], #0x10\n"
199           "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
200           "add %[b_ptr], %[b_ptr], #0x18\n"
201           "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
202           "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
203           "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
204           "ins %[b2].d[1], x20\n"  // Merge B[2].lower and .upper
205           "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
206           "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
207           "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
208           "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
209           "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
210           "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
211           "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
212           "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
213 
214           "umlal2 v5.4s, %[b1].8h, %[ab].h[0]\n"
215           "umlal v13.4s, %[b2].4h, %[ab].h[0]\n"
216           "umlal2 v21.4s, %[b2].8h, %[ab].h[0]\n"
217           "umlal2 v6.4s, %[b1].8h, %[ab].h[1]\n"
218           "umlal v14.4s, %[b2].4h, %[ab].h[1]\n"
219           "str q5, [%x[c_ptr]]\n"
220           "umlal2 v22.4s, %[b2].8h, %[ab].h[1]\n"
221           "str q13, [%x[c_ptr], #0x10]\n"
222           "umlal2 v7.4s, %[b1].8h, %[ab].h[2]\n"
223           "str q21, [%x[c_ptr], #0x20]\n"
224           "umlal v15.4s, %[b2].4h, %[ab].h[2]\n"
225           "str q6, [%x[c_ptr], #0x30]\n"
226           "umlal2 v23.4s, %[b2].8h, %[ab].h[2]\n"
227           "str q14, [%x[c_ptr], #0x40]\n"
228           "umlal2 v8.4s, %[b1].8h, %[ab].h[3]\n"
229           "str q22, [%x[c_ptr], #0x50]\n"
230           "umlal v16.4s, %[b2].4h, %[ab].h[3]\n"
231           "str q7, [%x[c_ptr], #0x60]\n"
232           "umlal2 v24.4s, %[b2].8h, %[ab].h[3]\n"
233           "str q15, [%x[c_ptr], #0x70]\n"
234           "umlal2 v9.4s, %[b1].8h, %[ab].h[4]\n"
235           "str q23, [%x[c_ptr], #0x80]\n"
236           "umlal v17.4s, %[b2].4h, %[ab].h[4]\n"
237           "str q8, [%x[c_ptr], #0x90]\n"
238           "umlal2 v25.4s, %[b2].8h, %[ab].h[4]\n"
239           "str q16, [%x[c_ptr], #0xa0]\n"
240           "umlal2 v10.4s, %[b1].8h, %[ab].h[5]\n"
241           "str q24, [%x[c_ptr], #0xb0]\n"
242           "umlal v18.4s, %[b2].4h, %[ab].h[5]\n"
243           "str q9, [%x[c_ptr], #0xc0]\n"
244           "umlal2 v26.4s, %[b2].8h, %[ab].h[5]\n"
245           "str q17, [%x[c_ptr], #0xd0]\n"
246           "umlal2 v11.4s, %[b1].8h, %[ab].h[6]\n"
247           "str q25, [%x[c_ptr], #0xe0]\n"
248           "umlal v19.4s, %[b2].4h, %[ab].h[6]\n"
249           "str q10, [%x[c_ptr], #0xf0]\n"
250           "umlal2 v27.4s, %[b2].8h, %[ab].h[6]\n"
251           "str q18, [%x[c_ptr], #0x100]\n"
252           "umlal2 v12.4s, %[b1].8h, %[ab].h[7]\n"
253           "str q26, [%x[c_ptr], #0x110]\n"
254           "umlal v20.4s, %[b2].4h, %[ab].h[7]\n"
255           "str q11, [%x[c_ptr], #0x120]\n"
256           "umlal2 v28.4s, %[b2].8h, %[ab].h[7]\n"
257           "str q19, [%x[c_ptr], #0x130]\n"
258           "b 4f\n"  // Complete write out
259 
260         "3:\n"  // Odd tail
261           "umlal v5.4s, %[b0].4h, %[aa].h[0]\n"
262           "umlal2 v13.4s, %[b0].8h, %[aa].h[0]\n"
263           "umlal v21.4s, %[b1].4h, %[aa].h[0]\n"
264           "umlal v6.4s, %[b0].4h, %[aa].h[1]\n"
265           "umlal2 v14.4s, %[b0].8h, %[aa].h[1]\n"
266           "umlal v22.4s, %[b1].4h, %[aa].h[1]\n"
267           "str q5, [%x[c_ptr]]\n"
268           "umlal v7.4s, %[b0].4h, %[aa].h[2]\n"
269           "str q13, [%x[c_ptr], #0x10]\n"
270           "umlal2 v15.4s, %[b0].8h, %[aa].h[2]\n"
271           "str q21, [%x[c_ptr], #0x20]\n"
272           "umlal v23.4s, %[b1].4h, %[aa].h[2]\n"
273           "str q6, [%x[c_ptr], #0x30]\n"
274           "umlal v8.4s, %[b0].4h, %[aa].h[3]\n"
275           "str q14, [%x[c_ptr], #0x40]\n"
276           "umlal2 v16.4s, %[b0].8h, %[aa].h[3]\n"
277           "str q22, [%x[c_ptr], #0x50]\n"
278           "umlal v24.4s, %[b1].4h, %[aa].h[3]\n"
279           "str q7, [%x[c_ptr], #0x60]\n"
280           "umlal v9.4s, %[b0].4h, %[aa].h[4]\n"
281           "str q15, [%x[c_ptr], #0x70]\n"
282           "umlal2 v17.4s, %[b0].8h, %[aa].h[4]\n"
283           "str q23, [%x[c_ptr], #0x80]\n"
284           "umlal v25.4s, %[b1].4h, %[aa].h[4]\n"
285           "str q8, [%x[c_ptr], #0x90]\n"
286           "umlal v10.4s, %[b0].4h, %[aa].h[5]\n"
287           "str q16, [%x[c_ptr], #0xa0]\n"
288           "umlal2 v18.4s, %[b0].8h, %[aa].h[5]\n"
289           "str q24, [%x[c_ptr], #0xb0]\n"
290           "umlal v26.4s, %[b1].4h, %[aa].h[5]\n"
291           "str q9, [%x[c_ptr], #0xc0]\n"
292           "umlal v11.4s, %[b0].4h, %[aa].h[6]\n"
293           "str q17, [%x[c_ptr], #0xd0]\n"
294           "umlal2 v19.4s, %[b0].8h, %[aa].h[6]\n"
295           "str q25, [%x[c_ptr], #0xe0]\n"
296           "umlal v27.4s, %[b1].4h, %[aa].h[6]\n"
297           "str q10, [%x[c_ptr], #0xf0]\n"
298           "umlal v12.4s, %[b0].4h, %[aa].h[7]\n"
299           "str q18, [%x[c_ptr], #0x100]\n"
300           "umlal2 v20.4s, %[b0].8h, %[aa].h[7]\n"
301           "str q26, [%x[c_ptr], #0x110]\n"
302           "umlal v28.4s, %[b1].4h, %[aa].h[7]\n"
303           "str q11, [%x[c_ptr], #0x120]\n"
304 
305         "4:\n"  // End of function
306           "str q19, [%x[c_ptr], #0x130]\n"
307           "str q27, [%x[c_ptr], #0x140]\n"
308           "str q12, [%x[c_ptr], #0x150]\n"
309           "str q20, [%x[c_ptr], #0x160]\n"
310           "str q28, [%x[c_ptr], #0x170]\n"
311           "add %x[c_ptr], %x[c_ptr], #0x180\n"
312         : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k),
313           [aa] "+w" (aa), [ab] "+w" (ab), [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2)
314         : [odd_k] "r" (odd_k)
315         : "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "x20", "cc", "memory"
316       );
317     }
318   }
319 }
320 
321 } // namespace arm_gemm
322 
323 #endif
324