1 /*
2  * Copyright (c) 2017-2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifdef __aarch64__
25 
26 #include <algorithm>
27 
28 #include <arm_neon.h>
29 
30 #include "../../asmlib.hpp"
31 #include "../../utils.hpp"
32 
33 namespace arm_gemm {
34 
a64_sgemv_pretransposed(const float * A,int lda,const float * X,float * Y,float beta,int M,int N)35 void a64_sgemv_pretransposed(const float *A, int lda, const float *X, float *Y, float beta, int M, int N) {
36     const bool beta0 = (beta==0.0f);
37     const bool beta1 = (beta==1.0f);
38 
39     for (int x=0; x<N; x+=32) {
40         float *y_ptr = Y + x;
41 
42         // How many elements are we processing in this loop?
43         int l = std::min(N - x, 32);
44 
45         register float32x4_t r0 asm("v24");
46         register float32x4_t r1 asm("v25");
47         register float32x4_t r2 asm("v26");
48         register float32x4_t r3 asm("v27");
49         register float32x4_t r4 asm("v28");
50         register float32x4_t r5 asm("v29");
51         register float32x4_t r6 asm("v30");
52         register float32x4_t r7 asm("v31");
53 
54         register float32x4_t x0  asm("v0");
55         register float32x4_t x0a asm("v1");
56 
57         const float *x_ptr = X;
58         const float *a_ptr = A + ((x/32) * lda);
59 
60         if (beta0) {
61             r0=r1=r2=r3=r4=r5=r6=r7=vdupq_n_f32(0.0f);
62         } else {
63             if (l==32) {
64                 // Fastest path - load all 8 vectors
65                 r0 = vld1q_f32(y_ptr);
66                 r1 = vld1q_f32(y_ptr + 4);
67                 r2 = vld1q_f32(y_ptr + 8);
68                 r3 = vld1q_f32(y_ptr + 12);
69                 r4 = vld1q_f32(y_ptr + 16);
70                 r5 = vld1q_f32(y_ptr + 20);
71                 r6 = vld1q_f32(y_ptr + 24);
72                 r7 = vld1q_f32(y_ptr + 28);
73             } else {
74                 // Slow case - leftovers.  Note that we don't care about
75                 // out-of-range vectors and lanes as we will throw them away at
76                 // the end.
77                 int vecs=l/4; // How many leftover vectors?
78                 int oddbits=l%4; // And how many odd single values?
79 
80                 if (oddbits) {
81                     // Load the outstanding odd values into a vector first
82                     float32x4_t oddvec = vdupq_n_f32(0.0f); // This does not really need to be initialized, but the compiler has a hard time with that.
83                     float *oddbase = y_ptr + l - oddbits;
84 
85                     switch (oddbits) {
86                         case 3:
87                             oddvec = vld1q_lane_f32(oddbase + 2, oddvec, 2);
88                             // fall through
89                         case 2:
90                             oddvec = vld1q_lane_f32(oddbase + 1, oddvec, 1);
91                             // fall through
92                         case 1:
93                             oddvec = vld1q_lane_f32(oddbase, oddvec, 0);
94                             break;
95 
96                         default:
97                             UNREACHABLE("Impossible case in switch.");
98                     }
99 
100                     // Now load the whole vectors, putting the oddments in when we run out.
101                     do {
102                         if (vecs==0) { r0 = oddvec; break; }
103 
104                         r0 = vld1q_f32(y_ptr);
105                         if (--vecs==0) { r1 = oddvec; break; }
106 
107                         r1 = vld1q_f32(y_ptr + 4);
108                         if (--vecs==0) { r2 = oddvec; break; }
109 
110                         r2 = vld1q_f32(y_ptr + 8);
111                         if (--vecs==0) { r3 = oddvec; break; }
112 
113                         r3 = vld1q_f32(y_ptr + 12);
114                         if (--vecs==0) { r4 = oddvec; break; }
115 
116                         r4 = vld1q_f32(y_ptr + 16);
117                         if (--vecs==0) { r5 = oddvec; break; }
118 
119                         r5 = vld1q_f32(y_ptr + 20);
120                         if (--vecs==0) { r6 = oddvec; break; }
121 
122                         r6 = vld1q_f32(y_ptr + 24);
123                         r7 = oddvec;
124                     } while (0);
125                 } else {
126                     // Slightly less slow path - just load the whole vectors
127                     do {
128                         // It can't be the case that oddbits==0 AND vecs==0 or we wouldn't be here.
129                         if (vecs==0) { UNREACHABLE("Impossible lack of work to do"); }
130 
131                         r0 = vld1q_f32(y_ptr);
132                         if (--vecs==0) { break; }
133 
134                         r1 = vld1q_f32(y_ptr + 4);
135                         if (--vecs==0) { break; }
136 
137                         r2 = vld1q_f32(y_ptr + 8);
138                         if (--vecs==0) { break; }
139 
140                         r3 = vld1q_f32(y_ptr + 12);
141                         if (--vecs==0) { break; }
142 
143                         r4 = vld1q_f32(y_ptr + 16);
144                         if (--vecs==0) { break; }
145 
146                         r5 = vld1q_f32(y_ptr + 20);
147                         if (--vecs==0) { break; }
148 
149                         r6 = vld1q_f32(y_ptr + 24);
150                     } while (0);
151                 }
152             }
153 
154             if (!beta1) {
155                 const float32x4_t vb = vdupq_n_f32(beta);
156 
157                 r0 = vmulq_f32(r0, vb);
158                 r1 = vmulq_f32(r1, vb);
159                 r2 = vmulq_f32(r2, vb);
160                 r3 = vmulq_f32(r3, vb);
161                 r4 = vmulq_f32(r4, vb);
162                 r5 = vmulq_f32(r5, vb);
163                 r6 = vmulq_f32(r6, vb);
164                 r7 = vmulq_f32(r7, vb);
165             }
166         }
167 
168         if (M>=8) {
169             int k = (M/8)-1;
170             x0 = vld1q_f32(x_ptr);
171 
172             __asm __volatile (
173                 "ldr	q2, [%[a_ptr], #0]\n"
174                 "ldr	q3, [%[a_ptr], #16]\n"
175                 "ldr	q4, [%[a_ptr], #32]\n"
176                 "ldr	q5, [%[a_ptr], #48]\n"
177                 "ldr	q6, [%[a_ptr], #64]\n"
178                 "ldr	q7, [%[a_ptr], #80]\n"
179                 "ldr	q8, [%[a_ptr], #96]\n"
180                 "ldr	q9, [%[a_ptr], #112]\n"
181                 "ldr	q10, [%[a_ptr], #128]\n"
182                 "ldr	q11, [%[a_ptr], #144]\n"
183                 "ldr	q12, [%[a_ptr], #160]\n"
184                 "ldr	q13, [%[a_ptr], #176]\n"
185                 "ldr	q14, [%[a_ptr], #192]\n"
186                 "ldr	q15, [%[a_ptr], #208]\n"
187                 "ldr	q16, [%[a_ptr], #224]\n"
188                 "ldr	q17, [%[a_ptr], #240]\n"
189                 "ldr	q18, [%[a_ptr], #256]\n"
190                 "ldr	q19, [%[a_ptr], #272]\n"
191                 "ldr	q20, [%[a_ptr], #288]\n"
192                 "ldr	q21, [%[a_ptr], #304]\n"
193                 "ldr	q22, [%[a_ptr], #320]\n"
194                 "ldr	q23, [%[a_ptr], #336]\n"
195                 ASM_PREFETCH("[%[a_ptr], #384]")
196                 ASM_PREFETCH("[%[a_ptr], #448]")
197                 ASM_PREFETCH("[%[a_ptr], #512]")
198                 ASM_PREFETCH("[%[a_ptr], #576]")
199                 ASM_PREFETCH("[%[a_ptr], #640]")
200                 ASM_PREFETCH("[%[a_ptr], #704]")
201                 ASM_PREFETCH("[%[a_ptr], #768]")
202                 ASM_PREFETCH("[%[a_ptr], #832]")
203                 ASM_PREFETCH("[%[a_ptr], #896]")
204                 ASM_PREFETCH("[%[a_ptr], #960]")
205                 ASM_PREFETCH("[%[a_ptr], #1024]")
206                 ASM_PREFETCH("[%[a_ptr], #1088]")
207                 ASM_PREFETCH("[%[a_ptr], #1152]")
208                 ASM_PREFETCH("[%[a_ptr], #1216]")
209                 ASM_PREFETCH("[%[a_ptr], #1280]")
210                 ASM_PREFETCH("[%[a_ptr], #1344]")
211                 ASM_PREFETCH("[%[a_ptr], #1408]")
212                 ASM_PREFETCH("[%[a_ptr], #1472]")
213                 ASM_PREFETCH("[%[a_ptr], #1536]")
214                 ASM_PREFETCH("[%[a_ptr], #1600]")
215                 ASM_PREFETCH("[%[a_ptr], #1664]")
216                 ASM_PREFETCH("[%[a_ptr], #1728]")
217                 ASM_PREFETCH("[%[a_ptr], #1792]")
218                 ASM_PREFETCH("[%[a_ptr], #1856]")
219                 ASM_PREFETCH("[%[a_ptr], #1920]")
220                 ASM_PREFETCH("[%[a_ptr], #1984]")
221                 "add	%[a_ptr], %[a_ptr], #352\n"
222 
223                 "cbz	%w[k], 2f\n"
224 
225                 "1:\n"
226                 // Unroll 0
227                 "fmla	%[r0].4s, v2.4s, %[x0].s[0]\n"
228                 "ldr	%q[x0a], [%[x_ptr], #16]\n"
229                 "fmla	%[r1].4s, v3.4s, %[x0].s[0]\n"
230                 "ldr	q3, [%[a_ptr], #0]\n"
231                 "subs	%w[k], %w[k], #1\n"
232                 "fmla	%[r2].4s, v4.4s, %[x0].s[0]\n"
233                 "ldr	q4, [%[a_ptr], #16]\n"
234                 "fmla	%[r3].4s, v5.4s, %[x0].s[0]\n"
235                 "ldr	q5, [%[a_ptr], #32]\n"
236                 "add	%[x_ptr], %[x_ptr], #32\n"
237                 ASM_PREFETCH("[%[a_ptr], #1664]")
238                 "fmla	%[r4].4s, v6.4s, %[x0].s[0]\n"
239                 "ldr	q6, [%[a_ptr], #48]\n"
240                 "fmla	%[r5].4s, v7.4s, %[x0].s[0]\n"
241                 "ldr	q7, [%[a_ptr], #64]\n"
242                 "fmla	%[r6].4s, v8.4s, %[x0].s[0]\n"
243                 "ldr	q8, [%[a_ptr], #80]\n"
244                 "fmla	%[r7].4s, v9.4s, %[x0].s[0]\n"
245                 "ldr	q9, [%[a_ptr], #96]\n"
246                 ASM_PREFETCH("[%[a_ptr], #1728]")
247 
248                 // Unroll 1
249                 "fmla	%[r0].4s, v10.4s, %[x0].s[1]\n"
250                 "ldr	q10, [%[a_ptr], #112]\n"
251                 "fmla	%[r1].4s, v11.4s, %[x0].s[1]\n"
252                 "ldr	q11, [%[a_ptr], #128]\n"
253                 "fmla	%[r2].4s, v12.4s, %[x0].s[1]\n"
254                 "ldr	q12, [%[a_ptr], #144]\n"
255                 "fmla	%[r3].4s, v13.4s, %[x0].s[1]\n"
256                 "ldr	q13, [%[a_ptr], #160]\n"
257                 ASM_PREFETCH("[%[a_ptr], #1792]")
258                 "fmla	%[r4].4s, v14.4s, %[x0].s[1]\n"
259                 "ldr	q14, [%[a_ptr], #176]\n"
260                 "fmla	%[r5].4s, v15.4s, %[x0].s[1]\n"
261                 "ldr	q15, [%[a_ptr], #192]\n"
262                 "fmla	%[r6].4s, v16.4s, %[x0].s[1]\n"
263                 "ldr	q16, [%[a_ptr], #208]\n"
264                 "fmla	%[r7].4s, v17.4s, %[x0].s[1]\n"
265                 "ldr	q17, [%[a_ptr], #224]\n"
266                 ASM_PREFETCH("[%[a_ptr], #1856]")
267 
268                 // Unroll 2
269                 "fmla	%[r0].4s, v18.4s, %[x0].s[2]\n"
270                 "ldr	q18, [%[a_ptr], #240]\n"
271                 "fmla	%[r1].4s, v19.4s, %[x0].s[2]\n"
272                 "ldr	q19, [%[a_ptr], #256]\n"
273                 "fmla	%[r2].4s, v20.4s, %[x0].s[2]\n"
274                 "ldr	q20, [%[a_ptr], #272]\n"
275                 "fmla	%[r3].4s, v21.4s, %[x0].s[2]\n"
276                 "ldr	q21, [%[a_ptr], #288]\n"
277                 ASM_PREFETCH("[%[a_ptr], #1920]")
278                 "fmla	%[r4].4s, v22.4s, %[x0].s[2]\n"
279                 "ldr	q22, [%[a_ptr], #304]\n"
280                 "fmla	%[r5].4s, v23.4s, %[x0].s[2]\n"
281                 "ldr	q23, [%[a_ptr], #320]\n"
282                 "fmla	%[r6].4s, v3.4s, %[x0].s[2]\n"
283                 "ldr	q2, [%[a_ptr], #336]\n"
284                 "ldr	q3, [%[a_ptr], #352]\n"
285                 "fmla	%[r7].4s, v4.4s, %[x0].s[2]\n"
286                 "ldr	q4, [%[a_ptr], #368]\n"
287                 ASM_PREFETCH("[%[a_ptr], #1984]")
288 
289                 // Unroll 3
290                 "fmla	%[r0].4s, v5.4s, %[x0].s[3]\n"
291                 "ldr	q5, [%[a_ptr], #384]\n"
292                 "fmla	%[r1].4s, v6.4s, %[x0].s[3]\n"
293                 "ldr	q6, [%[a_ptr], #400]\n"
294                 "fmla	%[r2].4s, v7.4s, %[x0].s[3]\n"
295                 "ldr	q7, [%[a_ptr], #416]\n"
296                 "fmla	%[r3].4s, v8.4s, %[x0].s[3]\n"
297                 ASM_PREFETCH("[%[a_ptr], #2048]")
298                 "ldr	q8, [%[a_ptr], #432]\n"
299                 "fmla	%[r4].4s, v9.4s, %[x0].s[3]\n"
300                 "ldr	q9, [%[a_ptr], #448]\n"
301                 "fmla	%[r5].4s, v10.4s, %[x0].s[3]\n"
302                 "ldr	q10, [%[a_ptr], #464]\n"
303                 "fmla	%[r6].4s, v11.4s, %[x0].s[3]\n"
304                 "ldr	q11, [%[a_ptr], #480]\n"
305                 "fmla	%[r7].4s, v12.4s, %[x0].s[3]\n"
306                 "ldr	q12, [%[a_ptr], #496]\n"
307                 ASM_PREFETCH("[%[a_ptr], #2112]")
308 
309                 // Unroll 4
310                 "fmla	%[r0].4s, v13.4s, %[x0a].s[0]\n"
311                 "ldr	%q[x0], [%[x_ptr]]\n"
312                 "fmla	%[r1].4s, v14.4s, %[x0a].s[0]\n"
313                 "ldr	q14, [%[a_ptr], #512]\n"
314                 "fmla	%[r2].4s, v15.4s, %[x0a].s[0]\n"
315                 "ldr	q15, [%[a_ptr], #528]\n"
316                 "fmla	%[r3].4s, v16.4s, %[x0a].s[0]\n"
317                 ASM_PREFETCH("[%[a_ptr], #2176]")
318                 "ldr	q16, [%[a_ptr], #544]\n"
319                 "fmla	%[r4].4s, v17.4s, %[x0a].s[0]\n"
320                 "ldr	q17, [%[a_ptr], #560]\n"
321                 "fmla	%[r5].4s, v18.4s, %[x0a].s[0]\n"
322                 "ldr	q18, [%[a_ptr], #576]\n"
323                 "fmla	%[r6].4s, v19.4s, %[x0a].s[0]\n"
324                 "ldr	q19, [%[a_ptr], #592]\n"
325                 "fmla	%[r7].4s, v20.4s, %[x0a].s[0]\n"
326                 "ldr	q20, [%[a_ptr], #608]\n"
327                 ASM_PREFETCH("[%[a_ptr], #2240]")
328 
329                 // Unroll 5
330                 "fmla	%[r0].4s, v21.4s, %[x0a].s[1]\n"
331                 "ldr	q21, [%[a_ptr], #624]\n"
332                 "fmla	%[r1].4s, v22.4s, %[x0a].s[1]\n"
333                 "ldr	q22, [%[a_ptr], #640]\n"
334                 "fmla	%[r2].4s, v23.4s, %[x0a].s[1]\n"
335                 "ldr	q23, [%[a_ptr], #656]\n"
336                 "fmla	%[r3].4s, v2.4s, %[x0a].s[1]\n"
337                 "ldr	q2, [%[a_ptr], #672]\n"
338                 ASM_PREFETCH("[%[a_ptr], #2304]")
339                 "fmla	%[r4].4s, v3.4s, %[x0a].s[1]\n"
340                 "ldr	q3, [%[a_ptr], #688]\n"
341                 "fmla	%[r5].4s, v4.4s, %[x0a].s[1]\n"
342                 "ldr	q4, [%[a_ptr], #704]\n"
343                 "fmla	%[r6].4s, v5.4s, %[x0a].s[1]\n"
344                 "ldr	q5, [%[a_ptr], #720]\n"
345                 "fmla	%[r7].4s, v6.4s, %[x0a].s[1]\n"
346                 "ldr	q6, [%[a_ptr], #736]\n"
347                 ASM_PREFETCH("[%[a_ptr], #2368]")
348 
349                 // Unroll 6
350                 "fmla	%[r0].4s, v7.4s, %[x0a].s[2]\n"
351                 "ldr	q7, [%[a_ptr], #752]\n"
352                 "fmla	%[r1].4s, v8.4s, %[x0a].s[2]\n"
353                 "ldr	q8, [%[a_ptr], #768]\n"
354                 "fmla	%[r2].4s, v9.4s, %[x0a].s[2]\n"
355                 "ldr	q9, [%[a_ptr], #784]\n"
356                 "fmla	%[r3].4s, v10.4s, %[x0a].s[2]\n"
357                 "ldr	q10, [%[a_ptr], #800]\n"
358                 ASM_PREFETCH("[%[a_ptr], #2432]")
359                 "fmla	%[r4].4s, v11.4s, %[x0a].s[2]\n"
360                 "ldr	q11, [%[a_ptr], #816]\n"
361                 "fmla	%[r5].4s, v12.4s, %[x0a].s[2]\n"
362                 "ldr	q12, [%[a_ptr], #832]\n"
363                 "fmla	%[r6].4s, v14.4s, %[x0a].s[2]\n"
364                 "ldr	q13, [%[a_ptr], #848]\n"
365                 "ldr	q14, [%[a_ptr], #864]\n"
366                 "fmla	%[r7].4s, v15.4s, %[x0a].s[2]\n"
367                 "ldr	q15, [%[a_ptr], #880]\n"
368                 ASM_PREFETCH("[%[a_ptr], #2496]")
369 
370                 // Unroll 7
371                 "fmla	%[r0].4s, v16.4s, %[x0a].s[3]\n"
372                 "ldr	q16, [%[a_ptr], #896]\n"
373                 "fmla	%[r1].4s, v17.4s, %[x0a].s[3]\n"
374                 "ldr	q17, [%[a_ptr], #912]\n"
375                 "fmla	%[r2].4s, v18.4s, %[x0a].s[3]\n"
376                 "ldr	q18, [%[a_ptr], #928]\n"
377                 "fmla	%[r3].4s, v19.4s, %[x0a].s[3]\n"
378                 ASM_PREFETCH("[%[a_ptr], #2560]")
379                 "ldr	q19, [%[a_ptr], #944]\n"
380                 "fmla	%[r4].4s, v20.4s, %[x0a].s[3]\n"
381                 "ldr	q20, [%[a_ptr], #960]\n"
382                 "fmla	%[r5].4s, v21.4s, %[x0a].s[3]\n"
383                 "ldr	q21, [%[a_ptr], #976]\n"
384                 "add	%[a_ptr], %[a_ptr], #1024\n"
385                 "fmla	%[r6].4s, v22.4s, %[x0a].s[3]\n"
386                 "ldr	q22, [%[a_ptr], #-32]\n"
387                 "fmla	%[r7].4s, v23.4s, %[x0a].s[3]\n"
388                 "ldr	q23, [%[a_ptr], #-16]\n"
389                 ASM_PREFETCH("[%[a_ptr], #1600]")
390                 "bne	1b\n"
391 
392                 // Detached final iteration
393                 "2:\n"
394 
395                 // Unroll 0
396                 "fmla	%[r0].4s, v2.4s, %[x0].s[0]\n"
397                 "ldr	%q[x0a], [%[x_ptr], #16]\n"
398                 "fmla	%[r1].4s, v3.4s, %[x0].s[0]\n"
399                 "ldr	q3, [%[a_ptr], #0]\n"
400                 "subs	%w[k], %w[k], #1\n"
401                 "fmla	%[r2].4s, v4.4s, %[x0].s[0]\n"
402                 "ldr	q4, [%[a_ptr], #16]\n"
403                 "fmla	%[r3].4s, v5.4s, %[x0].s[0]\n"
404                 "ldr	q5, [%[a_ptr], #32]\n"
405                 "add	%[x_ptr], %[x_ptr], #32\n"
406                 "fmla	%[r4].4s, v6.4s, %[x0].s[0]\n"
407                 "ldr	q6, [%[a_ptr], #48]\n"
408                 "fmla	%[r5].4s, v7.4s, %[x0].s[0]\n"
409                 "ldr	q7, [%[a_ptr], #64]\n"
410                 "fmla	%[r6].4s, v8.4s, %[x0].s[0]\n"
411                 "ldr	q8, [%[a_ptr], #80]\n"
412                 "fmla	%[r7].4s, v9.4s, %[x0].s[0]\n"
413                 "ldr	q9, [%[a_ptr], #96]\n"
414 
415                 // Unroll 1
416                 "fmla	%[r0].4s, v10.4s, %[x0].s[1]\n"
417                 "ldr	q10, [%[a_ptr], #112]\n"
418                 "fmla	%[r1].4s, v11.4s, %[x0].s[1]\n"
419                 "ldr	q11, [%[a_ptr], #128]\n"
420                 "fmla	%[r2].4s, v12.4s, %[x0].s[1]\n"
421                 "ldr	q12, [%[a_ptr], #144]\n"
422                 "fmla	%[r3].4s, v13.4s, %[x0].s[1]\n"
423                 "ldr	q13, [%[a_ptr], #160]\n"
424                 "fmla	%[r4].4s, v14.4s, %[x0].s[1]\n"
425                 "ldr	q14, [%[a_ptr], #176]\n"
426                 "fmla	%[r5].4s, v15.4s, %[x0].s[1]\n"
427                 "ldr	q15, [%[a_ptr], #192]\n"
428                 "fmla	%[r6].4s, v16.4s, %[x0].s[1]\n"
429                 "ldr	q16, [%[a_ptr], #208]\n"
430                 "fmla	%[r7].4s, v17.4s, %[x0].s[1]\n"
431                 "ldr	q17, [%[a_ptr], #224]\n"
432 
433                 // Unroll 2
434                 "fmla	%[r0].4s, v18.4s, %[x0].s[2]\n"
435                 "ldr	q18, [%[a_ptr], #240]\n"
436                 "fmla	%[r1].4s, v19.4s, %[x0].s[2]\n"
437                 "ldr	q19, [%[a_ptr], #256]\n"
438                 "fmla	%[r2].4s, v20.4s, %[x0].s[2]\n"
439                 "ldr	q20, [%[a_ptr], #272]\n"
440                 "fmla	%[r3].4s, v21.4s, %[x0].s[2]\n"
441                 "ldr	q21, [%[a_ptr], #288]\n"
442                 "fmla	%[r4].4s, v22.4s, %[x0].s[2]\n"
443                 "ldr	q22, [%[a_ptr], #304]\n"
444                 "fmla	%[r5].4s, v23.4s, %[x0].s[2]\n"
445                 "ldr	q23, [%[a_ptr], #320]\n"
446                 "fmla	%[r6].4s, v3.4s, %[x0].s[2]\n"
447                 "ldr	q2, [%[a_ptr], #336]\n"
448                 "ldr	q3, [%[a_ptr], #352]\n"
449                 "fmla	%[r7].4s, v4.4s, %[x0].s[2]\n"
450                 "ldr	q4, [%[a_ptr], #368]\n"
451 
452                 // Unroll 3
453                 "fmla	%[r0].4s, v5.4s, %[x0].s[3]\n"
454                 "ldr	q5, [%[a_ptr], #384]\n"
455                 "fmla	%[r1].4s, v6.4s, %[x0].s[3]\n"
456                 "ldr	q6, [%[a_ptr], #400]\n"
457                 "fmla	%[r2].4s, v7.4s, %[x0].s[3]\n"
458                 "ldr	q7, [%[a_ptr], #416]\n"
459                 "fmla	%[r3].4s, v8.4s, %[x0].s[3]\n"
460                 "ldr	q8, [%[a_ptr], #432]\n"
461                 "fmla	%[r4].4s, v9.4s, %[x0].s[3]\n"
462                 "ldr	q9, [%[a_ptr], #448]\n"
463                 "fmla	%[r5].4s, v10.4s, %[x0].s[3]\n"
464                 "ldr	q10, [%[a_ptr], #464]\n"
465                 "fmla	%[r6].4s, v11.4s, %[x0].s[3]\n"
466                 "ldr	q11, [%[a_ptr], #480]\n"
467                 "fmla	%[r7].4s, v12.4s, %[x0].s[3]\n"
468                 "ldr	q12, [%[a_ptr], #496]\n"
469 
470                 // Unroll 4
471                 "fmla	%[r0].4s, v13.4s, %[x0a].s[0]\n"
472                 "fmla	%[r1].4s, v14.4s, %[x0a].s[0]\n"
473                 "ldr	q14, [%[a_ptr], #512]\n"
474                 "fmla	%[r2].4s, v15.4s, %[x0a].s[0]\n"
475                 "ldr	q15, [%[a_ptr], #528]\n"
476                 "fmla	%[r3].4s, v16.4s, %[x0a].s[0]\n"
477                 "ldr	q16, [%[a_ptr], #544]\n"
478                 "fmla	%[r4].4s, v17.4s, %[x0a].s[0]\n"
479                 "ldr	q17, [%[a_ptr], #560]\n"
480                 "fmla	%[r5].4s, v18.4s, %[x0a].s[0]\n"
481                 "ldr	q18, [%[a_ptr], #576]\n"
482                 "fmla	%[r6].4s, v19.4s, %[x0a].s[0]\n"
483                 "ldr	q19, [%[a_ptr], #592]\n"
484                 "fmla	%[r7].4s, v20.4s, %[x0a].s[0]\n"
485                 "ldr	q20, [%[a_ptr], #608]\n"
486 
487                 // Unroll 5
488                 "fmla	%[r0].4s, v21.4s, %[x0a].s[1]\n"
489                 "ldr	q21, [%[a_ptr], #624]\n"
490                 "fmla	%[r1].4s, v22.4s, %[x0a].s[1]\n"
491                 "ldr	q22, [%[a_ptr], #640]\n"
492                 "fmla	%[r2].4s, v23.4s, %[x0a].s[1]\n"
493                 "ldr	q23, [%[a_ptr], #656]\n"
494                 "fmla	%[r3].4s, v2.4s, %[x0a].s[1]\n"
495                 "add	%[a_ptr], %[a_ptr], #672\n"
496                 "fmla	%[r4].4s, v3.4s, %[x0a].s[1]\n"
497                 "fmla	%[r5].4s, v4.4s, %[x0a].s[1]\n"
498                 "fmla	%[r6].4s, v5.4s, %[x0a].s[1]\n"
499                 "fmla	%[r7].4s, v6.4s, %[x0a].s[1]\n"
500 
501                 // Unroll 6
502                 "fmla	%[r0].4s, v7.4s, %[x0a].s[2]\n"
503                 "fmla	%[r1].4s, v8.4s, %[x0a].s[2]\n"
504                 "fmla	%[r2].4s, v9.4s, %[x0a].s[2]\n"
505                 "fmla	%[r3].4s, v10.4s, %[x0a].s[2]\n"
506                 "fmla	%[r4].4s, v11.4s, %[x0a].s[2]\n"
507                 "fmla	%[r5].4s, v12.4s, %[x0a].s[2]\n"
508                 "fmla	%[r6].4s, v14.4s, %[x0a].s[2]\n"
509                 "fmla	%[r7].4s, v15.4s, %[x0a].s[2]\n"
510 
511                 // Unroll 7
512                 "fmla	%[r0].4s, v16.4s, %[x0a].s[3]\n"
513                 "fmla	%[r1].4s, v17.4s, %[x0a].s[3]\n"
514                 "fmla	%[r2].4s, v18.4s, %[x0a].s[3]\n"
515                 "fmla	%[r3].4s, v19.4s, %[x0a].s[3]\n"
516                 "fmla	%[r4].4s, v20.4s, %[x0a].s[3]\n"
517                 "fmla	%[r5].4s, v21.4s, %[x0a].s[3]\n"
518                 "fmla	%[r6].4s, v22.4s, %[x0a].s[3]\n"
519                 "fmla	%[r7].4s, v23.4s, %[x0a].s[3]\n"
520             :
521               [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr),
522               [x0] "+w" (x0), [x0a] "+w" (x0a), [k] "+r" (k),
523               [r0] "+w" (r0), [r1] "+w" (r1), [r2] "+w" (r2), [r3] "+w" (r3),
524               [r4] "+w" (r4), [r5] "+w" (r5), [r6] "+w" (r6), [r7] "+w" (r7)
525             :
526             : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14",
527               "v15", "v16", "v17", "v18", "v19", "v20", "v21", "v22", "v23", "x20", "x21", "cc", "memory");
528         }
529 
530         // Deal with ragged M
531         if (M % 8) {
532             int l=(M%8)-1;
533 
534             __asm __volatile (
535                 "ldr	q2, [%[a_ptr], #0]\n"
536                 "ldr	q3, [%[a_ptr], #16]\n"
537                 "ldr	q4, [%[a_ptr], #32]\n"
538                 "ldr	q5, [%[a_ptr], #48]\n"
539                 "ldr	q6, [%[a_ptr], #64]\n"
540                 "ldr	q7, [%[a_ptr], #80]\n"
541                 "ldr	q8, [%[a_ptr], #96]\n"
542                 "ldr	q9, [%[a_ptr], #112]\n"
543                 "ldr	%s[x0], [%[x_ptr]]\n"
544                 "add	%[a_ptr], %[a_ptr], #128\n"
545                 "add	%[x_ptr], %[x_ptr], #4\n"
546 
547                 "cbz	%w[l], 2f\n"
548 
549                 "1:\n"
550                 "fmla	%[r0].4s, v2.4s, %[x0].s[0]\n"
551                 "ldr	q2, [%[a_ptr], #0]\n"
552                 "subs	%w[l], %w[l], #1\n"
553                 "fmla	%[r1].4s, v3.4s, %[x0].s[0]\n"
554                 "ldr	q3, [%[a_ptr], #16]\n"
555                 "fmla	%[r2].4s, v4.4s, %[x0].s[0]\n"
556                 "ldr	q4, [%[a_ptr], #32]\n"
557                 "fmla	%[r3].4s, v5.4s, %[x0].s[0]\n"
558                 "ldr	q5, [%[a_ptr], #48]\n"
559                 "fmla	%[r4].4s, v6.4s, %[x0].s[0]\n"
560                 "ldr	q6, [%[a_ptr], #64]\n"
561                 "fmla	%[r5].4s, v7.4s, %[x0].s[0]\n"
562                 "ldr	q7, [%[a_ptr], #80]\n"
563                 "fmla	%[r6].4s, v8.4s, %[x0].s[0]\n"
564                 "ldr	q8, [%[a_ptr], #96]\n"
565                 "fmla	%[r7].4s, v9.4s, %[x0].s[0]\n"
566                 "ldr	q9, [%[a_ptr], #112]\n"
567                 "ldr	%s[x0], [%[x_ptr]]\n"
568                 "add	%[a_ptr], %[a_ptr], #128\n"
569                 "add	%[x_ptr], %[x_ptr], #4\n"
570                 "bne	1b\n"
571 
572                 "2:\n"
573 
574                 "fmla	%[r0].4s, v2.4s, %[x0].s[0]\n"
575                 "fmla	%[r1].4s, v3.4s, %[x0].s[0]\n"
576                 "fmla	%[r2].4s, v4.4s, %[x0].s[0]\n"
577                 "fmla	%[r3].4s, v5.4s, %[x0].s[0]\n"
578                 "fmla	%[r4].4s, v6.4s, %[x0].s[0]\n"
579                 "fmla	%[r5].4s, v7.4s, %[x0].s[0]\n"
580                 "fmla	%[r6].4s, v8.4s, %[x0].s[0]\n"
581                 "fmla	%[r7].4s, v9.4s, %[x0].s[0]\n"
582             :
583               [a_ptr] "+r" (a_ptr), [x_ptr] "+r" (x_ptr),
584               [x0] "+w" (x0), [l] "+r" (l),
585               [r0] "+w" (r0), [r1] "+w" (r1), [r2] "+w" (r2), [r3] "+w" (r3),
586               [r4] "+w" (r4), [r5] "+w" (r5), [r6] "+w" (r6), [r7] "+w" (r7)
587             :
588             : "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "cc", "memory");
589         }
590 
591         if (l==32) {
592             // Fast path
593             vst1q_f32(y_ptr, r0);
594             vst1q_f32(y_ptr + 4, r1);
595             vst1q_f32(y_ptr + 8, r2);
596             vst1q_f32(y_ptr + 12, r3);
597             vst1q_f32(y_ptr + 16, r4);
598             vst1q_f32(y_ptr + 20, r5);
599             vst1q_f32(y_ptr + 24, r6);
600             vst1q_f32(y_ptr + 28, r7);
601         } else {
602             int vecs=l/4;
603             int oddbits=l%4;
604 
605             if (oddbits) {
606                 // As above - slowest path deals with vectors plus odd bits
607                 float32x4_t oddvec;
608 
609                 do {
610                     if (vecs==0) { oddvec=r0; break; }
611 
612                     vst1q_f32(y_ptr, r0);
613                     if (--vecs==0) { oddvec=r1; break; }
614 
615                     vst1q_f32(y_ptr + 4, r1);
616                     if (--vecs==0) { oddvec=r2; break; }
617 
618                     vst1q_f32(y_ptr + 8, r2);
619                     if (--vecs==0) { oddvec=r3; break; }
620 
621                     vst1q_f32(y_ptr + 12, r3);
622                     if (--vecs==0) { oddvec=r4; break; }
623 
624                     vst1q_f32(y_ptr + 16, r4);
625                     if (--vecs==0) { oddvec=r5; break; }
626 
627                     vst1q_f32(y_ptr + 20, r5);
628                     if (--vecs==0) { oddvec=r6; break; }
629 
630                     vst1q_f32(y_ptr + 24, r6);
631                     oddvec=r7;
632                 } while (0);
633 
634                 float *oddbase = y_ptr + l - oddbits;
635 
636                 switch(oddbits) {
637                     case 3:
638                         vst1q_lane_f32(oddbase + 2, oddvec, 2);
639                         // fall through
640                     case 2:
641                         vst1q_lane_f32(oddbase + 1, oddvec, 1);
642                         // fall through
643                     case 1:
644                         vst1q_lane_f32(oddbase, oddvec, 0);
645                         break;
646 
647                     default:
648                         // oddbits must be 1, 2 or 3.
649                         UNREACHABLE("Impossible case in switch.");
650                 }
651             } else {
652                 // As above - medium path deals with vectors only
653                 do {
654                     if (vecs==0) { UNREACHABLE("vecs and oddbits can't both be 0"); }
655 
656                     vst1q_f32(y_ptr, r0);
657                     if (--vecs==0) { break; }
658 
659                     vst1q_f32(y_ptr + 4, r1);
660                     if (--vecs==0) { break; }
661 
662                     vst1q_f32(y_ptr + 8, r2);
663                     if (--vecs==0) { break; }
664 
665                     vst1q_f32(y_ptr + 12, r3);
666                     if (--vecs==0) { break; }
667 
668                     vst1q_f32(y_ptr + 16, r4);
669                     if (--vecs==0) { break; }
670 
671                     vst1q_f32(y_ptr + 20, r5);
672                     if (--vecs==0) { break; }
673 
674                     vst1q_f32(y_ptr + 24, r6);
675                 } while (0);
676             }
677         }
678     }
679 }
680 
681 } // namespace arm_gemm
682 
683 #endif // aarch64
684