xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24/a55r1.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 // Build on AArch64 where either FP16_KERNELS is set or FP16 is explicitly supported.
26 #if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
27 
28 #include <arm_neon.h>
29 
30 #include "../../asmlib.hpp"
31 
32 // Kernel implementation.
33 //
34 // Assume that "Apanel" points to a chunk of A blocks (each size 8xK) in read-order.
35 // Assume that "Bpanel" points to a chunk of B blocks (each size 12xK) in read-order.
36 // Assume that "Cpanel" points to a chunk of C output blocks (each size
37 // 12x8), the chunks being arranged in a row major fashion.
38 //
39 // Note that the intent of this is that either ablocks or bblocks will be 1
40 // - this construction allows the output loop to proceed in either order.
41 
42 namespace arm_gemm {
43 
a64_hgemm_asimd_8x24_a55r1(const __fp16 * Apanel,const __fp16 * Bpanel,__fp16 * Cpanel,int ablocks,int bblocks,int K)44 void a64_hgemm_asimd_8x24_a55r1(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) {
45     const __fp16 *a_ptr = Apanel;
46     __fp16 *c_ptr = Cpanel;
47 
48     // Fix up for odd lengths - set a flag if K is odd, but make
49     // sure we round up the iteration count.
50     int oddk = (K & 1);
51     int k_iters = ((K+1)/2) - 1;
52 
53     for (int yb=0; yb<ablocks; yb++) {
54         const __fp16 *a_ptr0 = a_ptr;
55         const __fp16 *b_ptr = Bpanel;
56 
57         for (int xb=0; xb<bblocks; xb++) {
58             int k = k_iters;
59             a_ptr = a_ptr0;
60 
61             // As A55 requires 64-bit loads anyway, just use 64 bits of the
62             // "A" operands to save on "ins" instructions.  Since A55 is
63             // in-order, two sets of "A" operands and one set of "B" is
64             // sufficient.
65             register float16x8_t a0  asm("v0");
66             register float16x8_t a1  asm("v1");
67             register float16x8_t a0a asm("v2");
68             register float16x8_t a1a asm("v3");
69             register float16x8_t b0  asm("v4");
70             register float16x8_t b1  asm("v5");
71             register float16x8_t b2  asm("v6");
72 
73             __asm __volatile (
74                 // Initialize result registers, load initial operands, prime prefetches.
75                 "movi	v8.8h, #0x0\n"
76                 "ldr	%d[a0], [%[a_ptr]]\n"
77                 "movi	v9.8h, #0x0\n"
78                 "ldr	%q[b0], [%[b_ptr]]\n"
79                 "movi	v10.8h, #0x0\n"
80                 "ldr	%d[a1], [%[a_ptr], #8]\n"
81                 "movi	v11.8h, #0x0\n"
82                 "ldr	%q[b1], [%[b_ptr], #16]\n"
83                 "movi	v12.8h, #0x0\n"
84                 "movi	v13.8h, #0x0\n"
85                 ASM_PREFETCH("[%[b_ptr], #64]")
86                 "movi	v14.8h, #0x0\n"
87                 "movi	v15.8h, #0x0\n"
88                 ASM_PREFETCH("[%[b_ptr], #128]")
89                 "movi	v16.8h, #0x0\n"
90                 "movi	v17.8h, #0x0\n"
91                 ASM_PREFETCH("[%[a_ptr], #64]")
92                 "movi	v18.8h, #0x0\n"
93                 "movi	v19.8h, #0x0\n"
94                 ASM_PREFETCH("[%[b_ptr], #192]")
95                 "movi	v20.8h, #0x0\n"
96                 "movi	v21.8h, #0x0\n"
97                 ASM_PREFETCH("[%[b_ptr], #256]")
98                 "movi	v22.8h, #0x0\n"
99                 "movi	v23.8h, #0x0\n"
100                 ASM_PREFETCH("[%[b_ptr], #320]")
101                 "movi	v24.8h, #0x0\n"
102                 "movi	v25.8h, #0x0\n"
103                 "movi	v26.8h, #0x0\n"
104                 "movi	v27.8h, #0x0\n"
105                 "movi	v28.8h, #0x0\n"
106                 "movi	v29.8h, #0x0\n"
107                 "movi	v30.8h, #0x0\n"
108                 "movi	v31.8h, #0x0\n"
109 
110                 // The loop is offset by these two instructions which must
111                 // always be executed.
112                 "fmla 	v8.8h , %[b0].8h, %[a0].h[0]\n"
113                 "ldr	%d[b2], [%[b_ptr], #32]\n"
114 
115                 // Skip loop if we are doing zero iterations of it.
116                 "cbz	%w[k], 4f\n"
117 
118                 "1:\n"
119                 "fmla  	v9.8h , %[b0].8h, %[a0].h[1]\n"
120                 "ldr	x20, [%[b_ptr], #40]\n"
121                 "fmla	v10.8h, %[b0].8h, %[a0].h[2]\n"
122                 "subs	%w[k], %w[k], #1\n"
123                 "fmla	v11.8h, %[b0].8h, %[a0].h[3]\n"
124                 "ldr	%d[a0a], [%[a_ptr], #16]\n"
125 
126                 "fmla 	v12.8h, %[b0].8h, %[a1].h[0]\n"
127                 "ins	%[b2].d[1], x20\n"
128                 "fmla	v13.8h, %[b0].8h, %[a1].h[1]\n"
129                 "fmla	v14.8h, %[b0].8h, %[a1].h[2]\n"
130                 "fmla	v15.8h, %[b0].8h, %[a1].h[3]\n"
131                 "ldr	%d[a1a], [%[a_ptr], #24]\n"
132 
133                 "fmla	v16.8h, %[b1].8h, %[a0].h[0]\n"
134                 "fmla	v17.8h, %[b1].8h, %[a0].h[1]\n"
135                 "fmla	v18.8h, %[b1].8h, %[a0].h[2]\n"
136                 "fmla	v19.8h, %[b1].8h, %[a0].h[3]\n"
137                 "ldr	%d[b0], [%[b_ptr], #48]\n"
138 
139                 "fmla	v20.8h, %[b1].8h, %[a1].h[0]\n"
140                 "fmla	v21.8h, %[b1].8h, %[a1].h[1]\n"
141                 "ldr	x20, [%[b_ptr], #56]\n"
142                 "fmla	v22.8h, %[b1].8h, %[a1].h[2]\n"
143                 "fmla	v23.8h, %[b1].8h, %[a1].h[3]\n"
144                 "ldr	%d[b1], [%[b_ptr], #64]\n"
145 
146                 "fmla	v24.8h, %[b2].8h, %[a0].h[0]\n"
147                 "ins	%[b0].d[1], x20\n"
148                 "fmla	v25.8h, %[b2].8h, %[a0].h[1]\n"
149                 "ldr	x20, [%[b_ptr], #72]\n"
150                 "fmla	v26.8h, %[b2].8h, %[a0].h[2]\n"
151                 "fmla	v27.8h, %[b2].8h, %[a0].h[3]\n"
152                 ASM_PREFETCH("[%[a_ptr], #128]")
153 
154                 "fmla	v28.8h, %[b2].8h, %[a1].h[0]\n"
155                 "fmla	v29.8h, %[b2].8h, %[a1].h[1]\n"
156                 ASM_PREFETCH("[%[b_ptr], #384]")
157                 "fmla	v30.8h, %[b2].8h, %[a1].h[2]\n"
158                 "fmla	v31.8h, %[b2].8h, %[a1].h[3]\n"
159                 "ldr	%d[b2], [%[b_ptr], #80]\n"
160 
161                 // Unroll 1
162                 "fmla 	v8.8h , %[b0].8h, %[a0a].h[0]\n"
163                 "ins	%[b1].d[1], x20\n"
164                 "fmla	v9.8h , %[b0].8h, %[a0a].h[1]\n"
165                 "ldr	x20, [%[b_ptr], #88]\n"
166                 "fmla	v10.8h, %[b0].8h, %[a0a].h[2]\n"
167                 "fmla	v11.8h, %[b0].8h, %[a0a].h[3]\n"
168                 "ldr	%d[a0], [%[a_ptr], #32]\n"
169 
170                 "fmla 	v12.8h, %[b0].8h, %[a1a].h[0]\n"
171                 "ins	%[b2].d[1], x20\n"
172                 "fmla	v13.8h, %[b0].8h, %[a1a].h[1]\n"
173                 "fmla	v14.8h, %[b0].8h, %[a1a].h[2]\n"
174                 "fmla	v15.8h, %[b0].8h, %[a1a].h[3]\n"
175                 "ldr	%d[a1], [%[a_ptr], #40]\n"
176 
177                 "fmla	v16.8h, %[b1].8h, %[a0a].h[0]\n"
178                 "add	%[a_ptr], %[a_ptr], #32\n"
179                 "fmla	v17.8h, %[b1].8h, %[a0a].h[1]\n"
180                 "fmla	v18.8h, %[b1].8h, %[a0a].h[2]\n"
181                 "fmla	v19.8h, %[b1].8h, %[a0a].h[3]\n"
182                 "ldr	%d[b0], [%[b_ptr], #96]\n"
183 
184                 "fmla	v20.8h, %[b1].8h, %[a1a].h[0]\n"
185                 "fmla	v21.8h, %[b1].8h, %[a1a].h[1]\n"
186                 "ldr	x20, [%[b_ptr], #104]\n"
187                 "fmla	v22.8h, %[b1].8h, %[a1a].h[2]\n"
188                 "fmla	v23.8h, %[b1].8h, %[a1a].h[3]\n"
189                 "ldr	%d[b1], [%[b_ptr], #112]\n"
190 
191                 "fmla	v24.8h, %[b2].8h, %[a0a].h[0]\n"
192                 "ins	%[b0].d[1], x20\n"
193                 "fmla	v25.8h, %[b2].8h, %[a0a].h[1]\n"
194                 "ldr	x20, [%[b_ptr], #120]\n"
195                 "fmla	v26.8h, %[b2].8h, %[a0a].h[2]\n"
196                 "fmla	v27.8h, %[b2].8h, %[a0a].h[3]\n"
197 
198                 "fmla	v28.8h, %[b2].8h, %[a1a].h[0]\n"
199                 ASM_PREFETCH("[%[b_ptr], #448]")
200                 "fmla	v29.8h, %[b2].8h, %[a1a].h[1]\n"
201                 "add	%[b_ptr], %[b_ptr], #96\n"
202                 "fmla	v30.8h, %[b2].8h, %[a1a].h[2]\n"
203                 "ins	%[b1].d[1], x20\n"
204                 "fmla	v31.8h, %[b2].8h, %[a1a].h[3]\n"
205                 "ldr	%d[b2], [%[b_ptr], #32]\n"
206 
207                 "fmla 	v8.8h , %[b0].8h, %[a0].h[0]\n"
208                 "bne	1b\n"
209 
210                 "4:\n"
211 
212                 // Start final iteration - branch off to "odd" code before we load a0a
213                 "fmla  	v9.8h , %[b0].8h, %[a0].h[1]\n"
214                 "ldr	x20, [%[b_ptr], #40]\n"
215                 "fmla	v10.8h, %[b0].8h, %[a0].h[2]\n"
216                 "cbnz	%w[oddk], 2f\n"
217 
218                 // Even K continuation
219                 "fmla	v11.8h, %[b0].8h, %[a0].h[3]\n"
220                 "ldr	%d[a0a], [%[a_ptr], #16]\n"
221 
222                 "fmla 	v12.8h, %[b0].8h, %[a1].h[0]\n"
223                 "ins	%[b2].d[1], x20\n"
224                 "fmla	v13.8h, %[b0].8h, %[a1].h[1]\n"
225                 ASM_PREFETCHW("[%[c_ptr]]")
226                 "fmla	v14.8h, %[b0].8h, %[a1].h[2]\n"
227                 "fmla	v15.8h, %[b0].8h, %[a1].h[3]\n"
228                 "ldr	%d[a1a], [%[a_ptr], #24]\n"
229 
230                 "fmla	v16.8h, %[b1].8h, %[a0].h[0]\n"
231                 "fmla	v17.8h, %[b1].8h, %[a0].h[1]\n"
232                 ASM_PREFETCHW("[%[c_ptr], #64]")
233                 "fmla	v18.8h, %[b1].8h, %[a0].h[2]\n"
234                 "fmla	v19.8h, %[b1].8h, %[a0].h[3]\n"
235                 "ldr	%d[b0], [%[b_ptr], #48]\n"
236 
237                 "fmla	v20.8h, %[b1].8h, %[a1].h[0]\n"
238                 "fmla	v21.8h, %[b1].8h, %[a1].h[1]\n"
239                 "ldr	x20, [%[b_ptr], #56]\n"
240                 "fmla	v22.8h, %[b1].8h, %[a1].h[2]\n"
241                 "fmla	v23.8h, %[b1].8h, %[a1].h[3]\n"
242                 "ldr	%d[b1], [%[b_ptr], #64]\n"
243 
244                 "fmla	v24.8h, %[b2].8h, %[a0].h[0]\n"
245                 "ins	%[b0].d[1], x20\n"
246                 "fmla	v25.8h, %[b2].8h, %[a0].h[1]\n"
247                 "ldr	x20, [%[b_ptr], #72]\n"
248                 "fmla	v26.8h, %[b2].8h, %[a0].h[2]\n"
249                 "fmla	v27.8h, %[b2].8h, %[a0].h[3]\n"
250                 ASM_PREFETCHW("[%[c_ptr], #128]")
251 
252                 "fmla	v28.8h, %[b2].8h, %[a1].h[0]\n"
253                 "fmla	v29.8h, %[b2].8h, %[a1].h[1]\n"
254                 ASM_PREFETCHW("[%[c_ptr], #192]")
255                 "fmla	v30.8h, %[b2].8h, %[a1].h[2]\n"
256                 "fmla	v31.8h, %[b2].8h, %[a1].h[3]\n"
257                 "ldr	%d[b2], [%[b_ptr], #80]\n"
258 
259                 "fmla 	v8.8h , %[b0].8h, %[a0a].h[0]\n"
260                 "ins	%[b1].d[1], x20\n"
261                 "fmla	v9.8h , %[b0].8h, %[a0a].h[1]\n"
262                 "ldr	x20, [%[b_ptr], #88]\n"
263                 "fmla	v10.8h, %[b0].8h, %[a0a].h[2]\n"
264                 "fmla	v11.8h, %[b0].8h, %[a0a].h[3]\n"
265                 ASM_PREFETCHW("[%[c_ptr], #256]")
266 
267                 "fmla 	v12.8h, %[b0].8h, %[a1a].h[0]\n"
268                 "ins	%[b2].d[1], x20\n"
269                 "fmla	v13.8h, %[b0].8h, %[a1a].h[1]\n"
270                 ASM_PREFETCHW("[%[c_ptr], #320]")
271                 "fmla	v14.8h, %[b0].8h, %[a1a].h[2]\n"
272                 "fmla	v15.8h, %[b0].8h, %[a1a].h[3]\n"
273                 "ldr	%d[a1], [%[a_ptr], #40]\n"
274 
275                 "fmla	v16.8h, %[b1].8h, %[a0a].h[0]\n"
276                 "add	%[a_ptr], %[a_ptr], #32\n"
277                 "fmla	v17.8h, %[b1].8h, %[a0a].h[1]\n"
278                 ASM_PREFETCHWL2("[%[c_ptr], #384]")
279                 "fmla	v18.8h, %[b1].8h, %[a0a].h[2]\n"
280                 "fmla	v19.8h, %[b1].8h, %[a0a].h[3]\n"
281                 ASM_PREFETCHWL2("[%[c_ptr], #448]")
282 
283                 "fmla	v20.8h, %[b1].8h, %[a1a].h[0]\n"
284                 "fmla	v21.8h, %[b1].8h, %[a1a].h[1]\n"
285                 ASM_PREFETCHWL2("[%[c_ptr], #512]")
286                 "fmla	v22.8h, %[b1].8h, %[a1a].h[2]\n"
287                 "fmla	v23.8h, %[b1].8h, %[a1a].h[3]\n"
288                 ASM_PREFETCHWL2("[%[c_ptr], #576]")
289 
290                 "fmla	v24.8h, %[b2].8h, %[a0a].h[0]\n"
291                 "fmla	v25.8h, %[b2].8h, %[a0a].h[1]\n"
292                 ASM_PREFETCHWL2("[%[c_ptr], #640]")
293                 "fmla	v26.8h, %[b2].8h, %[a0a].h[2]\n"
294                 "fmla	v27.8h, %[b2].8h, %[a0a].h[3]\n"
295                 ASM_PREFETCHWL2("[%[c_ptr], #704]")
296 
297                 "fmla	v28.8h, %[b2].8h, %[a1a].h[0]\n"
298                 "fmla	v29.8h, %[b2].8h, %[a1a].h[1]\n"
299                 "add	%[b_ptr], %[b_ptr], #96\n"
300                 "fmla	v30.8h, %[b2].8h, %[a1a].h[2]\n"
301                 "fmla	v31.8h, %[b2].8h, %[a1a].h[3]\n"
302                 "b	3f\n"
303 
304                 "2:\n"
305 
306                 // Odd tail
307                 "fmla	v11.8h, %[b0].8h, %[a0].h[3]\n"
308                 ASM_PREFETCHW("[%[c_ptr]]")
309 
310                 "fmla 	v12.8h, %[b0].8h, %[a1].h[0]\n"
311                 "ins	%[b2].d[1], x20\n"
312                 "fmla	v13.8h, %[b0].8h, %[a1].h[1]\n"
313                 ASM_PREFETCHW("[%[c_ptr], #64]")
314                 "fmla	v14.8h, %[b0].8h, %[a1].h[2]\n"
315                 "add	%[a_ptr], %[a_ptr], #16\n"
316                 "fmla	v15.8h, %[b0].8h, %[a1].h[3]\n"
317                 ASM_PREFETCHW("[%[c_ptr], #128]")
318 
319                 "fmla	v16.8h, %[b1].8h, %[a0].h[0]\n"
320                 "add	%[b_ptr], %[b_ptr], #48\n"
321                 "fmla	v17.8h, %[b1].8h, %[a0].h[1]\n"
322                 ASM_PREFETCHW("[%[c_ptr], #192]")
323                 "fmla	v18.8h, %[b1].8h, %[a0].h[2]\n"
324                 "fmla	v19.8h, %[b1].8h, %[a0].h[3]\n"
325                 ASM_PREFETCHW("[%[c_ptr], #256]")
326 
327                 "fmla	v20.8h, %[b1].8h, %[a1].h[0]\n"
328                 "fmla	v21.8h, %[b1].8h, %[a1].h[1]\n"
329                 ASM_PREFETCHW("[%[c_ptr], #320]")
330                 "fmla	v22.8h, %[b1].8h, %[a1].h[2]\n"
331                 "fmla	v23.8h, %[b1].8h, %[a1].h[3]\n"
332                 ASM_PREFETCHWL2("[%[c_ptr], #384]")
333 
334                 "fmla	v24.8h, %[b2].8h, %[a0].h[0]\n"
335                 "fmla	v25.8h, %[b2].8h, %[a0].h[1]\n"
336                 ASM_PREFETCHWL2("[%[c_ptr], #384]")
337                 "fmla	v26.8h, %[b2].8h, %[a0].h[2]\n"
338                 "fmla	v27.8h, %[b2].8h, %[a0].h[3]\n"
339                 ASM_PREFETCHWL2("[%[c_ptr], #448]")
340 
341                 "fmla	v28.8h, %[b2].8h, %[a1].h[0]\n"
342                 ASM_PREFETCHWL2("[%[c_ptr], #512]")
343                 "fmla	v29.8h, %[b2].8h, %[a1].h[1]\n"
344                 ASM_PREFETCHWL2("[%[c_ptr], #576]")
345                 "fmla	v30.8h, %[b2].8h, %[a1].h[2]\n"
346                 ASM_PREFETCHWL2("[%[c_ptr], #640]")
347                 "fmla	v31.8h, %[b2].8h, %[a1].h[3]\n"
348                 ASM_PREFETCHWL2("[%[c_ptr], #704]")
349 
350                 // Common tail
351                 // A55 won't dual issue these stores with anything else, so
352                 // simplest to do them all in this common code.
353                 "3:\n"
354                 "str	q8,  [%[c_ptr]]\n"
355                 "str	q16, [%[c_ptr], #16]\n"
356                 "str	q24, [%[c_ptr], #32]\n"
357                 "str	q9,  [%[c_ptr], #48]\n"
358                 "str	q17, [%[c_ptr], #64]\n"
359                 "str	q25, [%[c_ptr], #80]\n"
360                 "str	q10, [%[c_ptr], #96]\n"
361                 "str	q18, [%[c_ptr], #112]\n"
362                 "str	q26, [%[c_ptr], #128]\n"
363                 "str	q11, [%[c_ptr], #144]\n"
364                 "str	q19, [%[c_ptr], #160]\n"
365                 "str	q27, [%[c_ptr], #176]\n"
366                 "str	q12, [%[c_ptr], #192]\n"
367                 "str	q20, [%[c_ptr], #208]\n"
368                 "str	q28, [%[c_ptr], #224]\n"
369                 "str	q13, [%[c_ptr], #240]\n"
370                 "str	q21, [%[c_ptr], #256]\n"
371                 "str	q29, [%[c_ptr], #272]\n"
372                 "str	q14, [%[c_ptr], #288]\n"
373                 "str	q22, [%[c_ptr], #304]\n"
374                 "str	q30, [%[c_ptr], #320]\n"
375                 "str	q15, [%[c_ptr], #336]\n"
376                 "str	q23, [%[c_ptr], #352]\n"
377                 "str	q31, [%[c_ptr], #368]\n"
378                 "5:\n"
379                 "add	%[c_ptr], %[c_ptr], #384\n"
380             :
381               [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
382               [a0] "=w" (a0), [a0a] "=w" (a0a), [a1] "=w" (a1), [a1a] "=w" (a1a),
383               [b0] "=w" (b0), [b1] "=w" (b1), [b2] "=w" (b2), [k] "+r" (k)
384             : [oddk] "r" (oddk)
385             : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
386               "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"
387             );
388         }
389     }
390 }
391 
392 } // namespace arm_gemm
393 
394 #endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
395