1 /*
2 * Copyright (c) 2017-2018 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24
25 // Build on AArch64 where either FP16_KERNELS is set or FP16 is explicitly supported.
26 #if defined(__aarch64__) && (defined(FP16_KERNELS) || defined(__ARM_FEATURE_FP16_VECTOR_ARITHMETIC))
27
28 #include <arm_neon.h>
29
30 #include "../../asmlib.hpp"
31
32 // Kernel implementation.
33 //
34 // Assume that "Apanel" points to a chunk of A blocks (each size 8xK) in read-order.
35 // Assume that "Bpanel" points to a chunk of B blocks (each size 12xK) in read-order.
36 // Assume that "Cpanel" points to a chunk of C output blocks (each size
37 // 12x8), the chunks being arranged in a row major fashion.
38 //
39 // Note that the intent of this is that either ablocks or bblocks will be 1
40 // - this construction allows the output loop to proceed in either order.
41
42 namespace arm_gemm {
43
a64_hgemm_asimd_8x24_a55r1(const __fp16 * Apanel,const __fp16 * Bpanel,__fp16 * Cpanel,int ablocks,int bblocks,int K)44 void a64_hgemm_asimd_8x24_a55r1(const __fp16 *Apanel, const __fp16 *Bpanel, __fp16 *Cpanel, int ablocks, int bblocks, int K) {
45 const __fp16 *a_ptr = Apanel;
46 __fp16 *c_ptr = Cpanel;
47
48 // Fix up for odd lengths - set a flag if K is odd, but make
49 // sure we round up the iteration count.
50 int oddk = (K & 1);
51 int k_iters = ((K+1)/2) - 1;
52
53 for (int yb=0; yb<ablocks; yb++) {
54 const __fp16 *a_ptr0 = a_ptr;
55 const __fp16 *b_ptr = Bpanel;
56
57 for (int xb=0; xb<bblocks; xb++) {
58 int k = k_iters;
59 a_ptr = a_ptr0;
60
61 // As A55 requires 64-bit loads anyway, just use 64 bits of the
62 // "A" operands to save on "ins" instructions. Since A55 is
63 // in-order, two sets of "A" operands and one set of "B" is
64 // sufficient.
65 register float16x8_t a0 asm("v0");
66 register float16x8_t a1 asm("v1");
67 register float16x8_t a0a asm("v2");
68 register float16x8_t a1a asm("v3");
69 register float16x8_t b0 asm("v4");
70 register float16x8_t b1 asm("v5");
71 register float16x8_t b2 asm("v6");
72
73 __asm __volatile (
74 // Initialize result registers, load initial operands, prime prefetches.
75 "movi v8.8h, #0x0\n"
76 "ldr %d[a0], [%[a_ptr]]\n"
77 "movi v9.8h, #0x0\n"
78 "ldr %q[b0], [%[b_ptr]]\n"
79 "movi v10.8h, #0x0\n"
80 "ldr %d[a1], [%[a_ptr], #8]\n"
81 "movi v11.8h, #0x0\n"
82 "ldr %q[b1], [%[b_ptr], #16]\n"
83 "movi v12.8h, #0x0\n"
84 "movi v13.8h, #0x0\n"
85 ASM_PREFETCH("[%[b_ptr], #64]")
86 "movi v14.8h, #0x0\n"
87 "movi v15.8h, #0x0\n"
88 ASM_PREFETCH("[%[b_ptr], #128]")
89 "movi v16.8h, #0x0\n"
90 "movi v17.8h, #0x0\n"
91 ASM_PREFETCH("[%[a_ptr], #64]")
92 "movi v18.8h, #0x0\n"
93 "movi v19.8h, #0x0\n"
94 ASM_PREFETCH("[%[b_ptr], #192]")
95 "movi v20.8h, #0x0\n"
96 "movi v21.8h, #0x0\n"
97 ASM_PREFETCH("[%[b_ptr], #256]")
98 "movi v22.8h, #0x0\n"
99 "movi v23.8h, #0x0\n"
100 ASM_PREFETCH("[%[b_ptr], #320]")
101 "movi v24.8h, #0x0\n"
102 "movi v25.8h, #0x0\n"
103 "movi v26.8h, #0x0\n"
104 "movi v27.8h, #0x0\n"
105 "movi v28.8h, #0x0\n"
106 "movi v29.8h, #0x0\n"
107 "movi v30.8h, #0x0\n"
108 "movi v31.8h, #0x0\n"
109
110 // The loop is offset by these two instructions which must
111 // always be executed.
112 "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
113 "ldr %d[b2], [%[b_ptr], #32]\n"
114
115 // Skip loop if we are doing zero iterations of it.
116 "cbz %w[k], 4f\n"
117
118 "1:\n"
119 "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
120 "ldr x20, [%[b_ptr], #40]\n"
121 "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
122 "subs %w[k], %w[k], #1\n"
123 "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
124 "ldr %d[a0a], [%[a_ptr], #16]\n"
125
126 "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
127 "ins %[b2].d[1], x20\n"
128 "fmla v13.8h, %[b0].8h, %[a1].h[1]\n"
129 "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
130 "fmla v15.8h, %[b0].8h, %[a1].h[3]\n"
131 "ldr %d[a1a], [%[a_ptr], #24]\n"
132
133 "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
134 "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
135 "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
136 "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
137 "ldr %d[b0], [%[b_ptr], #48]\n"
138
139 "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
140 "fmla v21.8h, %[b1].8h, %[a1].h[1]\n"
141 "ldr x20, [%[b_ptr], #56]\n"
142 "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
143 "fmla v23.8h, %[b1].8h, %[a1].h[3]\n"
144 "ldr %d[b1], [%[b_ptr], #64]\n"
145
146 "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
147 "ins %[b0].d[1], x20\n"
148 "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
149 "ldr x20, [%[b_ptr], #72]\n"
150 "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
151 "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
152 ASM_PREFETCH("[%[a_ptr], #128]")
153
154 "fmla v28.8h, %[b2].8h, %[a1].h[0]\n"
155 "fmla v29.8h, %[b2].8h, %[a1].h[1]\n"
156 ASM_PREFETCH("[%[b_ptr], #384]")
157 "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
158 "fmla v31.8h, %[b2].8h, %[a1].h[3]\n"
159 "ldr %d[b2], [%[b_ptr], #80]\n"
160
161 // Unroll 1
162 "fmla v8.8h , %[b0].8h, %[a0a].h[0]\n"
163 "ins %[b1].d[1], x20\n"
164 "fmla v9.8h , %[b0].8h, %[a0a].h[1]\n"
165 "ldr x20, [%[b_ptr], #88]\n"
166 "fmla v10.8h, %[b0].8h, %[a0a].h[2]\n"
167 "fmla v11.8h, %[b0].8h, %[a0a].h[3]\n"
168 "ldr %d[a0], [%[a_ptr], #32]\n"
169
170 "fmla v12.8h, %[b0].8h, %[a1a].h[0]\n"
171 "ins %[b2].d[1], x20\n"
172 "fmla v13.8h, %[b0].8h, %[a1a].h[1]\n"
173 "fmla v14.8h, %[b0].8h, %[a1a].h[2]\n"
174 "fmla v15.8h, %[b0].8h, %[a1a].h[3]\n"
175 "ldr %d[a1], [%[a_ptr], #40]\n"
176
177 "fmla v16.8h, %[b1].8h, %[a0a].h[0]\n"
178 "add %[a_ptr], %[a_ptr], #32\n"
179 "fmla v17.8h, %[b1].8h, %[a0a].h[1]\n"
180 "fmla v18.8h, %[b1].8h, %[a0a].h[2]\n"
181 "fmla v19.8h, %[b1].8h, %[a0a].h[3]\n"
182 "ldr %d[b0], [%[b_ptr], #96]\n"
183
184 "fmla v20.8h, %[b1].8h, %[a1a].h[0]\n"
185 "fmla v21.8h, %[b1].8h, %[a1a].h[1]\n"
186 "ldr x20, [%[b_ptr], #104]\n"
187 "fmla v22.8h, %[b1].8h, %[a1a].h[2]\n"
188 "fmla v23.8h, %[b1].8h, %[a1a].h[3]\n"
189 "ldr %d[b1], [%[b_ptr], #112]\n"
190
191 "fmla v24.8h, %[b2].8h, %[a0a].h[0]\n"
192 "ins %[b0].d[1], x20\n"
193 "fmla v25.8h, %[b2].8h, %[a0a].h[1]\n"
194 "ldr x20, [%[b_ptr], #120]\n"
195 "fmla v26.8h, %[b2].8h, %[a0a].h[2]\n"
196 "fmla v27.8h, %[b2].8h, %[a0a].h[3]\n"
197
198 "fmla v28.8h, %[b2].8h, %[a1a].h[0]\n"
199 ASM_PREFETCH("[%[b_ptr], #448]")
200 "fmla v29.8h, %[b2].8h, %[a1a].h[1]\n"
201 "add %[b_ptr], %[b_ptr], #96\n"
202 "fmla v30.8h, %[b2].8h, %[a1a].h[2]\n"
203 "ins %[b1].d[1], x20\n"
204 "fmla v31.8h, %[b2].8h, %[a1a].h[3]\n"
205 "ldr %d[b2], [%[b_ptr], #32]\n"
206
207 "fmla v8.8h , %[b0].8h, %[a0].h[0]\n"
208 "bne 1b\n"
209
210 "4:\n"
211
212 // Start final iteration - branch off to "odd" code before we load a0a
213 "fmla v9.8h , %[b0].8h, %[a0].h[1]\n"
214 "ldr x20, [%[b_ptr], #40]\n"
215 "fmla v10.8h, %[b0].8h, %[a0].h[2]\n"
216 "cbnz %w[oddk], 2f\n"
217
218 // Even K continuation
219 "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
220 "ldr %d[a0a], [%[a_ptr], #16]\n"
221
222 "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
223 "ins %[b2].d[1], x20\n"
224 "fmla v13.8h, %[b0].8h, %[a1].h[1]\n"
225 ASM_PREFETCHW("[%[c_ptr]]")
226 "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
227 "fmla v15.8h, %[b0].8h, %[a1].h[3]\n"
228 "ldr %d[a1a], [%[a_ptr], #24]\n"
229
230 "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
231 "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
232 ASM_PREFETCHW("[%[c_ptr], #64]")
233 "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
234 "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
235 "ldr %d[b0], [%[b_ptr], #48]\n"
236
237 "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
238 "fmla v21.8h, %[b1].8h, %[a1].h[1]\n"
239 "ldr x20, [%[b_ptr], #56]\n"
240 "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
241 "fmla v23.8h, %[b1].8h, %[a1].h[3]\n"
242 "ldr %d[b1], [%[b_ptr], #64]\n"
243
244 "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
245 "ins %[b0].d[1], x20\n"
246 "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
247 "ldr x20, [%[b_ptr], #72]\n"
248 "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
249 "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
250 ASM_PREFETCHW("[%[c_ptr], #128]")
251
252 "fmla v28.8h, %[b2].8h, %[a1].h[0]\n"
253 "fmla v29.8h, %[b2].8h, %[a1].h[1]\n"
254 ASM_PREFETCHW("[%[c_ptr], #192]")
255 "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
256 "fmla v31.8h, %[b2].8h, %[a1].h[3]\n"
257 "ldr %d[b2], [%[b_ptr], #80]\n"
258
259 "fmla v8.8h , %[b0].8h, %[a0a].h[0]\n"
260 "ins %[b1].d[1], x20\n"
261 "fmla v9.8h , %[b0].8h, %[a0a].h[1]\n"
262 "ldr x20, [%[b_ptr], #88]\n"
263 "fmla v10.8h, %[b0].8h, %[a0a].h[2]\n"
264 "fmla v11.8h, %[b0].8h, %[a0a].h[3]\n"
265 ASM_PREFETCHW("[%[c_ptr], #256]")
266
267 "fmla v12.8h, %[b0].8h, %[a1a].h[0]\n"
268 "ins %[b2].d[1], x20\n"
269 "fmla v13.8h, %[b0].8h, %[a1a].h[1]\n"
270 ASM_PREFETCHW("[%[c_ptr], #320]")
271 "fmla v14.8h, %[b0].8h, %[a1a].h[2]\n"
272 "fmla v15.8h, %[b0].8h, %[a1a].h[3]\n"
273 "ldr %d[a1], [%[a_ptr], #40]\n"
274
275 "fmla v16.8h, %[b1].8h, %[a0a].h[0]\n"
276 "add %[a_ptr], %[a_ptr], #32\n"
277 "fmla v17.8h, %[b1].8h, %[a0a].h[1]\n"
278 ASM_PREFETCHWL2("[%[c_ptr], #384]")
279 "fmla v18.8h, %[b1].8h, %[a0a].h[2]\n"
280 "fmla v19.8h, %[b1].8h, %[a0a].h[3]\n"
281 ASM_PREFETCHWL2("[%[c_ptr], #448]")
282
283 "fmla v20.8h, %[b1].8h, %[a1a].h[0]\n"
284 "fmla v21.8h, %[b1].8h, %[a1a].h[1]\n"
285 ASM_PREFETCHWL2("[%[c_ptr], #512]")
286 "fmla v22.8h, %[b1].8h, %[a1a].h[2]\n"
287 "fmla v23.8h, %[b1].8h, %[a1a].h[3]\n"
288 ASM_PREFETCHWL2("[%[c_ptr], #576]")
289
290 "fmla v24.8h, %[b2].8h, %[a0a].h[0]\n"
291 "fmla v25.8h, %[b2].8h, %[a0a].h[1]\n"
292 ASM_PREFETCHWL2("[%[c_ptr], #640]")
293 "fmla v26.8h, %[b2].8h, %[a0a].h[2]\n"
294 "fmla v27.8h, %[b2].8h, %[a0a].h[3]\n"
295 ASM_PREFETCHWL2("[%[c_ptr], #704]")
296
297 "fmla v28.8h, %[b2].8h, %[a1a].h[0]\n"
298 "fmla v29.8h, %[b2].8h, %[a1a].h[1]\n"
299 "add %[b_ptr], %[b_ptr], #96\n"
300 "fmla v30.8h, %[b2].8h, %[a1a].h[2]\n"
301 "fmla v31.8h, %[b2].8h, %[a1a].h[3]\n"
302 "b 3f\n"
303
304 "2:\n"
305
306 // Odd tail
307 "fmla v11.8h, %[b0].8h, %[a0].h[3]\n"
308 ASM_PREFETCHW("[%[c_ptr]]")
309
310 "fmla v12.8h, %[b0].8h, %[a1].h[0]\n"
311 "ins %[b2].d[1], x20\n"
312 "fmla v13.8h, %[b0].8h, %[a1].h[1]\n"
313 ASM_PREFETCHW("[%[c_ptr], #64]")
314 "fmla v14.8h, %[b0].8h, %[a1].h[2]\n"
315 "add %[a_ptr], %[a_ptr], #16\n"
316 "fmla v15.8h, %[b0].8h, %[a1].h[3]\n"
317 ASM_PREFETCHW("[%[c_ptr], #128]")
318
319 "fmla v16.8h, %[b1].8h, %[a0].h[0]\n"
320 "add %[b_ptr], %[b_ptr], #48\n"
321 "fmla v17.8h, %[b1].8h, %[a0].h[1]\n"
322 ASM_PREFETCHW("[%[c_ptr], #192]")
323 "fmla v18.8h, %[b1].8h, %[a0].h[2]\n"
324 "fmla v19.8h, %[b1].8h, %[a0].h[3]\n"
325 ASM_PREFETCHW("[%[c_ptr], #256]")
326
327 "fmla v20.8h, %[b1].8h, %[a1].h[0]\n"
328 "fmla v21.8h, %[b1].8h, %[a1].h[1]\n"
329 ASM_PREFETCHW("[%[c_ptr], #320]")
330 "fmla v22.8h, %[b1].8h, %[a1].h[2]\n"
331 "fmla v23.8h, %[b1].8h, %[a1].h[3]\n"
332 ASM_PREFETCHWL2("[%[c_ptr], #384]")
333
334 "fmla v24.8h, %[b2].8h, %[a0].h[0]\n"
335 "fmla v25.8h, %[b2].8h, %[a0].h[1]\n"
336 ASM_PREFETCHWL2("[%[c_ptr], #384]")
337 "fmla v26.8h, %[b2].8h, %[a0].h[2]\n"
338 "fmla v27.8h, %[b2].8h, %[a0].h[3]\n"
339 ASM_PREFETCHWL2("[%[c_ptr], #448]")
340
341 "fmla v28.8h, %[b2].8h, %[a1].h[0]\n"
342 ASM_PREFETCHWL2("[%[c_ptr], #512]")
343 "fmla v29.8h, %[b2].8h, %[a1].h[1]\n"
344 ASM_PREFETCHWL2("[%[c_ptr], #576]")
345 "fmla v30.8h, %[b2].8h, %[a1].h[2]\n"
346 ASM_PREFETCHWL2("[%[c_ptr], #640]")
347 "fmla v31.8h, %[b2].8h, %[a1].h[3]\n"
348 ASM_PREFETCHWL2("[%[c_ptr], #704]")
349
350 // Common tail
351 // A55 won't dual issue these stores with anything else, so
352 // simplest to do them all in this common code.
353 "3:\n"
354 "str q8, [%[c_ptr]]\n"
355 "str q16, [%[c_ptr], #16]\n"
356 "str q24, [%[c_ptr], #32]\n"
357 "str q9, [%[c_ptr], #48]\n"
358 "str q17, [%[c_ptr], #64]\n"
359 "str q25, [%[c_ptr], #80]\n"
360 "str q10, [%[c_ptr], #96]\n"
361 "str q18, [%[c_ptr], #112]\n"
362 "str q26, [%[c_ptr], #128]\n"
363 "str q11, [%[c_ptr], #144]\n"
364 "str q19, [%[c_ptr], #160]\n"
365 "str q27, [%[c_ptr], #176]\n"
366 "str q12, [%[c_ptr], #192]\n"
367 "str q20, [%[c_ptr], #208]\n"
368 "str q28, [%[c_ptr], #224]\n"
369 "str q13, [%[c_ptr], #240]\n"
370 "str q21, [%[c_ptr], #256]\n"
371 "str q29, [%[c_ptr], #272]\n"
372 "str q14, [%[c_ptr], #288]\n"
373 "str q22, [%[c_ptr], #304]\n"
374 "str q30, [%[c_ptr], #320]\n"
375 "str q15, [%[c_ptr], #336]\n"
376 "str q23, [%[c_ptr], #352]\n"
377 "str q31, [%[c_ptr], #368]\n"
378 "5:\n"
379 "add %[c_ptr], %[c_ptr], #384\n"
380 :
381 [a_ptr] "+r" (a_ptr), [b_ptr] "+r" (b_ptr), [c_ptr] "+r" (c_ptr),
382 [a0] "=w" (a0), [a0a] "=w" (a0a), [a1] "=w" (a1), [a1a] "=w" (a1a),
383 [b0] "=w" (b0), [b1] "=w" (b1), [b2] "=w" (b2), [k] "+r" (k)
384 : [oddk] "r" (oddk)
385 : "x20", "x21", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18",
386 "v19", "v20", "v21", "v22", "v23", "v24", "v25", "v26", "v27", "v28", "v29", "v30", "v31", "cc", "memory"
387 );
388 }
389 }
390 }
391
392 } // namespace arm_gemm
393
394 #endif // __aarch64__ && (FP16_KERNELS || __ARM_FEATURE_FP16_VECTOR_ARITHMETIC)
395