xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a55r1.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __arm__
25 
26 #include <arm_neon.h>
27 
28 #include "../../asmlib.hpp"
29 
30 // Kernel implementation.
31 //
32 // Assume that "Apanel" points to a chunk of A blocks (each size 6xK) in read-order.
33 // Assume that "Bpanel" points to a chunk of B blocks (each size 8xK) in read-order.
34 // Assume that "Cpanel" points to a chunk of C output blocks (each size
35 // 8x6), the chunks being arranged in a row major fashion.
36 //
37 // Note that the intent of this is that either ablocks or bblocks will be 1
38 // - this construction allows the output loop to proceed in either order.
39 
40 namespace arm_gemm {
41 
a32_sgemm_8x6_a55r1(const float * Apanel,const float * Bpanel,float * Cpanel,int ablocks,int bblocks,int K)42 void a32_sgemm_8x6_a55r1(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
43     const float *a_ptr = Apanel;
44     float *c_ptr = Cpanel;
45 
46     /* Work out starting values for "k" and "tails" in the inner loop. */
47     int tails_initial = (K & 3);
48     if (tails_initial == 0) {
49         tails_initial = 4;
50     }
51 
52     int k_initial = ((K+3)/4) - 1;
53 
54     for (int yb=0; yb<ablocks; yb++) {
55         const float *a_ptr0 = a_ptr;
56         const float *b_ptr = Bpanel;
57 
58         for (int xb=0; xb<bblocks; xb++) {
59             int tails = tails_initial;
60             int k = k_initial;
61 
62             a_ptr = a_ptr0;
63 
64             __asm __volatile (
65                 "vldr		d0, [%[a_ptr]]\n"
66                 "vmov.i32	q4, #0\n"
67                 "vldr		d1, [%[a_ptr], #0x08]\n"
68                 "vmov.i32	q5, #0\n"
69                 "vldr		d4, [%[b_ptr]]\n"
70                 "vmov.i32	q6, #0\n"
71                 "vldr		d5, [%[b_ptr], #0x08]\n"
72                 "vmov.i32	q7, #0\n"
73                 "vldr		d2, [%[a_ptr], #0x10]\n"
74                 "vmov.i32	q8, #0\n"
75                 ASM_PREFETCH("[%[b_ptr], #0x40]")
76                 "vmov.i32	q9, #0\n"
77                 ASM_PREFETCH("[%[a_ptr], #0x40]")
78                 "vmov.i32	q10, #0\n"
79                 ASM_PREFETCH("[%[b_ptr], #0x80]")
80                 "vmov.i32	q11, #0\n"
81                 ASM_PREFETCH("[%[a_ptr], #0x80]")
82                 "vmov.i32	q12, #0\n"
83                 ASM_PREFETCH("[%[b_ptr], #0XC0]")
84                 "vmov.i32	q13, #0\n"
85                 ASM_PREFETCH("[%[a_ptr], #0xC0]")
86                 "vmov.i32	q14, #0\n"
87                 ASM_PREFETCH("[%[b_ptr], #0x100]")
88                 "vmov.i32	q15, #0\n"
89                 ASM_PREFETCH("[%[a_ptr], #0x100]")
90                 "cmp		%[k], #0\n"
91                 ASM_PREFETCH("[%[b_ptr], #0x140]")
92                 "beq		6f\n"
93                 ASM_PREFETCH("[%[b_ptr], #0x180]")
94 
95                 "1:\n"
96                 // Unroll 0
97                 "vmla.f32	q4, q2, d0[0]\n"
98                 "vldr		d6, [%[b_ptr], #0x10]\n"
99                 "vmla.f32	q5, q2, d0[1]\n"
100                 "vldr		d7, [%[b_ptr], #0x18]\n"
101                 "vmla.f32	q6, q2, d1[0]\n"
102                 "vldr		d3, [%[a_ptr], #0x18]\n"
103                 "vmla.f32	q7, q2, d1[1]\n"
104                 ASM_PREFETCH("[%[a_ptr], #0x140]")
105                 "vmla.f32	q8, q2, d2[0]\n"
106                 "subs		%[k], %[k], #1\n"
107                 "vmla.f32	q9, q2, d2[1]\n"
108                 "vldr		d4, [%[b_ptr], #0x20]\n"
109                 "vmla.f32	q10, q3, d0[0]\n"
110                 "vldr		d5, [%[b_ptr], #0x28]\n"
111                 "vmla.f32	q11, q3, d0[1]\n"
112                 "vldr		d0, [%[a_ptr], #0x20]\n"
113                 "vmla.f32	q12, q3, d1[0]\n"
114 
115                 "vmla.f32	q13, q3, d1[1]\n"
116                 "vldr		d1, [%[a_ptr], #0x28]\n"
117                 "vmla.f32	q14, q3, d2[0]\n"
118 
119                 "vmla.f32	q15, q3, d2[1]\n"
120                 "vldr		d6, [%[b_ptr], #0x30]\n"
121 
122                 // Unroll 1
123                 "vmla.f32	q4, q2, d3[0]\n"
124                 "vldr		d7, [%[b_ptr], #0x38]\n"
125                 "vmla.f32	q5, q2, d3[1]\n"
126                 "vldr		d2, [%[a_ptr], #0x30]\n"
127                 "vmla.f32	q6, q2, d0[0]\n"
128 
129                 "vmla.f32	q7, q2, d0[1]\n"
130                 ASM_PREFETCH("[%[b_ptr], #0x1C0]")
131                 "vmla.f32	q8, q2, d1[0]\n"
132 
133                 "vmla.f32	q9, q2, d1[1]\n"
134                 "vldr		d4, [%[b_ptr], #0x40]\n"
135                 "vmla.f32	q10, q3, d3[0]\n"
136                 "vldr		d5, [%[b_ptr], #0x48]\n"
137                 "vmla.f32	q11, q3, d3[1]\n"
138                 "vldr		d3, [%[a_ptr], #0x38]\n"
139                 "vmla.f32	q12, q3, d0[0]\n"
140 
141                 "vmla.f32	q13, q3, d0[1]\n"
142                 "vldr		d0, [%[a_ptr], #0x40]\n"
143                 "vmla.f32	q14, q3, d1[0]\n"
144 
145                 "vmla.f32	q15, q3, d1[1]\n"
146                 "vldr		d6, [%[b_ptr], #0x50]\n"
147 
148                 // Unroll 2
149                 "vmla.f32	q4, q2, d2[0]\n"
150                 "vldr		d7, [%[b_ptr], #0x58]\n"
151                 "vmla.f32	q5, q2, d2[1]\n"
152                 "vldr		d1, [%[a_ptr], #0x48]\n"
153                 "vmla.f32	q6, q2, d3[0]\n"
154 
155                 "vmla.f32	q7, q2, d3[1]\n"
156                 ASM_PREFETCH("[%[a_ptr], #0x180]")
157                 "vmla.f32	q8, q2, d0[0]\n"
158 
159                 "vmla.f32	q9, q2, d0[1]\n"
160                 "vldr		d4, [%[b_ptr], #0x60]\n"
161                 "vmla.f32	q10, q3, d2[0]\n"
162                 "vldr		d5, [%[b_ptr], #0x68]\n"
163                 "vmla.f32	q11, q3, d2[1]\n"
164                 "vldr		d2, [%[a_ptr], #0x50]\n"
165                 "vmla.f32	q12, q3, d3[0]\n"
166 
167                 "vmla.f32	q13, q3, d3[1]\n"
168                 "vldr		d3, [%[a_ptr], #0x58]\n"
169                 "vmla.f32	q14, q3, d0[0]\n"
170                 "add		%[a_ptr], %[a_ptr], #0x60\n"
171                 "vmla.f32	q15, q3, d0[1]\n"
172                 "vldr		d6, [%[b_ptr], #0x70]\n"
173 
174                 // Unroll 3
175                 "vmla.f32	q4, q2, d1[0]\n"
176                 "vldr		d7, [%[b_ptr], #0x78]\n"
177                 "vmla.f32	q5, q2, d1[1]\n"
178                 "add		%[b_ptr], %[b_ptr], #0x80\n"
179                 "vmla.f32	q6, q2, d2[0]\n"
180                 "vldr		d0, [%[a_ptr], #0x00]\n"
181                 "vmla.f32	q7, q2, d2[1]\n"
182                 ASM_PREFETCH("[%[b_ptr], #0x180]")
183                 "vmla.f32	q8, q2, d3[0]\n"
184 
185                 "vmla.f32	q9, q2, d3[1]\n"
186                 "vldr		d4, [%[b_ptr], #0x00]\n"
187                 "vmla.f32	q10, q3, d1[0]\n"
188                 "vldr		d5, [%[b_ptr], #0x08]\n"
189                 "vmla.f32	q11, q3, d1[1]\n"
190                 "vldr		d1, [%[a_ptr], #0x08]\n"
191                 "vmla.f32	q12, q3, d2[0]\n"
192 
193                 "vmla.f32	q13, q3, d2[1]\n"
194                 "vldr		d2, [%[a_ptr], #0x10]\n"
195                 "vmla.f32	q14, q3, d3[0]\n"
196 
197                 "vmla.f32	q15, q3, d3[1]\n"
198                 "bne		1b\n"
199 
200                 // "Tails" shows how many multiply blocks are needed at the
201                 // end, must be 1-4 inclusive.  Bail out to alternative tail
202                 // immediately if it's 1.
203                 "6:\n"
204                 "subs		%[tails], %[tails], #1\n"
205                 "beq		3f\n"
206 
207                 // Detached final iteration
208 
209                 // Unroll 0
210                 "vmla.f32	q4, q2, d0[0]\n"
211                 "vldr		d6, [%[b_ptr], #0x10]\n"
212                 "vmla.f32	q5, q2, d0[1]\n"
213                 "vldr		d7, [%[b_ptr], #0x18]\n"
214                 "vmla.f32	q6, q2, d1[0]\n"
215                 "vldr		d3, [%[a_ptr], #0x18]\n"
216                 "vmla.f32	q7, q2, d1[1]\n"
217                 "subs		%[tails], %[tails], #1\n"
218                 "vmla.f32	q8, q2, d2[0]\n"
219                 "vmla.f32	q9, q2, d2[1]\n"
220                 "vldr		d4, [%[b_ptr], #0x20]\n"
221 
222                 "vmla.f32	q10, q3, d0[0]\n"
223                 "vldr		d5, [%[b_ptr], #0x28]\n"
224                 "vmla.f32	q11, q3, d0[1]\n"
225                 "vldr		d0, [%[a_ptr], #0x20]\n"
226                 "vmla.f32	q12, q3, d1[0]\n"
227                 "vmla.f32	q13, q3, d1[1]\n"
228                 "vldr		d1, [%[a_ptr], #0x28]\n"
229                 "vmla.f32	q14, q3, d2[0]\n"
230                 "vmla.f32	q15, q3, d2[1]\n"
231                 "beq		4f\n"
232 
233                 // Unroll 1
234                 "vmla.f32	q4, q2, d3[0]\n"
235                 "vldr		d6, [%[b_ptr], #0x30]\n"
236                 "vmla.f32	q5, q2, d3[1]\n"
237                 "vldr		d7, [%[b_ptr], #0x38]\n"
238                 "vmla.f32	q6, q2, d0[0]\n"
239                 "vldr		d2, [%[a_ptr], #0x30]\n"
240                 "vmla.f32	q7, q2, d0[1]\n"
241                 "subs		%[tails], %[tails], #1\n"
242                 "vmla.f32	q8, q2, d1[0]\n"
243 
244                 "vmla.f32	q9, q2, d1[1]\n"
245 
246                 "vmla.f32	q10, q3, d3[0]\n"
247                 "vldr		d4, [%[b_ptr], #0x40]\n"
248                 "vmla.f32	q11, q3, d3[1]\n"
249                 "vldr		d5, [%[b_ptr], #0x48]\n"
250                 "vmla.f32	q12, q3, d0[0]\n"
251                 "vldr		d3, [%[a_ptr], #0x38]\n"
252                 "vmla.f32	q13, q3, d0[1]\n"
253                 "vldr		d0, [%[a_ptr], #0x40]\n"
254                 "vmla.f32	q14, q3, d1[0]\n"
255                 "vmla.f32	q15, q3, d1[1]\n"
256                 "beq		5f\n"
257 
258                 // Unroll 2
259                 "vmla.f32	q4, q2, d2[0]\n"
260                 "vldr		d6, [%[b_ptr], #0x50]\n"
261                 "vmla.f32	q5, q2, d2[1]\n"
262                 "vldr		d7, [%[b_ptr], #0x58]\n"
263                 "vmla.f32	q6, q2, d3[0]\n"
264                 "vldr		d1, [%[a_ptr], #0x48]\n"
265                 "vmla.f32	q7, q2, d3[1]\n"
266                 "vmla.f32	q8, q2, d0[0]\n"
267                 "vmla.f32	q9, q2, d0[1]\n"
268 
269                 "vmla.f32	q10, q3, d2[0]\n"
270                 "vldr		d4, [%[b_ptr], #0x60]\n"
271                 "vmla.f32	q11, q3, d2[1]\n"
272                 "vldr		d5, [%[b_ptr], #0x68]\n"
273                 "vmla.f32	q12, q3, d3[0]\n"
274                 "vldr		d2, [%[a_ptr], #0x50]\n"
275                 "vmla.f32	q13, q3, d3[1]\n"
276                 "vldr		d3, [%[a_ptr], #0x58]\n"
277                 "vmla.f32	q14, q3, d0[0]\n"
278                 "vmla.f32	q15, q3, d0[1]\n"
279 
280                 // Unroll 3
281                 "vmla.f32	q4, q2, d1[0]\n"
282                 "vldr		d6, [%[b_ptr], #0x70]\n"
283                 "vmla.f32	q5, q2, d1[1]\n"
284                 "vldr		d7, [%[b_ptr], #0x78]\n"
285                 "vmla.f32	q10, q3, d1[0]\n"
286                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
287                 "vmla.f32	q11, q3, d1[1]\n"
288                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
289                 "vmla.f32	q6, q2, d2[0]\n"
290                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
291                 "vmla.f32	q12, q3, d2[0]\n"
292                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
293                 "vmla.f32	q7, q2, d2[1]\n"
294                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
295                 "vmla.f32	q13, q3, d2[1]\n"
296                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
297                 "vmla.f32	q8, q2, d3[0]\n"
298                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
299                 "vmla.f32	q14, q3, d3[0]\n"
300                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
301                 "vmla.f32	q9, q2, d3[1]\n"
302                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
303                 "vmla.f32	q15, q3, d3[1]\n"
304                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
305                 "add		%[a_ptr], %[a_ptr], #0x60\n"
306                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
307                 "add		%[b_ptr], %[b_ptr], #0x80\n"
308                 "b		2f\n"
309 
310                 // tails==1 final tail
311                 "3:\n"
312                 "vmla.f32	q4, q2, d0[0]\n"
313                 "vldr		d6, [%[b_ptr], #0x10]\n"
314                 "vmla.f32	q5, q2, d0[1]\n"
315                 "vldr		d7, [%[b_ptr], #0x18]\n"
316                 "vmla.f32	q6, q2, d1[0]\n"
317                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
318                 "vmla.f32	q10, q3, d0[0]\n"
319                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
320                 "vmla.f32	q11, q3, d0[1]\n"
321                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
322                 "vmla.f32	q12, q3, d1[0]\n"
323                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
324                 "vmla.f32	q7, q2, d1[1]\n"
325                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
326                 "vmla.f32	q13, q3, d1[1]\n"
327                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
328                 "vmla.f32	q8, q2, d2[0]\n"
329                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
330                 "vmla.f32	q14, q3, d2[0]\n"
331                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
332                 "vmla.f32	q9, q2, d2[1]\n"
333                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
334                 "vmla.f32	q15, q3, d2[1]\n"
335                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
336                 "add		%[a_ptr], %[a_ptr], #0x18\n"
337                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
338                 "add		%[b_ptr], %[b_ptr], #0x20\n"
339                 "b		2f\n"
340 
341                 // tails==2 final tail
342                 "4:\n"
343                 "vmla.f32	q4, q2, d3[0]\n"
344                 "vldr		d6, [%[b_ptr], #0x30]\n"
345                 "vmla.f32	q5, q2, d3[1]\n"
346                 "vldr		d7, [%[b_ptr], #0x38]\n"
347                 "vmla.f32	q10, q3, d3[0]\n"
348                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
349                 "vmla.f32	q11, q3, d3[1]\n"
350                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
351                 "vmla.f32	q6, q2, d0[0]\n"
352                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
353                 "vmla.f32	q12, q3, d0[0]\n"
354                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
355                 "vmla.f32	q7, q2, d0[1]\n"
356                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
357                 "vmla.f32	q13, q3, d0[1]\n"
358                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
359                 "vmla.f32	q8, q2, d1[0]\n"
360                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
361                 "vmla.f32	q14, q3, d1[0]\n"
362                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
363                 "vmla.f32	q9, q2, d1[1]\n"
364                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
365                 "vmla.f32	q15, q3, d1[1]\n"
366                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
367                 "add		%[b_ptr], %[b_ptr], #0x40\n"
368                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
369                 "add		%[a_ptr], %[a_ptr], #0x30\n"
370                 "b		2f\n"
371 
372                 // tails==3 final tail
373                 "5:\n"
374                 "vmla.f32	q4, q2, d2[0]\n"
375                 "vldr		d6, [%[b_ptr], #0x50]\n"
376                 "vmla.f32	q5, q2, d2[1]\n"
377                 "vldr		d7, [%[b_ptr], #0x58]\n"
378                 "vmla.f32	q6, q2, d3[0]\n"
379                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
380                 "vmla.f32	q10, q3, d2[0]\n"
381                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
382                 "vmla.f32	q11, q3, d2[1]\n"
383                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
384                 "vmla.f32	q12, q3, d3[0]\n"
385                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
386                 "vmla.f32	q7, q2, d3[1]\n"
387                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
388                 "vmla.f32	q13, q3, d3[1]\n"
389                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
390                 "vmla.f32	q8, q2, d0[0]\n"
391                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
392                 "vmla.f32	q14, q3, d0[0]\n"
393                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
394                 "vmla.f32	q9, q2, d0[1]\n"
395                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
396                 "vmla.f32	q15, q3, d0[1]\n"
397                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
398                 "add		%[a_ptr], %[a_ptr], #0x48\n"
399                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
400                 "add		%[b_ptr], %[b_ptr], #0x60\n"
401 
402                 "2:\n"
403                 "vst1.32	{d30-d31}, [%[c_ptr] :128]!\n"
404             : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k), [tails] "+r" (tails)
405             :
406             : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15",
407               "r0", "r1", "cc", "memory"
408             );
409         }
410     }
411 }
412 
413 } // namespace arm_gemm
414 
415 #endif
416