1 /*
2 * Copyright (c) 2017-2018 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #ifdef __arm__
25
26 #include <arm_neon.h>
27
28 #include "../../asmlib.hpp"
29
30 // Kernel implementation.
31 //
32 // Assume that "Apanel" points to a chunk of A blocks (each size 6xK) in read-order.
33 // Assume that "Bpanel" points to a chunk of B blocks (each size 8xK) in read-order.
34 // Assume that "Cpanel" points to a chunk of C output blocks (each size
35 // 8x6), the chunks being arranged in a row major fashion.
36 //
37 // Note that the intent of this is that either ablocks or bblocks will be 1
38 // - this construction allows the output loop to proceed in either order.
39
40 namespace arm_gemm {
41
a32_sgemm_8x6(const float * Apanel,const float * Bpanel,float * Cpanel,int ablocks,int bblocks,int K)42 void a32_sgemm_8x6(const float *Apanel, const float *Bpanel, float *Cpanel, int ablocks, int bblocks, int K) {
43 const float *a_ptr = Apanel;
44 float *c_ptr = Cpanel;
45
46 for (int yb=0; yb<ablocks; yb++) {
47 const float *a_ptr0 = a_ptr;
48 const float *b_ptr = Bpanel;
49
50 for (int xb=0; xb<bblocks; xb++) {
51 a_ptr = a_ptr0;
52 int tails = (K & 3);
53 if (tails == 0) {
54 tails = 4;
55 }
56 int k = ((K+3)/4) - 1;
57
58 __asm __volatile (
59 "vmov.i32 q4, #0\n"
60 "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
61 "vmov.i32 q5, #0\n"
62 "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
63 "vmov.i32 q6, #0\n"
64 ASM_PREFETCH("[%[a_ptr], #48]")
65 "vmov.i32 q7, #0\n"
66 ASM_PREFETCH("[%[b_ptr], #48]")
67 "vmov.i32 q8, #0\n"
68 ASM_PREFETCH("[%[a_ptr], #112]")
69 "vmov.i32 q9, #0\n"
70 ASM_PREFETCH("[%[b_ptr], #112]")
71 "vmov.i32 q10, #0\n"
72 "vmov.i32 q11, #0\n"
73 "vmov.i32 q12, #0\n"
74 "vmov.i32 q13, #0\n"
75 ASM_PREFETCH("[%[a_ptr], #176]")
76 "vmov.i32 q14, #0\n"
77 ASM_PREFETCH("[%[b_ptr], #176]")
78 "vmov.i32 q15, #0\n"
79
80 "cmp %[k], #0\n"
81 "beq 6f\n"
82
83 "1:\n"
84 // Unroll 0
85 "vmla.f32 q4, q2, d0[0]\n"
86 "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
87 "vmla.f32 q5, q2, d0[1]\n"
88 "vmla.f32 q6, q2, d1[0]\n"
89 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
90 "vmla.f32 q7, q2, d1[1]\n"
91 "vmla.f32 q8, q2, d2[0]\n"
92 "vmla.f32 q9, q2, d2[1]\n"
93 "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
94
95 "vmla.f32 q10, q3, d0[0]\n"
96 "vmla.f32 q11, q3, d0[1]\n"
97 "vmla.f32 q12, q3, d1[0]\n"
98 "vmla.f32 q13, q3, d1[1]\n"
99 "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
100 "vmla.f32 q14, q3, d2[0]\n"
101 "vmla.f32 q15, q3, d2[1]\n"
102 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
103
104 // Unroll 1
105 "vmla.f32 q4, q2, d3[0]\n"
106 "subs %[k], %[k], #1\n"
107 "vmla.f32 q5, q2, d3[1]\n"
108 ASM_PREFETCH("[%[a_ptr], #208]")
109 "vmla.f32 q6, q2, d0[0]\n"
110 "vmla.f32 q7, q2, d0[1]\n"
111 ASM_PREFETCH("[%[b_ptr], #192]")
112 "vmla.f32 q8, q2, d1[0]\n"
113 "vmla.f32 q9, q2, d1[1]\n"
114 "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
115
116 "vmla.f32 q10, q3, d3[0]\n"
117 "vmla.f32 q11, q3, d3[1]\n"
118 "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
119 "vmla.f32 q12, q3, d0[0]\n"
120 "vmla.f32 q13, q3, d0[1]\n"
121 "vmla.f32 q14, q3, d1[0]\n"
122 "vmla.f32 q15, q3, d1[1]\n"
123 "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
124
125 // Unroll 2
126 "vmla.f32 q4, q2, d2[0]\n"
127 "vmla.f32 q5, q2, d2[1]\n"
128 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
129 "vmla.f32 q6, q2, d3[0]\n"
130 "vmla.f32 q7, q2, d3[1]\n"
131 ASM_PREFETCH("[%[a_ptr], #240]")
132 "vmla.f32 q8, q2, d0[0]\n"
133 "vmla.f32 q9, q2, d0[1]\n"
134 "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
135
136 "vmla.f32 q10, q3, d2[0]\n"
137 "vmla.f32 q11, q3, d2[1]\n"
138 ASM_PREFETCH("[%[b_ptr], #208]")
139 "vmla.f32 q12, q3, d3[0]\n"
140 "vmla.f32 q13, q3, d3[1]\n"
141 "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
142 "vmla.f32 q14, q3, d0[0]\n"
143 "vmla.f32 q15, q3, d0[1]\n"
144 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
145
146 // Unroll 3
147 "vmla.f32 q4, q2, d1[0]\n"
148 "vmla.f32 q5, q2, d1[1]\n"
149 "vmla.f32 q6, q2, d2[0]\n"
150 "vmla.f32 q7, q2, d2[1]\n"
151 "vmla.f32 q8, q2, d3[0]\n"
152 "vmla.f32 q9, q2, d3[1]\n"
153 "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
154
155 "vmla.f32 q10, q3, d1[0]\n"
156 "vmla.f32 q11, q3, d1[1]\n"
157 "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
158 "vmla.f32 q12, q3, d2[0]\n"
159 "vmla.f32 q13, q3, d2[1]\n"
160 "vmla.f32 q14, q3, d3[0]\n"
161 "vmla.f32 q15, q3, d3[1]\n"
162 "bne 1b\n"
163
164 // Branch here if we never execute main loop.
165 "6:\n"
166
167 // "Tails" shows how many multiply blocks are needed at the
168 // end, must be 1-4 inclusive. Bail out to alternative tail
169 // immediately if it's 1.
170 "subs %[tails], %[tails], #1\n"
171 "beq 3f\n"
172
173 // Detached final iteration
174 // Unroll 0
175 "vmla.f32 q4, q2, d0[0]\n"
176 "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
177 "vmla.f32 q5, q2, d0[1]\n"
178 "vmla.f32 q6, q2, d1[0]\n"
179 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
180 "vmla.f32 q7, q2, d1[1]\n"
181 "vmla.f32 q8, q2, d2[0]\n"
182 "subs %[tails], %[tails], #1\n"
183 "vmla.f32 q9, q2, d2[1]\n"
184 "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
185
186 "vmla.f32 q10, q3, d0[0]\n"
187 "vmla.f32 q11, q3, d0[1]\n"
188 "vmla.f32 q12, q3, d1[0]\n"
189 "vmla.f32 q13, q3, d1[1]\n"
190 "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
191 "vmla.f32 q14, q3, d2[0]\n"
192 "vmla.f32 q15, q3, d2[1]\n"
193 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
194 "beq 4f\n"
195
196 // Unroll 1
197 "vmla.f32 q4, q2, d3[0]\n"
198 "vmla.f32 q5, q2, d3[1]\n"
199 "subs %[tails], %[tails], #1\n"
200 "vmla.f32 q6, q2, d0[0]\n"
201 "vmla.f32 q7, q2, d0[1]\n"
202 "vmla.f32 q8, q2, d1[0]\n"
203 "vmla.f32 q9, q2, d1[1]\n"
204 "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
205
206 "vmla.f32 q10, q3, d3[0]\n"
207 "vmla.f32 q11, q3, d3[1]\n"
208 "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
209 "vmla.f32 q12, q3, d0[0]\n"
210 "vmla.f32 q13, q3, d0[1]\n"
211 "vmla.f32 q14, q3, d1[0]\n"
212 "vmla.f32 q15, q3, d1[1]\n"
213 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
214 "beq 5f\n"
215
216 // Unroll 2
217 "vld1.32 {d0-d1}, [%[a_ptr] :64]!\n"
218 "vmla.f32 q4, q2, d2[0]\n"
219 "vmla.f32 q5, q2, d2[1]\n"
220 "vmla.f32 q6, q2, d3[0]\n"
221 "vmla.f32 q7, q2, d3[1]\n"
222 "vmla.f32 q8, q2, d0[0]\n"
223 "vmla.f32 q9, q2, d0[1]\n"
224 "vld1.32 {d4-d5}, [%[b_ptr] :128]!\n"
225
226 "vmla.f32 q10, q3, d2[0]\n"
227 "vmla.f32 q11, q3, d2[1]\n"
228 "vmla.f32 q12, q3, d3[0]\n"
229 "vmla.f32 q13, q3, d3[1]\n"
230 "vld1.32 {d2-d3}, [%[a_ptr] :64]!\n"
231 "vmla.f32 q14, q3, d0[0]\n"
232 "vmla.f32 q15, q3, d0[1]\n"
233 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
234
235 // Unroll 3
236 "vmla.f32 q4, q2, d1[0]\n"
237 "vmla.f32 q10, q3, d1[0]\n"
238 "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
239 "vmla.f32 q5, q2, d1[1]\n"
240 "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
241 "vmla.f32 q11, q3, d1[1]\n"
242 "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
243 "vmla.f32 q6, q2, d2[0]\n"
244 "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
245 "vmla.f32 q12, q3, d2[0]\n"
246 "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
247 "vmla.f32 q7, q2, d2[1]\n"
248 "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
249 "vmla.f32 q13, q3, d2[1]\n"
250 "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
251 "vmla.f32 q8, q2, d3[0]\n"
252 "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
253 "vmla.f32 q14, q3, d3[0]\n"
254 "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
255 "vmla.f32 q9, q2, d3[1]\n"
256 "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
257 "vmla.f32 q15, q3, d3[1]\n"
258 "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
259 "b 2f\n"
260
261 // tails==1 final tail
262 "3:\n"
263 "vmla.f32 q4, q2, d0[0]\n"
264 "vld1.32 {d2}, [%[a_ptr] :64]!\n"
265 "vmla.f32 q5, q2, d0[1]\n"
266 "vld1.32 {d6-d7}, [%[b_ptr] :128]!\n"
267 "vmla.f32 q6, q2, d1[0]\n"
268 "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
269 "vmla.f32 q10, q3, d0[0]\n"
270 "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
271 "vmla.f32 q11, q3, d0[1]\n"
272 "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
273 "vmla.f32 q12, q3, d1[0]\n"
274 "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
275 "vmla.f32 q7, q2, d1[1]\n"
276 "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
277 "vmla.f32 q13, q3, d1[1]\n"
278 "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
279 "vmla.f32 q8, q2, d2[0]\n"
280 "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
281 "vmla.f32 q14, q3, d2[0]\n"
282 "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
283 "vmla.f32 q9, q2, d2[1]\n"
284 "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
285 "vmla.f32 q15, q3, d2[1]\n"
286 "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
287 "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
288 "b 2f\n"
289
290 // tails==2 final tail
291 "4:\n"
292 "vmla.f32 q4, q2, d3[0]\n"
293 "vmla.f32 q10, q3, d3[0]\n"
294 "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
295 "vmla.f32 q5, q2, d3[1]\n"
296 "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
297 "vmla.f32 q11, q3, d3[1]\n"
298 "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
299 "vmla.f32 q6, q2, d0[0]\n"
300 "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
301 "vmla.f32 q12, q3, d0[0]\n"
302 "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
303 "vmla.f32 q7, q2, d0[1]\n"
304 "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
305 "vmla.f32 q13, q3, d0[1]\n"
306 "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
307 "vmla.f32 q8, q2, d1[0]\n"
308 "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
309 "vmla.f32 q14, q3, d1[0]\n"
310 "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
311 "vmla.f32 q9, q2, d1[1]\n"
312 "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
313 "vmla.f32 q15, q3, d1[1]\n"
314 "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
315 "b 2f\n"
316
317 // tails==3 final tail
318 "5:\n"
319 "vmla.f32 q4, q2, d2[0]\n"
320 "vld1.32 {d0}, [%[a_ptr] :64]!\n"
321 "vmla.f32 q5, q2, d2[1]\n"
322 "vmla.f32 q6, q2, d3[0]\n"
323 "vst1.32 {d8-d9}, [%[c_ptr] :128]!\n"
324 "vmla.f32 q10, q3, d2[0]\n"
325 "vst1.32 {d20-d21}, [%[c_ptr] :128]!\n"
326 "vmla.f32 q11, q3, d2[1]\n"
327 "vst1.32 {d10-d11}, [%[c_ptr] :128]!\n"
328 "vmla.f32 q12, q3, d3[0]\n"
329 "vst1.32 {d22-d23}, [%[c_ptr] :128]!\n"
330 "vmla.f32 q7, q2, d3[1]\n"
331 "vst1.32 {d12-d13}, [%[c_ptr] :128]!\n"
332 "vmla.f32 q13, q3, d3[1]\n"
333 "vst1.32 {d24-d25}, [%[c_ptr] :128]!\n"
334 "vmla.f32 q8, q2, d0[0]\n"
335 "vst1.32 {d14-d15}, [%[c_ptr] :128]!\n"
336 "vmla.f32 q14, q3, d0[0]\n"
337 "vst1.32 {d26-d27}, [%[c_ptr] :128]!\n"
338 "vmla.f32 q9, q2, d0[1]\n"
339 "vst1.32 {d16-d17}, [%[c_ptr] :128]!\n"
340 "vmla.f32 q15, q3, d0[1]\n"
341 "vst1.32 {d28-d29}, [%[c_ptr] :128]!\n"
342 "vst1.32 {d18-d19}, [%[c_ptr] :128]!\n"
343
344 "2:\n"
345 "vst1.32 {d30-d31}, [%[c_ptr] :128]!\n"
346 : [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr), [k] "+r" (k), [tails] "+r" (tails)
347 :
348 : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", "q12", "q13", "q14", "q15",
349 "cc", "memory"
350 );
351 }
352 }
353 }
354
355 } // namespace arm_gemm
356
357 #endif
358