xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/a53.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __arm__
25 
26 #include <arm_neon.h>
27 
28 #include "../../asmlib.hpp"
29 
30 // Kernel implementation.
31 //
32 // Assume that "Apanel" points to a chunk of A blocks (each size 6xK) in read-order.
33 // Assume that "Bpanel" points to a chunk of B blocks (each size 8xK) in read-order.
34 // Assume that "Cpanel" points to a chunk of C output blocks (each size
35 // 8x6), the chunks being arranged in a row major fashion.
36 //
37 // Note that the intent of this is that either ablocks or bblocks will be 1
38 // - this construction allows the output loop to proceed in either order.
39 
40 namespace arm_gemm {
41 
a32_sgemm_8x6_a53(const float * Apanel,const float * Bpanel,float * Cpanel,int ablocks,int bblocks,int K)42 void a32_sgemm_8x6_a53(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
43     const float *a_ptr = Apanel;
44     float *c_ptr = Cpanel;
45 
46     for (int yb=0; yb<ablocks; yb++) {
47         const float *a_ptr0 = a_ptr;
48         const float *b_ptr = Bpanel;
49 
50         for (int xb=0; xb<bblocks; xb++) {
51             a_ptr = a_ptr0;
52             int tails = (K & 3);
53             if (tails == 0) {
54                 tails = 4;
55             }
56             int k = ((K+3)/4) - 1;
57 
58             __asm __volatile (
59                 "vmov.i32	q4, #0\n"
60                 "vld1.32	{d0-d1}, [%[a_ptr] :64]\n"
61                 "vmov.i32	q5, #0\n"
62                 "vld1.32	{d4-d5}, [%[b_ptr] :128]\n"
63                 "vmov.i32	q6, #0\n"
64                 "ldr		r0, [%[a_ptr], #0x10]\n"
65                 "vmov.i32	q7, #0\n"
66                 "ldr		r1, [%[a_ptr], #0x14]\n"
67                 "vmov.i32	q8, #0\n"
68                 ASM_PREFETCH("[%[a_ptr], #0x40]")
69                 "vmov.i32	q9, #0\n"
70                 ASM_PREFETCH("[%[b_ptr], #0x40]")
71                 "vmov.i32	q10, #0\n"
72                 ASM_PREFETCH("[%[a_ptr], #0x80]")
73                 "vmov.i32	q11, #0\n"
74                 ASM_PREFETCH("[%[b_ptr], #0x80]")
75                 "vmov.i32	q12, #0\n"
76                 "vmov.i32	q13, #0\n"
77                 ASM_PREFETCH("[%[a_ptr], #0xC0]")
78                 "vmov.i32	q14, #0\n"
79                 ASM_PREFETCH("[%[b_ptr], #0XC0]")
80                 "vmov.i32	q15, #0\n"
81                 "cmp		%[k], #0\n"
82                 "beq		6f\n"
83 
84                 "1:\n"
85                 // Unroll 0
86                 "vldr		d6, [%[b_ptr], #0x10]\n"
87                 "vmov		d2, r0, r1\n"
88                 "vmla.f32	q4, q2, d0[0]\n"
89                 "ldr		r0, [%[b_ptr], #0x18]\n"
90                 "vmla.f32	q5, q2, d0[1]\n"
91                 "ldr		r1, [%[b_ptr], #0x1C]\n"
92                 "vmla.f32	q6, q2, d1[0]\n"
93 
94                 "vldr		d3, [%[a_ptr], #0x18]\n"
95                 "vmov		d7, r0, r1\n"
96                 "vmla.f32	q7, q2, d1[1]\n"
97                 ASM_PREFETCH("[%[a_ptr], #0x100]")
98                 "vmla.f32	q8, q2, d2[0]\n"
99                 "vmla.f32	q9, q2, d2[1]\n"
100 
101                 "vldr		d4, [%[b_ptr], #0x20]\n"
102                 "vmla.f32	q10, q3, d0[0]\n"
103                 "ldr		r0, [%[b_ptr], #0x28]\n"
104                 "vmla.f32	q11, q3, d0[1]\n"
105                 "ldr		r1, [%[b_ptr], #0x2C]\n"
106                 "vmla.f32	q12, q3, d1[0]\n"
107 
108                 "vldr		d0, [%[a_ptr], #0x20]\n"
109                 "vmov		d5, r0, r1\n"
110                 "vmla.f32	q13, q3, d1[1]\n"
111                 "ldr		r0, [%[a_ptr], #0x28]\n"
112                 "vmla.f32	q14, q3, d2[0]\n"
113                 "ldr		r1, [%[a_ptr], #0x2C]\n"
114                 "vmla.f32	q15, q3, d2[1]\n"
115 
116                 // Unroll 1
117                 "vldr		d6, [%[b_ptr], #0x30]\n"
118                 "vmov		d1, r0, r1\n"
119                 "vmla.f32	q4, q2, d3[0]\n"
120                 "ldr		r0, [%[b_ptr], #0x38]\n"
121                 "vmla.f32	q5, q2, d3[1]\n"
122                 "ldr		r1, [%[b_ptr], #0x3C]\n"
123                 "vmla.f32	q6, q2, d0[0]\n"
124 
125                 "vldr		d2, [%[a_ptr], #0x30]\n"
126                 "vmov		d7, r0, r1\n"
127                 "vmla.f32	q7, q2, d0[1]\n"
128                 ASM_PREFETCH("[%[b_ptr], #0x100]")
129                 "vmla.f32	q8, q2, d1[0]\n"
130                 "vmla.f32	q9, q2, d1[1]\n"
131 
132                 "vldr		d4, [%[b_ptr], #0x40]\n"
133                 "vmla.f32	q10, q3, d3[0]\n"
134                 "ldr		r0, [%[b_ptr], #0x48]\n"
135                 "vmla.f32	q11, q3, d3[1]\n"
136                 "ldr		r1, [%[b_ptr], #0x4C]\n"
137                 "vmla.f32	q12, q3, d0[0]\n"
138 
139                 "vldr		d3, [%[a_ptr], #0x38]\n"
140                 "vmov		d5, r0, r1\n"
141                 "vmla.f32	q13, q3, d0[1]\n"
142                 "ldr		r0, [%[a_ptr], #0x40]\n"
143                 "vmla.f32	q14, q3, d1[0]\n"
144                 "ldr		r1, [%[a_ptr], #0x44]\n"
145                 "vmla.f32	q15, q3, d1[1]\n"
146 
147                 // Unroll 2
148                 "vldr		d6, [%[b_ptr], #0x50]\n"
149                 "vmov		d0, r0, r1\n"
150                 "vmla.f32	q4, q2, d2[0]\n"
151                 "ldr		r0, [%[b_ptr], #0x58]\n"
152                 "vmla.f32	q5, q2, d2[1]\n"
153                 "ldr		r1, [%[b_ptr], #0x5C]\n"
154                 "vmla.f32	q6, q2, d3[0]\n"
155 
156                 "vldr		d1, [%[a_ptr], #0x48]\n"
157                 "vmov		d7, r0, r1\n"
158                 "vmla.f32	q7, q2, d3[1]\n"
159                 ASM_PREFETCH("[%[a_ptr], #0x140]")
160                 "vmla.f32	q8, q2, d0[0]\n"
161                 "vmla.f32	q9, q2, d0[1]\n"
162 
163                 "vldr		d4, [%[b_ptr], #0x60]\n"
164                 "vmla.f32	q10, q3, d2[0]\n"
165                 "ldr		r0, [%[b_ptr], #0x68]\n"
166                 "vmla.f32	q11, q3, d2[1]\n"
167                 "ldr		r1, [%[b_ptr], #0x6C]\n"
168                 "vmla.f32	q12, q3, d3[0]\n"
169 
170                 "vldr		d2, [%[a_ptr], #0x50]\n"
171                 "vmov		d5, r0, r1\n"
172                 "vmla.f32	q13, q3, d3[1]\n"
173                 "ldr		r0, [%[a_ptr], #0x58]\n"
174                 "vmla.f32	q14, q3, d0[0]\n"
175                 "ldr		r1, [%[a_ptr], #0x5C]\n"
176                 "vmla.f32	q15, q3, d0[1]\n"
177                 "add		%[a_ptr], %[a_ptr], #0x60\n"
178 
179                 // Unroll 3
180                 "vldr		d6, [%[b_ptr], #0x70]\n"
181                 "vmov		d3, r0, r1\n"
182                 "vmla.f32	q4, q2, d1[0]\n"
183                 "ldr		r0, [%[b_ptr], #0x78]\n"
184                 "vmla.f32	q5, q2, d1[1]\n"
185                 "ldr		r1, [%[b_ptr], #0x7C]\n"
186                 "vmla.f32	q6, q2, d2[0]\n"
187                 "add		%[b_ptr], %[b_ptr], #0x80\n"
188 
189                 "vldr		d0, [%[a_ptr], #0x00]\n"
190                 "vmov		d7, r0, r1\n"
191                 "vmla.f32	q7, q2, d2[1]\n"
192                 ASM_PREFETCH("[%[b_ptr], #0xC0]")
193                 "vmla.f32	q8, q2, d3[0]\n"
194                 "vmla.f32	q9, q2, d3[1]\n"
195 
196                 "vldr		d4, [%[b_ptr], #0x00]\n"
197                 "vmla.f32	q10, q3, d1[0]\n"
198                 "ldr		r0, [%[b_ptr], #0x08]\n"
199                 "vmla.f32	q11, q3, d1[1]\n"
200                 "ldr		r1, [%[b_ptr], #0x0C]\n"
201                 "vmla.f32	q12, q3, d2[0]\n"
202                 "subs		%[k], %[k], #1\n"
203 
204                 "vldr		d1, [%[a_ptr], #0x08]\n"
205                 "vmov		d5, r0, r1\n"
206                 "vmla.f32	q13, q3, d2[1]\n"
207                 "ldr		r0, [%[a_ptr], #0x10]\n"
208                 "vmla.f32	q14, q3, d3[0]\n"
209                 "ldr		r1, [%[a_ptr], #0x14]\n"
210                 "vmla.f32	q15, q3, d3[1]\n"
211                 "bne		1b\n"
212 
213                 // "Tails" shows how many multiply blocks are needed at the
214                 // end, must be 1-4 inclusive.  Bail out to alternative tail
215                 // immediately if it's 1.
216                 "6:\n"
217                 "subs		%[tails], %[tails], #1\n"
218                 "beq		3f\n"
219 
220                 // Detached final iteration - for now adapt the generic
221                 // tails rather than reimplementing for A53.
222 
223                 // Unroll 0
224                 "vmov		d2, r0, r1\n"
225                 "add		%[a_ptr], %[a_ptr], #0x18\n"
226                 "vmla.f32	q4, q2, d0[0]\n"
227                 "vld1.32	{d3}, [%[a_ptr] :64]!\n"
228                 "vmla.f32	q5, q2, d0[1]\n"
229                 "add		%[b_ptr], %[b_ptr], #0x10\n"
230                 "vmla.f32	q6, q2, d1[0]\n"
231                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
232                 "vmla.f32	q7, q2, d1[1]\n"
233                 "vmla.f32	q8, q2, d2[0]\n"
234                 "subs		%[tails], %[tails], #1\n"
235                 "vmla.f32	q9, q2, d2[1]\n"
236                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
237 
238                 "vmla.f32	q10, q3, d0[0]\n"
239                 "vmla.f32	q11, q3, d0[1]\n"
240                 "vmla.f32	q12, q3, d1[0]\n"
241                 "vmla.f32	q13, q3, d1[1]\n"
242                 "vld1.32	{d0-d1}, [%[a_ptr] :64]!\n"
243                 "vmla.f32	q14, q3, d2[0]\n"
244                 "vmla.f32	q15, q3, d2[1]\n"
245                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
246                 "beq		4f\n"
247 
248                 // Unroll 1
249                 "vmla.f32	q4, q2, d3[0]\n"
250                 "vmla.f32	q5, q2, d3[1]\n"
251                 "subs		%[tails], %[tails], #1\n"
252                 "vmla.f32	q6, q2, d0[0]\n"
253                 "vmla.f32	q7, q2, d0[1]\n"
254                 "vmla.f32	q8, q2, d1[0]\n"
255                 "vmla.f32	q9, q2, d1[1]\n"
256                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
257 
258                 "vmla.f32	q10, q3, d3[0]\n"
259                 "vmla.f32	q11, q3, d3[1]\n"
260                 "vld1.32	{d2-d3}, [%[a_ptr] :64]!\n"
261                 "vmla.f32	q12, q3, d0[0]\n"
262                 "vmla.f32	q13, q3, d0[1]\n"
263                 "vmla.f32	q14, q3, d1[0]\n"
264                 "vmla.f32	q15, q3, d1[1]\n"
265                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
266                 "beq		5f\n"
267 
268                 // Unroll 2
269                 "vld1.32	{d0-d1}, [%[a_ptr] :64]!\n"
270                 "vmla.f32	q4, q2, d2[0]\n"
271                 "vmla.f32	q5, q2, d2[1]\n"
272                 "vmla.f32	q6, q2, d3[0]\n"
273                 "vmla.f32	q7, q2, d3[1]\n"
274                 "vmla.f32	q8, q2, d0[0]\n"
275                 "vmla.f32	q9, q2, d0[1]\n"
276                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
277 
278                 "vmla.f32	q10, q3, d2[0]\n"
279                 "vmla.f32	q11, q3, d2[1]\n"
280                 "vmla.f32	q12, q3, d3[0]\n"
281                 "vmla.f32	q13, q3, d3[1]\n"
282                 "vld1.32	{d2-d3}, [%[a_ptr] :64]!\n"
283                 "vmla.f32	q14, q3, d0[0]\n"
284                 "vmla.f32	q15, q3, d0[1]\n"
285                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
286 
287                 // Unroll 3
288                 "vmla.f32	q4, q2, d1[0]\n"
289                 "vmla.f32	q10, q3, d1[0]\n"
290                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
291                 "vmla.f32	q5, q2, d1[1]\n"
292                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
293                 "vmla.f32	q11, q3, d1[1]\n"
294                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
295                 "vmla.f32	q6, q2, d2[0]\n"
296                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
297                 "vmla.f32	q12, q3, d2[0]\n"
298                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
299                 "vmla.f32	q7, q2, d2[1]\n"
300                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
301                 "vmla.f32	q13, q3, d2[1]\n"
302                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
303                 "vmla.f32	q8, q2, d3[0]\n"
304                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
305                 "vmla.f32	q14, q3, d3[0]\n"
306                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
307                 "vmla.f32	q9, q2, d3[1]\n"
308                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
309                 "vmla.f32	q15, q3, d3[1]\n"
310                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
311                 "b		2f\n"
312 
313                 // tails==1 final tail
314                 "3:\n"
315                 "vmov		d2, r0, r1\n"
316                 "add		%[b_ptr], %[b_ptr], #0x10\n"
317                 "vmla.f32	q4, q2, d0[0]\n"
318                 "add		%[a_ptr], %[a_ptr], #0x18\n"
319                 "vmla.f32	q5, q2, d0[1]\n"
320                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
321                 "vmla.f32	q6, q2, d1[0]\n"
322                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
323                 "vmla.f32	q10, q3, d0[0]\n"
324                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
325                 "vmla.f32	q11, q3, d0[1]\n"
326                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
327                 "vmla.f32	q12, q3, d1[0]\n"
328                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
329                 "vmla.f32	q7, q2, d1[1]\n"
330                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
331                 "vmla.f32	q13, q3, d1[1]\n"
332                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
333                 "vmla.f32	q8, q2, d2[0]\n"
334                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
335                 "vmla.f32	q14, q3, d2[0]\n"
336                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
337                 "vmla.f32	q9, q2, d2[1]\n"
338                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
339                 "vmla.f32	q15, q3, d2[1]\n"
340                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
341                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
342                 "b		2f\n"
343 
344                 // tails==2 final tail
345                 "4:\n"
346                 "vmla.f32	q4, q2, d3[0]\n"
347                 "vmla.f32	q10, q3, d3[0]\n"
348                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
349                 "vmla.f32	q5, q2, d3[1]\n"
350                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
351                 "vmla.f32	q11, q3, d3[1]\n"
352                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
353                 "vmla.f32	q6, q2, d0[0]\n"
354                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
355                 "vmla.f32	q12, q3, d0[0]\n"
356                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
357                 "vmla.f32	q7, q2, d0[1]\n"
358                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
359                 "vmla.f32	q13, q3, d0[1]\n"
360                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
361                 "vmla.f32	q8, q2, d1[0]\n"
362                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
363                 "vmla.f32	q14, q3, d1[0]\n"
364                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
365                 "vmla.f32	q9, q2, d1[1]\n"
366                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
367                 "vmla.f32	q15, q3, d1[1]\n"
368                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
369                 "b		2f\n"
370 
371                 // tails==3 final tail
372                 "5:\n"
373                 "vmla.f32	q4, q2, d2[0]\n"
374                 "vld1.32	{d0}, [%[a_ptr] :64]!\n"
375                 "vmla.f32	q5, q2, d2[1]\n"
376                 "vmla.f32	q6, q2, d3[0]\n"
377                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
378                 "vmla.f32	q10, q3, d2[0]\n"
379                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
380                 "vmla.f32	q11, q3, d2[1]\n"
381                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
382                 "vmla.f32	q12, q3, d3[0]\n"
383                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
384                 "vmla.f32	q7, q2, d3[1]\n"
385                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
386                 "vmla.f32	q13, q3, d3[1]\n"
387                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
388                 "vmla.f32	q8, q2, d0[0]\n"
389                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
390                 "vmla.f32	q14, q3, d0[0]\n"
391                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
392                 "vmla.f32	q9, q2, d0[1]\n"
393                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
394                 "vmla.f32	q15, q3, d0[1]\n"
395                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
396                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
397 
398                 "2:\n"
399                 "vst1.32	{d30-d31}, [%[c_ptr] :128]!\n"
400             : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k), [tails] "+r" (tails)
401             :
402             : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15",
403               "r0", "r1", "cc", "memory"
404             );
405         }
406     }
407 }
408 
409 } // namespace arm_gemm
410 
411 #endif
412