xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/merges/a64_merge_u32_4x4.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #pragma once
25 
26 #ifdef __aarch64__
27 
28 template<>
MergeResults(uint32_t * out,const uint32_t * in,const int ldout,const int y0,const int ymax,const int x0,const int xmax,const uint32_t * bias,Activation,bool append)29 void MergeResults<4, 4, false>(uint32_t *out, const uint32_t *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const uint32_t *bias, Activation , bool append)
30 {
31     const uint32_t *inptr = in;
32     uint32_t nullbias[4];
33 
34 
35     if (!append && !bias)
36     {
37         memset(nullbias, 0, (4 * sizeof(uint32_t)));
38     }
39 
40     for (int y=y0; y<ymax; y+=4)
41     {
42         uint32_t *outptr0 = out + (y * ldout) + x0;
43         uint32_t *outptr1 = outptr0 + ldout;
44         uint32_t *outptr2 = outptr1 + ldout;
45         uint32_t *outptr3 = outptr2 + ldout;
46 
47         const int height = ymax - y;
48 
49         for (int i=x0; i<xmax; i+=4)
50         {
51             if (append)
52             {
53                 switch(height)
54                 {
55                 case 1:
56                     {
57                         if ((i+3) >= xmax)
58                         {
59                             for (int xi=0; xi<3; xi++)
60                             {
61                                 if ((i+xi) < xmax)
62                                 {
63                                     *outptr0 += inptr[xi];
64                                     outptr0++;
65                                 }
66                             }
67                             inptr += 16;
68                         } else {
69                             /* Optimized routine to copy an entire block */
70                             __asm __volatile (
71                                 "ldr q2, [%[outptr0]]\n"
72                                 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
73                                 "ldr q10, [%[inptr]]\n"
74                                 "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
75                                 "add %[inptr], %[inptr], #0x40\n"
76                                 "add v10.4s, v10.4s, v2.4s\n"
77                                 "str q10, [%[outptr0]]\n"
78                                 "add %[outptr0], %[outptr0], #0x10\n"
79                             : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
80                               [inptr] "+r" (inptr)
81                             :
82                             : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
83                             );
84                         }
85                     }
86                     break;
87 
88                 case 2:
89                     {
90                         if ((i+3) >= xmax)
91                         {
92                             for (int xi=0; xi<3; xi++)
93                             {
94                                 if ((i+xi) < xmax)
95                                 {
96                                     *outptr0 += inptr[xi];
97                                     outptr0++;
98                                     *outptr1 += inptr[xi + 4];
99                                     outptr1++;
100                                 }
101                             }
102                             inptr += 16;
103                         } else {
104                             /* Optimized routine to copy an entire block */
105                             __asm __volatile (
106                                 "ldr q2, [%[outptr0]]\n"
107                                 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
108                                 "ldr q10, [%[inptr]]\n"
109                                 "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
110                                 "ldr q3, [%[outptr1]]\n"
111                                 "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
112                                 "add v10.4s, v10.4s, v2.4s\n"
113                                 "ldr q11, [%[inptr], #0x10]\n"
114                                 "add %[inptr], %[inptr], #0x40\n"
115                                 "add v11.4s, v11.4s, v3.4s\n"
116                                 "str q10, [%[outptr0]]\n"
117                                 "add %[outptr0], %[outptr0], #0x10\n"
118                                 "str q11, [%[outptr1]]\n"
119                                 "add %[outptr1], %[outptr1], #0x10\n"
120                             : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
121                               [inptr] "+r" (inptr)
122                             :
123                             : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
124                             );
125                         }
126                     }
127                     break;
128 
129                 case 3:
130                     {
131                         if ((i+3) >= xmax)
132                         {
133                             for (int xi=0; xi<3; xi++)
134                             {
135                                 if ((i+xi) < xmax)
136                                 {
137                                     *outptr0 += inptr[xi];
138                                     outptr0++;
139                                     *outptr1 += inptr[xi + 4];
140                                     outptr1++;
141                                     *outptr2 += inptr[xi + 8];
142                                     outptr2++;
143                                 }
144                             }
145                             inptr += 16;
146                         } else {
147                             /* Optimized routine to copy an entire block */
148                             __asm __volatile (
149                                 "ldr q2, [%[outptr0]]\n"
150                                 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
151                                 "ldr q10, [%[inptr]]\n"
152                                 "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
153                                 "ldr q3, [%[outptr1]]\n"
154                                 "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
155                                 "add v10.4s, v10.4s, v2.4s\n"
156                                 "ldr q11, [%[inptr], #0x10]\n"
157                                 "ldr q4, [%[outptr2]]\n"
158                                 "prfm PLDL1KEEP, [%[outptr2], #0x20]\n"
159                                 "ldr q12, [%[inptr], #0x20]\n"
160                                 "add %[inptr], %[inptr], #0x40\n"
161                                 "add v11.4s, v11.4s, v3.4s\n"
162                                 "str q10, [%[outptr0]]\n"
163                                 "add %[outptr0], %[outptr0], #0x10\n"
164                                 "add v12.4s, v12.4s, v4.4s\n"
165                                 "str q11, [%[outptr1]]\n"
166                                 "add %[outptr1], %[outptr1], #0x10\n"
167                                 "str q12, [%[outptr2]]\n"
168                                 "add %[outptr2], %[outptr2], #0x10\n"
169                             : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
170                               [inptr] "+r" (inptr)
171                             :
172                             : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
173                             );
174                         }
175                     }
176                     break;
177 
178                 default:
179                 case 4:
180                     {
181                         if ((i+3) >= xmax)
182                         {
183                             for (int xi=0; xi<3; xi++)
184                             {
185                                 if ((i+xi) < xmax)
186                                 {
187                                     *outptr0 += inptr[xi];
188                                     outptr0++;
189                                     *outptr1 += inptr[xi + 4];
190                                     outptr1++;
191                                     *outptr2 += inptr[xi + 8];
192                                     outptr2++;
193                                     *outptr3 += inptr[xi + 12];
194                                     outptr3++;
195                                 }
196                             }
197                             inptr += 16;
198                         } else {
199                             /* Optimized routine to copy an entire block */
200                             __asm __volatile (
201                                 "ldr q2, [%[outptr0]]\n"
202                                 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
203                                 "ldr q10, [%[inptr]]\n"
204                                 "prfm PLDL1KEEP, [%[outptr0], #0x20]\n"
205                                 "ldr q3, [%[outptr1]]\n"
206                                 "prfm PLDL1KEEP, [%[outptr1], #0x20]\n"
207                                 "add v10.4s, v10.4s, v2.4s\n"
208                                 "ldr q11, [%[inptr], #0x10]\n"
209                                 "ldr q4, [%[outptr2]]\n"
210                                 "prfm PLDL1KEEP, [%[outptr2], #0x20]\n"
211                                 "ldr q12, [%[inptr], #0x20]\n"
212                                 "prfm PLDL1KEEP, [%[outptr3], #0x20]\n"
213                                 "add v11.4s, v11.4s, v3.4s\n"
214                                 "str q10, [%[outptr0]]\n"
215                                 "ldr q5, [%[outptr3]]\n"
216                                 "add %[outptr0], %[outptr0], #0x10\n"
217                                 "add v12.4s, v12.4s, v4.4s\n"
218                                 "str q11, [%[outptr1]]\n"
219                                 "ldr q13, [%[inptr], #0x30]\n"
220                                 "add %[outptr1], %[outptr1], #0x10\n"
221                                 "add %[inptr], %[inptr], #0x40\n"
222                                 "str q12, [%[outptr2]]\n"
223                                 "add %[outptr2], %[outptr2], #0x10\n"
224                                 "add v13.4s, v13.4s, v5.4s\n"
225                                 "str q13, [%[outptr3]]\n"
226                                 "add %[outptr3], %[outptr3], #0x10\n"
227                             : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
228                               [inptr] "+r" (inptr)
229                             :
230                             : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
231                             );
232                         }
233                     }
234                     break;
235 
236 
237                 }
238             }
239             else
240             {
241                 const uint32_t *biasptr = bias ? bias + i : nullbias;
242 
243                 switch(height)
244                 {
245                 case 1:
246                     {
247                         if ((i+3) >= xmax)
248                         {
249                             for (int xi=0; xi<3; xi++)
250                             {
251                                 if ((i+xi) < xmax)
252                                 {
253                                     *outptr0 = biasptr[xi] + inptr[xi];
254                                     outptr0++;
255                                 }
256                             }
257                             inptr += 16;
258                         } else {
259                             /* Optimized routine to copy an entire block */
260                             __asm __volatile (
261                                 "ldr q2, [%[biasptr]]\n"
262                                 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
263                                 "ldr q11, [%[inptr]]\n"
264                                 "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
265                                 "add %[inptr], %[inptr], #0x40\n"
266                                 "add v11.4s, v11.4s, v2.4s\n"
267                                 "str q11, [%[outptr0]]\n"
268                                 "add %[outptr0], %[outptr0], #0x10\n"
269                             : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
270                               [inptr] "+r" (inptr)
271                             : [biasptr] "r" (biasptr)
272                             : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
273                             );
274                         }
275                     }
276                     break;
277 
278                 case 2:
279                     {
280                         if ((i+3) >= xmax)
281                         {
282                             for (int xi=0; xi<3; xi++)
283                             {
284                                 if ((i+xi) < xmax)
285                                 {
286                                     *outptr0 = biasptr[xi] + inptr[xi];
287                                     outptr0++;
288                                     *outptr1 = biasptr[xi] + inptr[xi + 4];
289                                     outptr1++;
290                                 }
291                             }
292                             inptr += 16;
293                         } else {
294                             /* Optimized routine to copy an entire block */
295                             __asm __volatile (
296                                 "ldr q2, [%[biasptr]]\n"
297                                 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
298                                 "ldr q11, [%[inptr]]\n"
299                                 "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
300                                 "ldr q12, [%[inptr], #0x10]\n"
301                                 "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
302                                 "add v11.4s, v11.4s, v2.4s\n"
303                                 "add %[inptr], %[inptr], #0x40\n"
304                                 "add v12.4s, v12.4s, v2.4s\n"
305                                 "str q11, [%[outptr0]]\n"
306                                 "add %[outptr0], %[outptr0], #0x10\n"
307                                 "str q12, [%[outptr1]]\n"
308                                 "add %[outptr1], %[outptr1], #0x10\n"
309                             : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
310                               [inptr] "+r" (inptr)
311                             : [biasptr] "r" (biasptr)
312                             : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
313                             );
314                         }
315                     }
316                     break;
317 
318                 case 3:
319                     {
320                         if ((i+3) >= xmax)
321                         {
322                             for (int xi=0; xi<3; xi++)
323                             {
324                                 if ((i+xi) < xmax)
325                                 {
326                                     *outptr0 = biasptr[xi] + inptr[xi];
327                                     outptr0++;
328                                     *outptr1 = biasptr[xi] + inptr[xi + 4];
329                                     outptr1++;
330                                     *outptr2 = biasptr[xi] + inptr[xi + 8];
331                                     outptr2++;
332                                 }
333                             }
334                             inptr += 16;
335                         } else {
336                             /* Optimized routine to copy an entire block */
337                             __asm __volatile (
338                                 "ldr q2, [%[biasptr]]\n"
339                                 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
340                                 "ldr q11, [%[inptr]]\n"
341                                 "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
342                                 "ldr q12, [%[inptr], #0x10]\n"
343                                 "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
344                                 "add v11.4s, v11.4s, v2.4s\n"
345                                 "ldr q13, [%[inptr], #0x20]\n"
346                                 "prfm PSTL1KEEP, [%[outptr2], #0x20]\n"
347                                 "add v12.4s, v12.4s, v2.4s\n"
348                                 "add %[inptr], %[inptr], #0x40\n"
349                                 "add v13.4s, v13.4s, v2.4s\n"
350                                 "str q11, [%[outptr0]]\n"
351                                 "add %[outptr0], %[outptr0], #0x10\n"
352                                 "str q12, [%[outptr1]]\n"
353                                 "add %[outptr1], %[outptr1], #0x10\n"
354                                 "str q13, [%[outptr2]]\n"
355                                 "add %[outptr2], %[outptr2], #0x10\n"
356                             : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
357                               [inptr] "+r" (inptr)
358                             : [biasptr] "r" (biasptr)
359                             : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
360                             );
361                         }
362                     }
363                     break;
364 
365                 default:
366                 case 4:
367                     {
368                         if ((i+3) >= xmax)
369                         {
370                             for (int xi=0; xi<3; xi++)
371                             {
372                                 if ((i+xi) < xmax)
373                                 {
374                                     *outptr0 = biasptr[xi] + inptr[xi];
375                                     outptr0++;
376                                     *outptr1 = biasptr[xi] + inptr[xi + 4];
377                                     outptr1++;
378                                     *outptr2 = biasptr[xi] + inptr[xi + 8];
379                                     outptr2++;
380                                     *outptr3 = biasptr[xi] + inptr[xi + 12];
381                                     outptr3++;
382                                 }
383                             }
384                             inptr += 16;
385                         } else {
386                             /* Optimized routine to copy an entire block */
387                             __asm __volatile (
388                                 "ldr q2, [%[biasptr]]\n"
389                                 "prfm PLDL1KEEP, [%[inptr], #0x40]\n"
390                                 "ldr q11, [%[inptr]]\n"
391                                 "prfm PSTL1KEEP, [%[outptr0], #0x20]\n"
392                                 "ldr q12, [%[inptr], #0x10]\n"
393                                 "prfm PSTL1KEEP, [%[outptr1], #0x20]\n"
394                                 "add v11.4s, v11.4s, v2.4s\n"
395                                 "ldr q13, [%[inptr], #0x20]\n"
396                                 "ldr q14, [%[inptr], #0x30]\n"
397                                 "prfm PSTL1KEEP, [%[outptr2], #0x20]\n"
398                                 "add v12.4s, v12.4s, v2.4s\n"
399                                 "str q11, [%[outptr0]]\n"
400                                 "add v13.4s, v13.4s, v2.4s\n"
401                                 "add %[outptr0], %[outptr0], #0x10\n"
402                                 "add v14.4s, v14.4s, v2.4s\n"
403                                 "str q12, [%[outptr1]]\n"
404                                 "add %[outptr1], %[outptr1], #0x10\n"
405                                 "prfm PSTL1KEEP, [%[outptr3], #0x20]\n"
406                                 "add %[inptr], %[inptr], #0x40\n"
407                                 "str q13, [%[outptr2]]\n"
408                                 "add %[outptr2], %[outptr2], #0x10\n"
409                                 "str q14, [%[outptr3]]\n"
410                                 "add %[outptr3], %[outptr3], #0x10\n"
411                             : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
412                               [inptr] "+r" (inptr)
413                             : [biasptr] "r" (biasptr)
414                             : "v0", "v1", "v2", "v3", "v4", "v5", "v6", "v7", "v8", "v9", "v10", "v11", "v12", "v13", "v14", "v15", "v16", "v17", "v18", "memory"
415                             );
416                         }
417                     }
418                     break;
419 
420 
421                 }
422             }
423         }
424     }
425 }
426 
427 #endif // __aarch64__
428