xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/kernels/a32_sgemm_8x6/generic.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __arm__
25 
26 #include <arm_neon.h>
27 
28 #include "../../asmlib.hpp"
29 
30 // Kernel implementation.
31 //
32 // Assume that "Apanel" points to a chunk of A blocks (each size 6xK) in read-order.
33 // Assume that "Bpanel" points to a chunk of B blocks (each size 8xK) in read-order.
34 // Assume that "Cpanel" points to a chunk of C output blocks (each size
35 // 8x6), the chunks being arranged in a row major fashion.
36 //
37 // Note that the intent of this is that either ablocks or bblocks will be 1
38 // - this construction allows the output loop to proceed in either order.
39 
40 namespace arm_gemm {
41 
a32_sgemm_8x6(const float * Apanel,const float * Bpanel,float * Cpanel,int ablocks,int bblocks,int K)42 void a32_sgemm_8x6(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
43     const float *a_ptr = Apanel;
44     float *c_ptr = Cpanel;
45 
46     for (int yb=0; yb<ablocks; yb++) {
47         const float *a_ptr0 = a_ptr;
48         const float *b_ptr = Bpanel;
49 
50         for (int xb=0; xb<bblocks; xb++) {
51             a_ptr = a_ptr0;
52             int tails = (K & 3);
53             if (tails == 0) {
54                 tails = 4;
55             }
56             int k = ((K+3)/4) - 1;
57 
58             __asm __volatile (
59                 "vmov.i32	q4, #0\n"
60                 "vld1.32	{d0-d1}, [%[a_ptr] :64]!\n"
61                 "vmov.i32	q5, #0\n"
62                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
63                 "vmov.i32	q6, #0\n"
64                 ASM_PREFETCH("[%[a_ptr], #48]")
65                 "vmov.i32	q7, #0\n"
66                 ASM_PREFETCH("[%[b_ptr], #48]")
67                 "vmov.i32	q8, #0\n"
68                 ASM_PREFETCH("[%[a_ptr], #112]")
69                 "vmov.i32	q9, #0\n"
70                 ASM_PREFETCH("[%[b_ptr], #112]")
71                 "vmov.i32	q10, #0\n"
72                 "vmov.i32	q11, #0\n"
73                 "vmov.i32	q12, #0\n"
74                 "vmov.i32	q13, #0\n"
75                 ASM_PREFETCH("[%[a_ptr], #176]")
76                 "vmov.i32	q14, #0\n"
77                 ASM_PREFETCH("[%[b_ptr], #176]")
78                 "vmov.i32	q15, #0\n"
79 
80                 "cmp		%[k], #0\n"
81                 "beq		6f\n"
82 
83                 "1:\n"
84                 // Unroll 0
85                 "vmla.f32	q4, q2, d0[0]\n"
86                 "vld1.32	{d2-d3}, [%[a_ptr] :64]!\n"
87                 "vmla.f32	q5, q2, d0[1]\n"
88                 "vmla.f32	q6, q2, d1[0]\n"
89                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
90                 "vmla.f32	q7, q2, d1[1]\n"
91                 "vmla.f32	q8, q2, d2[0]\n"
92                 "vmla.f32	q9, q2, d2[1]\n"
93                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
94 
95                 "vmla.f32	q10, q3, d0[0]\n"
96                 "vmla.f32	q11, q3, d0[1]\n"
97                 "vmla.f32	q12, q3, d1[0]\n"
98                 "vmla.f32	q13, q3, d1[1]\n"
99                 "vld1.32	{d0-d1}, [%[a_ptr] :64]!\n"
100                 "vmla.f32	q14, q3, d2[0]\n"
101                 "vmla.f32	q15, q3, d2[1]\n"
102                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
103 
104                 // Unroll 1
105                 "vmla.f32	q4, q2, d3[0]\n"
106                 "subs		%[k], %[k], #1\n"
107                 "vmla.f32	q5, q2, d3[1]\n"
108                 ASM_PREFETCH("[%[a_ptr], #208]")
109                 "vmla.f32	q6, q2, d0[0]\n"
110                 "vmla.f32	q7, q2, d0[1]\n"
111                 ASM_PREFETCH("[%[b_ptr], #192]")
112                 "vmla.f32	q8, q2, d1[0]\n"
113                 "vmla.f32	q9, q2, d1[1]\n"
114                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
115 
116                 "vmla.f32	q10, q3, d3[0]\n"
117                 "vmla.f32	q11, q3, d3[1]\n"
118                 "vld1.32	{d2-d3}, [%[a_ptr] :64]!\n"
119                 "vmla.f32	q12, q3, d0[0]\n"
120                 "vmla.f32	q13, q3, d0[1]\n"
121                 "vmla.f32	q14, q3, d1[0]\n"
122                 "vmla.f32	q15, q3, d1[1]\n"
123                 "vld1.32	{d0-d1}, [%[a_ptr] :64]!\n"
124 
125                 // Unroll 2
126                 "vmla.f32	q4, q2, d2[0]\n"
127                 "vmla.f32	q5, q2, d2[1]\n"
128                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
129                 "vmla.f32	q6, q2, d3[0]\n"
130                 "vmla.f32	q7, q2, d3[1]\n"
131                 ASM_PREFETCH("[%[a_ptr], #240]")
132                 "vmla.f32	q8, q2, d0[0]\n"
133                 "vmla.f32	q9, q2, d0[1]\n"
134                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
135 
136                 "vmla.f32	q10, q3, d2[0]\n"
137                 "vmla.f32	q11, q3, d2[1]\n"
138                 ASM_PREFETCH("[%[b_ptr], #208]")
139                 "vmla.f32	q12, q3, d3[0]\n"
140                 "vmla.f32	q13, q3, d3[1]\n"
141                 "vld1.32	{d2-d3}, [%[a_ptr] :64]!\n"
142                 "vmla.f32	q14, q3, d0[0]\n"
143                 "vmla.f32	q15, q3, d0[1]\n"
144                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
145 
146                 // Unroll 3
147                 "vmla.f32	q4, q2, d1[0]\n"
148                 "vmla.f32	q5, q2, d1[1]\n"
149                 "vmla.f32	q6, q2, d2[0]\n"
150                 "vmla.f32	q7, q2, d2[1]\n"
151                 "vmla.f32	q8, q2, d3[0]\n"
152                 "vmla.f32	q9, q2, d3[1]\n"
153                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
154 
155                 "vmla.f32	q10, q3, d1[0]\n"
156                 "vmla.f32	q11, q3, d1[1]\n"
157                 "vld1.32	{d0-d1}, [%[a_ptr] :64]!\n"
158                 "vmla.f32	q12, q3, d2[0]\n"
159                 "vmla.f32	q13, q3, d2[1]\n"
160                 "vmla.f32	q14, q3, d3[0]\n"
161                 "vmla.f32	q15, q3, d3[1]\n"
162                 "bne		1b\n"
163 
164                 // Branch here if we never execute main loop.
165                 "6:\n"
166 
167                 // "Tails" shows how many multiply blocks are needed at the
168                 // end, must be 1-4 inclusive.  Bail out to alternative tail
169                 // immediately if it's 1.
170                 "subs		%[tails], %[tails], #1\n"
171                 "beq		3f\n"
172 
173                 // Detached final iteration
174                 // Unroll 0
175                 "vmla.f32	q4, q2, d0[0]\n"
176                 "vld1.32	{d2-d3}, [%[a_ptr] :64]!\n"
177                 "vmla.f32	q5, q2, d0[1]\n"
178                 "vmla.f32	q6, q2, d1[0]\n"
179                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
180                 "vmla.f32	q7, q2, d1[1]\n"
181                 "vmla.f32	q8, q2, d2[0]\n"
182                 "subs		%[tails], %[tails], #1\n"
183                 "vmla.f32	q9, q2, d2[1]\n"
184                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
185 
186                 "vmla.f32	q10, q3, d0[0]\n"
187                 "vmla.f32	q11, q3, d0[1]\n"
188                 "vmla.f32	q12, q3, d1[0]\n"
189                 "vmla.f32	q13, q3, d1[1]\n"
190                 "vld1.32	{d0-d1}, [%[a_ptr] :64]!\n"
191                 "vmla.f32	q14, q3, d2[0]\n"
192                 "vmla.f32	q15, q3, d2[1]\n"
193                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
194                 "beq		4f\n"
195 
196                 // Unroll 1
197                 "vmla.f32	q4, q2, d3[0]\n"
198                 "vmla.f32	q5, q2, d3[1]\n"
199                 "subs		%[tails], %[tails], #1\n"
200                 "vmla.f32	q6, q2, d0[0]\n"
201                 "vmla.f32	q7, q2, d0[1]\n"
202                 "vmla.f32	q8, q2, d1[0]\n"
203                 "vmla.f32	q9, q2, d1[1]\n"
204                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
205 
206                 "vmla.f32	q10, q3, d3[0]\n"
207                 "vmla.f32	q11, q3, d3[1]\n"
208                 "vld1.32	{d2-d3}, [%[a_ptr] :64]!\n"
209                 "vmla.f32	q12, q3, d0[0]\n"
210                 "vmla.f32	q13, q3, d0[1]\n"
211                 "vmla.f32	q14, q3, d1[0]\n"
212                 "vmla.f32	q15, q3, d1[1]\n"
213                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
214                 "beq		5f\n"
215 
216                 // Unroll 2
217                 "vld1.32	{d0-d1}, [%[a_ptr] :64]!\n"
218                 "vmla.f32	q4, q2, d2[0]\n"
219                 "vmla.f32	q5, q2, d2[1]\n"
220                 "vmla.f32	q6, q2, d3[0]\n"
221                 "vmla.f32	q7, q2, d3[1]\n"
222                 "vmla.f32	q8, q2, d0[0]\n"
223                 "vmla.f32	q9, q2, d0[1]\n"
224                 "vld1.32	{d4-d5}, [%[b_ptr] :128]!\n"
225 
226                 "vmla.f32	q10, q3, d2[0]\n"
227                 "vmla.f32	q11, q3, d2[1]\n"
228                 "vmla.f32	q12, q3, d3[0]\n"
229                 "vmla.f32	q13, q3, d3[1]\n"
230                 "vld1.32	{d2-d3}, [%[a_ptr] :64]!\n"
231                 "vmla.f32	q14, q3, d0[0]\n"
232                 "vmla.f32	q15, q3, d0[1]\n"
233                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
234 
235                 // Unroll 3
236                 "vmla.f32	q4, q2, d1[0]\n"
237                 "vmla.f32	q10, q3, d1[0]\n"
238                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
239                 "vmla.f32	q5, q2, d1[1]\n"
240                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
241                 "vmla.f32	q11, q3, d1[1]\n"
242                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
243                 "vmla.f32	q6, q2, d2[0]\n"
244                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
245                 "vmla.f32	q12, q3, d2[0]\n"
246                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
247                 "vmla.f32	q7, q2, d2[1]\n"
248                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
249                 "vmla.f32	q13, q3, d2[1]\n"
250                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
251                 "vmla.f32	q8, q2, d3[0]\n"
252                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
253                 "vmla.f32	q14, q3, d3[0]\n"
254                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
255                 "vmla.f32	q9, q2, d3[1]\n"
256                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
257                 "vmla.f32	q15, q3, d3[1]\n"
258                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
259                 "b		2f\n"
260 
261                 // tails==1 final tail
262                 "3:\n"
263                 "vmla.f32	q4, q2, d0[0]\n"
264                 "vld1.32	{d2}, [%[a_ptr] :64]!\n"
265                 "vmla.f32	q5, q2, d0[1]\n"
266                 "vld1.32	{d6-d7}, [%[b_ptr] :128]!\n"
267                 "vmla.f32	q6, q2, d1[0]\n"
268                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
269                 "vmla.f32	q10, q3, d0[0]\n"
270                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
271                 "vmla.f32	q11, q3, d0[1]\n"
272                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
273                 "vmla.f32	q12, q3, d1[0]\n"
274                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
275                 "vmla.f32	q7, q2, d1[1]\n"
276                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
277                 "vmla.f32	q13, q3, d1[1]\n"
278                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
279                 "vmla.f32	q8, q2, d2[0]\n"
280                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
281                 "vmla.f32	q14, q3, d2[0]\n"
282                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
283                 "vmla.f32	q9, q2, d2[1]\n"
284                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
285                 "vmla.f32	q15, q3, d2[1]\n"
286                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
287                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
288                 "b		2f\n"
289 
290                 // tails==2 final tail
291                 "4:\n"
292                 "vmla.f32	q4, q2, d3[0]\n"
293                 "vmla.f32	q10, q3, d3[0]\n"
294                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
295                 "vmla.f32	q5, q2, d3[1]\n"
296                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
297                 "vmla.f32	q11, q3, d3[1]\n"
298                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
299                 "vmla.f32	q6, q2, d0[0]\n"
300                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
301                 "vmla.f32	q12, q3, d0[0]\n"
302                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
303                 "vmla.f32	q7, q2, d0[1]\n"
304                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
305                 "vmla.f32	q13, q3, d0[1]\n"
306                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
307                 "vmla.f32	q8, q2, d1[0]\n"
308                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
309                 "vmla.f32	q14, q3, d1[0]\n"
310                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
311                 "vmla.f32	q9, q2, d1[1]\n"
312                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
313                 "vmla.f32	q15, q3, d1[1]\n"
314                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
315                 "b		2f\n"
316 
317                 // tails==3 final tail
318                 "5:\n"
319                 "vmla.f32	q4, q2, d2[0]\n"
320                 "vld1.32	{d0}, [%[a_ptr] :64]!\n"
321                 "vmla.f32	q5, q2, d2[1]\n"
322                 "vmla.f32	q6, q2, d3[0]\n"
323                 "vst1.32	{d8-d9}, [%[c_ptr] :128]!\n"
324                 "vmla.f32	q10, q3, d2[0]\n"
325                 "vst1.32	{d20-d21}, [%[c_ptr] :128]!\n"
326                 "vmla.f32	q11, q3, d2[1]\n"
327                 "vst1.32	{d10-d11}, [%[c_ptr] :128]!\n"
328                 "vmla.f32	q12, q3, d3[0]\n"
329                 "vst1.32	{d22-d23}, [%[c_ptr] :128]!\n"
330                 "vmla.f32	q7, q2, d3[1]\n"
331                 "vst1.32	{d12-d13}, [%[c_ptr] :128]!\n"
332                 "vmla.f32	q13, q3, d3[1]\n"
333                 "vst1.32	{d24-d25}, [%[c_ptr] :128]!\n"
334                 "vmla.f32	q8, q2, d0[0]\n"
335                 "vst1.32	{d14-d15}, [%[c_ptr] :128]!\n"
336                 "vmla.f32	q14, q3, d0[0]\n"
337                 "vst1.32	{d26-d27}, [%[c_ptr] :128]!\n"
338                 "vmla.f32	q9, q2, d0[1]\n"
339                 "vst1.32	{d16-d17}, [%[c_ptr] :128]!\n"
340                 "vmla.f32	q15, q3, d0[1]\n"
341                 "vst1.32	{d28-d29}, [%[c_ptr] :128]!\n"
342                 "vst1.32	{d18-d19}, [%[c_ptr] :128]!\n"
343 
344                 "2:\n"
345                 "vst1.32	{d30-d31}, [%[c_ptr] :128]!\n"
346             : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k), [tails] "+r" (tails)
347             :
348             : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15",
349               "cc", "memory"
350             );
351         }
352     }
353 }
354 
355 } // namespace arm_gemm
356 
357 #endif
358