xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/kernels/a64_hgemm_8x24/generic.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 // Build on AArch64 where either FP16_KERNELS is set or FP16 is explicitly supported.
26 #if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
27 
28 #include <arm_neon.h>
29 
30 #include "../../asmlib.hpp"
31 
32 // Kernel implementation.
33 //
34 // Assume that "Apanel" points to a chunk of A blocks (each size 8xK) in read-order.
35 // Assume that "Bpanel" points to a chunk of B blocks (each size 24xK) in read-order.
36 // Assume that "Cpanel" points to a chunk of C output blocks (each size
37 // 8x24), the chunks being arranged in a row major fashion.
38 //
39 // Note that the intent of this is that either ablocks or bblocks will be 1
40 // - this construction allows the output loop to proceed in either order.
41 
42 namespace arm_gemm {
43 
a64_hgemm_asimd_8x24(const __fp16 * Apanel,const __fp16 * Bpanel,__fp16 * Cpanel,int ablocks,int bblocks,int K)44 void a64_hgemm_asimd_8x24(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) {
45     const __fp16 *a_ptr = Apanel;
46     __fp16 *c_ptr = Cpanel;
47 
48     for (int yb=0; yb<ablocks; yb++) {
49         const __fp16 *a_ptr0 = a_ptr;
50         const __fp16 *b_ptr = Bpanel;
51 
52         for (int xb=0; xb<bblocks; xb++) {
53             a_ptr = a_ptr0;
54             // Fix up for odd lengths - set a flag if K is odd, but make
55             // sure we round up the iteration count.
56             int oddk = (K & 1);
57             int k = ((K+1)/2) - 1;
58 
59             register float16x8_t a0  asm("v0");
60             register float16x8_t a0a asm("v1");
61             register float16x8_t b0  asm("v2");
62             register float16x8_t b1  asm("v3");
63             register float16x8_t b2  asm("v4");
64             register float16x8_t b0a asm("v5");
65             register float16x8_t b1a asm("v6");
66             register float16x8_t b2a asm("v7");
67 
68             __asm __volatile (
69                 // Initialize result registers, load initial operands, prime prefetches.
70                 "movi	v8.8h, #0x0\n"
71                 "ldr	%q[a0], [%[a_ptr]]\n"
72                 "movi	v9.8h, #0x0\n"
73                 "ldr	%q[b0], [%[b_ptr]]\n"
74                 "movi	v10.8h, #0x0\n"
75                 "ldr	%q[b1], [%[b_ptr], #16]\n"
76                 "movi	v11.8h, #0x0\n"
77                 "ldr	%q[b2], [%[b_ptr], #32]\n"
78                 "movi	v12.8h, #0x0\n"
79                 "ldr	%q[b0a], [%[b_ptr], #48]\n"
80                 "movi	v13.8h, #0x0\n"
81                 "ldr	%q[b1a], [%[b_ptr], #64]\n"
82                 "movi	v14.8h, #0x0\n"
83                 ASM_PREFETCH("[%[b_ptr], #64]")
84                 "movi	v15.8h, #0x0\n"
85                 ASM_PREFETCH("[%[b_ptr], #128]")
86                 "movi	v16.8h, #0x0\n"
87                 ASM_PREFETCH("[%[a_ptr], #64]")
88                 "movi	v17.8h, #0x0\n"
89                 ASM_PREFETCH("[%[b_ptr], #192]")
90                 "movi	v18.8h, #0x0\n"
91                 ASM_PREFETCH("[%[b_ptr], #256]")
92                 "movi	v19.8h, #0x0\n"
93                 ASM_PREFETCH("[%[b_ptr], #320]")
94                 "movi	v20.8h, #0x0\n"
95                 "movi	v21.8h, #0x0\n"
96                 "movi	v22.8h, #0x0\n"
97                 "movi	v23.8h, #0x0\n"
98                 "movi	v24.8h, #0x0\n"
99                 "movi	v25.8h, #0x0\n"
100                 "movi	v26.8h, #0x0\n"
101                 "movi	v27.8h, #0x0\n"
102                 "movi	v28.8h, #0x0\n"
103                 "movi	v29.8h, #0x0\n"
104                 "movi	v30.8h, #0x0\n"
105                 "movi	v31.8h, #0x0\n"
106 
107                 // Skip loop if we are doing zero iterations of it.
108                 "cbz	%w[k], 4f\n"
109 
110                 "1:\n"
111                 "fmla 	v8.8h , %[b0].8h, %[a0].h[0]\n"
112                 "fmla  	v9.8h , %[b0].8h, %[a0].h[1]\n"
113                 "ldr	%q[a0a], [%[a_ptr], #16]\n"
114                 "fmla	v10.8h, %[b0].8h, %[a0].h[2]\n"
115                 "fmla	v11.8h, %[b0].8h, %[a0].h[3]\n"
116                 "ldr	%q[b2a], [%[b_ptr], #80]\n"
117                 "fmla 	v12.8h, %[b0].8h, %[a0].h[4]\n"
118                 "fmla	v13.8h, %[b0].8h, %[a0].h[5]\n"
119                 "fmla	v14.8h, %[b0].8h, %[a0].h[6]\n"
120                 "fmla	v15.8h, %[b0].8h, %[a0].h[7]\n"
121                 "ldr	%q[b0], [%[b_ptr], #96]\n"
122 
123                 "fmla	v16.8h, %[b1].8h, %[a0].h[0]\n"
124                 "fmla	v17.8h, %[b1].8h, %[a0].h[1]\n"
125                 ASM_PREFETCH("[%[a_ptr], #128]")
126                 "fmla	v18.8h, %[b1].8h, %[a0].h[2]\n"
127                 "fmla	v19.8h, %[b1].8h, %[a0].h[3]\n"
128                 "add	%[b_ptr], %[b_ptr], #96\n"
129                 "fmla	v20.8h, %[b1].8h, %[a0].h[4]\n"
130                 "fmla	v21.8h, %[b1].8h, %[a0].h[5]\n"
131                 "fmla	v22.8h, %[b1].8h, %[a0].h[6]\n"
132                 "fmla	v23.8h, %[b1].8h, %[a0].h[7]\n"
133                 "ldr	%q[b1], [%[b_ptr], #16]\n"
134 
135                 "fmla	v24.8h, %[b2].8h, %[a0].h[0]\n"
136                 "fmla	v25.8h, %[b2].8h, %[a0].h[1]\n"
137                 ASM_PREFETCH("[%[b_ptr], #288]")
138                 "fmla	v26.8h, %[b2].8h, %[a0].h[2]\n"
139                 "fmla	v27.8h, %[b2].8h, %[a0].h[3]\n"
140                 "fmla	v28.8h, %[b2].8h, %[a0].h[4]\n"
141                 "fmla	v29.8h, %[b2].8h, %[a0].h[5]\n"
142                 "fmla	v30.8h, %[b2].8h, %[a0].h[6]\n"
143                 "fmla	v31.8h, %[b2].8h, %[a0].h[7]\n"
144                 "ldr	%q[a0], [%[a_ptr], #32]\n"
145 
146                 "fmla 	v8.8h , %[b0a].8h, %[a0a].h[0]\n"
147                 "fmla	v9.8h , %[b0a].8h, %[a0a].h[1]\n"
148                 "ldr	%q[b2], [%[b_ptr], #32]\n"
149                 "fmla	v10.8h, %[b0a].8h, %[a0a].h[2]\n"
150                 "fmla	v11.8h, %[b0a].8h, %[a0a].h[3]\n"
151                 "fmla 	v12.8h, %[b0a].8h, %[a0a].h[4]\n"
152                 "fmla	v13.8h, %[b0a].8h, %[a0a].h[5]\n"
153                 "fmla	v14.8h, %[b0a].8h, %[a0a].h[6]\n"
154                 "fmla	v15.8h, %[b0a].8h, %[a0a].h[7]\n"
155                 "ldr	%q[b0a], [%[b_ptr], #48]\n"
156 
157                 "fmla	v16.8h, %[b1a].8h, %[a0a].h[0]\n"
158                 "fmla	v17.8h, %[b1a].8h, %[a0a].h[1]\n"
159                 ASM_PREFETCH("[%[b_ptr], #352]")
160                 "fmla	v18.8h, %[b1a].8h, %[a0a].h[2]\n"
161                 "fmla	v19.8h, %[b1a].8h, %[a0a].h[3]\n"
162                 "fmla	v20.8h, %[b1a].8h, %[a0a].h[4]\n"
163                 "fmla	v21.8h, %[b1a].8h, %[a0a].h[5]\n"
164                 "fmla	v22.8h, %[b1a].8h, %[a0a].h[6]\n"
165                 "fmla	v23.8h, %[b1a].8h, %[a0a].h[7]\n"
166                 "ldr	%q[b1a], [%[b_ptr], #64]\n"
167 
168                 "fmla	v24.8h, %[b2a].8h, %[a0a].h[0]\n"
169                 "fmla	v25.8h, %[b2a].8h, %[a0a].h[1]\n"
170                 "add	%[a_ptr], %[a_ptr], #32\n"
171                 "fmla	v26.8h, %[b2a].8h, %[a0a].h[2]\n"
172                 "fmla	v27.8h, %[b2a].8h, %[a0a].h[3]\n"
173                 "fmla	v28.8h, %[b2a].8h, %[a0a].h[4]\n"
174                 "fmla	v29.8h, %[b2a].8h, %[a0a].h[5]\n"
175                 "subs	%w[k], %w[k], #1\n"
176                 "fmla	v30.8h, %[b2a].8h, %[a0a].h[6]\n"
177                 "fmla	v31.8h, %[b2a].8h, %[a0a].h[7]\n"
178 
179                 "bne	1b\n"
180                 "4:\n"
181 
182                 // Jump to odd tail if necessary.
183                 "cbnz	%w[oddk], 2f\n"
184 
185                 // Even tail.
186                 "fmla 	v8.8h , %[b0].8h, %[a0].h[0]\n"
187                 "fmla   v9.8h , %[b0].8h, %[a0].h[1]\n"
188                 "ldr	%q[a0a], [%[a_ptr], #16]\n"
189                 "fmla	v10.8h, %[b0].8h, %[a0].h[2]\n"
190                 "fmla	v11.8h, %[b0].8h, %[a0].h[3]\n"
191                 "ldr	%q[b2a], [%[b_ptr], #80]\n"
192                 "fmla 	v12.8h, %[b0].8h, %[a0].h[4]\n"
193                 "fmla   v13.8h, %[b0].8h, %[a0].h[5]\n"
194                 "fmla	v14.8h, %[b0].8h, %[a0].h[6]\n"
195                 "fmla	v15.8h, %[b0].8h, %[a0].h[7]\n"
196 
197                 "fmla	v16.8h, %[b1].8h, %[a0].h[0]\n"
198                 "fmla	v17.8h, %[b1].8h, %[a0].h[1]\n"
199                 "add	%[b_ptr], %[b_ptr], #96\n"
200                 "fmla	v18.8h, %[b1].8h, %[a0].h[2]\n"
201                 "fmla	v19.8h, %[b1].8h, %[a0].h[3]\n"
202                 "fmla	v20.8h, %[b1].8h, %[a0].h[4]\n"
203                 "fmla	v21.8h, %[b1].8h, %[a0].h[5]\n"
204                 "add	%[a_ptr], %[a_ptr], #32\n"
205                 "fmla	v22.8h, %[b1].8h, %[a0].h[6]\n"
206                 "fmla	v23.8h, %[b1].8h, %[a0].h[7]\n"
207 
208                 "fmla	v24.8h, %[b2].8h, %[a0].h[0]\n"
209                 "fmla	v25.8h, %[b2].8h, %[a0].h[1]\n"
210                 "fmla	v26.8h, %[b2].8h, %[a0].h[2]\n"
211                 "fmla	v27.8h, %[b2].8h, %[a0].h[3]\n"
212                 "fmla	v28.8h, %[b2].8h, %[a0].h[4]\n"
213                 "fmla	v29.8h, %[b2].8h, %[a0].h[5]\n"
214                 "fmla	v30.8h, %[b2].8h, %[a0].h[6]\n"
215                 "fmla	v31.8h, %[b2].8h, %[a0].h[7]\n"
216 
217                 "fmla 	v8.8h , %[b0a].8h, %[a0a].h[0]\n"
218                 "fmla	v16.8h, %[b1a].8h, %[a0a].h[0]\n"
219                 "str	q8, [%[c_ptr]]\n"
220                 "fmla	v24.8h, %[b2a].8h, %[a0a].h[0]\n"
221                 "str	q16, [%[c_ptr], #16]\n"
222 
223                 "fmla  	v9.8h , %[b0a].8h, %[a0a].h[1]\n"
224                 "str	q24, [%[c_ptr], #32]\n"
225                 "fmla	v17.8h, %[b1a].8h, %[a0a].h[1]\n"
226                 "str	q9, [%[c_ptr], #48]\n"
227                 "fmla	v25.8h, %[b2a].8h, %[a0a].h[1]\n"
228                 "str	q17, [%[c_ptr], #64]\n"
229 
230                 "fmla	v10.8h, %[b0a].8h, %[a0a].h[2]\n"
231                 "str	q25, [%[c_ptr], #80]\n"
232                 "fmla	v18.8h, %[b1a].8h, %[a0a].h[2]\n"
233                 "str	q10, [%[c_ptr], #96]\n"
234                 "fmla	v26.8h, %[b2a].8h, %[a0a].h[2]\n"
235                 "str	q18, [%[c_ptr], #112]\n"
236 
237                 "fmla	v11.8h, %[b0a].8h, %[a0a].h[3]\n"
238                 "str	q26, [%[c_ptr], #128]\n"
239                 "fmla	v19.8h, %[b1a].8h, %[a0a].h[3]\n"
240                 "str	q11, [%[c_ptr], #144]\n"
241                 "fmla	v27.8h, %[b2a].8h, %[a0a].h[3]\n"
242                 "str	q19, [%[c_ptr], #160]\n"
243 
244                 "fmla 	v12.8h, %[b0a].8h, %[a0a].h[4]\n"
245                 "str	q27, [%[c_ptr], #176]\n"
246                 "fmla	v20.8h, %[b1a].8h, %[a0a].h[4]\n"
247                 "str	q12, [%[c_ptr], #192]\n"
248                 "fmla	v28.8h, %[b2a].8h, %[a0a].h[4]\n"
249                 "str	q20, [%[c_ptr], #208]\n"
250 
251                 "fmla  	v13.8h, %[b0a].8h, %[a0a].h[5]\n"
252                 "str	q28, [%[c_ptr], #224]\n"
253                 "fmla	v21.8h, %[b1a].8h, %[a0a].h[5]\n"
254                 "str	q13, [%[c_ptr], #240]\n"
255                 "fmla	v29.8h, %[b2a].8h, %[a0a].h[5]\n"
256                 "str	q21, [%[c_ptr], #256]\n"
257 
258                 "fmla	v14.8h, %[b0a].8h, %[a0a].h[6]\n"
259                 "str	q29, [%[c_ptr], #272]\n"
260                 "fmla	v22.8h, %[b1a].8h, %[a0a].h[6]\n"
261                 "str	q14, [%[c_ptr], #288]\n"
262                 "fmla	v30.8h, %[b2a].8h, %[a0a].h[6]\n"
263                 "str	q22, [%[c_ptr], #304]\n"
264 
265                 "fmla	v15.8h, %[b0a].8h, %[a0a].h[7]\n"
266                 "str	q30, [%[c_ptr], #320]\n"
267                 "fmla	v23.8h, %[b1a].8h, %[a0a].h[7]\n"
268                 "str	q15, [%[c_ptr], #336]\n"
269                 "fmla	v31.8h, %[b2a].8h, %[a0a].h[7]\n"
270                 "b	3f\n"
271 
272                 // Odd tail
273                 "2:\n"
274                 "fmla 	v8.8h , %[b0].8h, %[a0].h[0]\n"
275                 "add	%[b_ptr], %[b_ptr], #48\n"
276                 "fmla	v16.8h, %[b1].8h, %[a0].h[0]\n"
277                 "add	%[a_ptr], %[a_ptr], #16\n"
278                 "str	q8, [%[c_ptr]]\n"
279                 "fmla	v24.8h, %[b2].8h, %[a0].h[0]\n"
280                 "str	q16, [%[c_ptr], #16]\n"
281 
282                 "fmla  	v9.8h , %[b0].8h, %[a0].h[1]\n"
283                 "str	q24, [%[c_ptr], #32]\n"
284                 "fmla	v17.8h, %[b1].8h, %[a0].h[1]\n"
285                 "str	q9, [%[c_ptr], #48]\n"
286                 "fmla	v25.8h, %[b2].8h, %[a0].h[1]\n"
287                 "str	q17, [%[c_ptr], #64]\n"
288 
289                 "fmla	v10.8h, %[b0].8h, %[a0].h[2]\n"
290                 "str	q25, [%[c_ptr], #80]\n"
291                 "fmla	v18.8h, %[b1].8h, %[a0].h[2]\n"
292                 "str	q10, [%[c_ptr], #96]\n"
293                 "fmla	v26.8h, %[b2].8h, %[a0].h[2]\n"
294                 "str	q18, [%[c_ptr], #112]\n"
295 
296                 "fmla	v11.8h, %[b0].8h, %[a0].h[3]\n"
297                 "str	q26, [%[c_ptr], #128]\n"
298                 "fmla	v19.8h, %[b1].8h, %[a0].h[3]\n"
299                 "str	q11, [%[c_ptr], #144]\n"
300                 "fmla	v27.8h, %[b2].8h, %[a0].h[3]\n"
301                 "str	q19, [%[c_ptr], #160]\n"
302 
303                 "fmla 	v12.8h, %[b0].8h, %[a0].h[4]\n"
304                 "str	q27, [%[c_ptr], #176]\n"
305                 "fmla	v20.8h, %[b1].8h, %[a0].h[4]\n"
306                 "str	q12, [%[c_ptr], #192]\n"
307                 "fmla	v28.8h, %[b2].8h, %[a0].h[4]\n"
308                 "str	q20, [%[c_ptr], #208]\n"
309 
310                 "fmla  	v13.8h, %[b0].8h, %[a0].h[5]\n"
311                 "str	q28, [%[c_ptr], #224]\n"
312                 "fmla	v21.8h, %[b1].8h, %[a0].h[5]\n"
313                 "str	q13, [%[c_ptr], #240]\n"
314                 "fmla	v29.8h, %[b2].8h, %[a0].h[5]\n"
315                 "str	q21, [%[c_ptr], #256]\n"
316 
317                 "fmla	v14.8h, %[b0].8h, %[a0].h[6]\n"
318                 "str	q29, [%[c_ptr], #272]\n"
319                 "fmla	v22.8h, %[b1].8h, %[a0].h[6]\n"
320                 "str	q14, [%[c_ptr], #288]\n"
321                 "fmla	v30.8h, %[b2].8h, %[a0].h[6]\n"
322                 "str	q22, [%[c_ptr], #304]\n"
323 
324                 "fmla	v15.8h, %[b0].8h, %[a0].h[7]\n"
325                 "str	q30, [%[c_ptr], #320]\n"
326                 "fmla	v23.8h, %[b1].8h, %[a0].h[7]\n"
327                 "str	q15, [%[c_ptr], #336]\n"
328                 "fmla	v31.8h, %[b2].8h, %[a0].h[7]\n"
329 
330                 "3:\n"
331                 "str	q23, [%[c_ptr], #352]\n"
332                 "str	q31, [%[c_ptr], #368]\n"
333                 "add	%[c_ptr], %[c_ptr], #384\n"
334             :
335               [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
336               [a0] "+w" (a0), [a0a] "+w" (a0a),
337               [b0] "+w" (b0), [b1] "+w" (b1), [b2] "+w" (b2), [k] "+r" (k),
338               [b0a] "+w" (b0a), [b1a] "+w" (b1a), [b2a] "+w" (b2a)
339             : [oddk] "r" (oddk)
340             : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
341               "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc"
342             );
343         }
344     }
345 }
346 
347 } // namespace arm_gemm
348 
349 #endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
350