xref: /aosp_15_r20/external/ComputeLibrary/src/core/NEON/kernels/arm_gemm/merges/a32_merge_float_8x6.hpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2018 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #pragma once
25 
26 #ifdef __arm__
27 
28 #include <arm_neon.h>
29 
30 template<>
MergeResults(float * out,const float * in,const int ldout,const int y0,const int ymax,const int x0,const int xmax,const float * bias,Activation act,bool append)31 void MergeResults<8, 6, false>(float *out, const float *in, const int ldout, const int y0, const int ymax, const int x0, const int xmax, const float *bias, Activation act, bool append) {
32     const float *inptr = in;
33     prefetch_6x(inptr);
34     prefetch_6x(inptr + 96);
35 
36     float nullbias[8];
37     float minval = - std::numeric_limits<float>::infinity();
38     float maxval =   std::numeric_limits<float>::infinity();
39 
40     switch(act.type)
41     {
42         default:
43         case Activation::Type::None:
44             break;
45         case Activation::Type::BoundedReLU:
46             maxval = static_cast<float>(act.param1);
47             /* fall through */
48         case Activation::Type::ReLU:
49             minval = 0.0f;
50             break;
51     }
52 
53     float32x4_t minv = vdupq_n_f32(minval);
54     float32x4_t maxv = vdupq_n_f32(maxval);
55 
56     if (!append && !bias)
57     {
58         memset(nullbias, 0, (8 * sizeof(float)));
59     }
60 
61     for (int y=y0; y<ymax; y+=8) {
62         float *outptr0 = out + (y * ldout) + x0;
63         float *outptr1 = outptr0 + ldout;
64         float *outptr2 = outptr1 + ldout;
65         float *outptr3 = outptr2 + ldout;
66         float *outptr4 = outptr3 + ldout;
67         float *outptr5 = outptr4 + ldout;
68 
69         prefetch_2x(outptr0);
70         prefetch_2x(outptr1);
71         prefetch_2x(outptr2);
72         prefetch_2x(outptr3);
73         prefetch_2x(outptr4);
74         prefetch_2x(outptr5);
75 
76         for (int i=x0; i<xmax; i+=8) {
77             float dummyres[8];
78 
79             /* Make sure we throw away results if Y isn't a multiple of 8.
80              * We do this by pointing the result pointer at a dummy buffer
81              * we later discard.  */
82             if ((y+5) >= ymax) {
83                 switch ((y + 5) - ymax) {
84                     case 4:
85                         outptr1 = dummyres;
86                         /* fall through */
87                     case 3:
88                         outptr2 = dummyres;
89                         /* fall through */
90                     case 2:
91                         outptr3 = dummyres;
92                         /* fall through */
93                     case 1:
94                         outptr4 = dummyres;
95                         /* fall through */
96                     case 0:
97                         outptr5 = dummyres;
98                         break;
99 
100                     default:
101                         UNREACHABLE("Impossible.");
102                 }
103             }
104 
105             if (append) {
106                /* Append mode: Read, activate, write. */
107 
108                 /* For ragged X, manually copy over the valid results. */
109                 if ((i+7) >= xmax) {
110                     for (int xi=0; xi<8; xi++) {
111                         if ((i+xi) < xmax) {
112                             *outptr0 = std::min(std::max(minval, inptr[xi] + *outptr0), maxval);
113                             outptr0++;
114                             *outptr1 = std::min(std::max(minval, inptr[xi + 8] + *outptr1), maxval);
115                             outptr1++;
116                             *outptr2 = std::min(std::max(minval, inptr[xi + 16] + *outptr2), maxval);
117                             outptr2++;
118                             *outptr3 = std::min(std::max(minval, inptr[xi + 24] + *outptr3), maxval);
119                             outptr3++;
120                             *outptr4 = std::min(std::max(minval, inptr[xi + 32] + *outptr4), maxval);
121                             outptr4++;
122                             *outptr5 = std::min(std::max(minval, inptr[xi + 40] + *outptr5), maxval);
123                             outptr5++;
124                         }
125                     }
126                     inptr += 48;
127                 } else {
128                     /* Optimized routine to copy an entire block */
129                     __asm __volatile (
130                         // Rows 0-1
131                         "VLD1.32	{d0-d3},   [%[inptr]]!\n"
132                         "VLD1.32	{d8-d11},  [%[outptr0]]\n"
133                         "VLD1.32	{d4-d7},   [%[inptr]]!\n"
134                         "VLD1.32	{d12-d15}, [%[outptr1]]\n"
135 
136                         "VADD.f32	q4, q4, q0\n"
137                         ASM_PREFETCH("[%[inptr], #352]")
138                         "VADD.f32	q5, q5, q1\n"
139                         "VADD.f32	q6, q6, q2\n"
140                         "VADD.f32	q7, q7, q3\n"
141                         ASM_PREFETCH("[%[inptr], #416]")
142                         "VMAX.f32	q4, q4, %q[minv]\n"
143                         "VMAX.f32	q5, q5, %q[minv]\n"
144                         "VMAX.f32	q6, q6, %q[minv]\n"
145                         ASM_PREFETCH("[%[inptr], #480]")
146                         "VMAX.f32	q7, q7, %q[minv]\n"
147                         "VMIN.f32	q4, q4, %q[maxv]\n"
148                         "VMIN.f32	q5, q5, %q[maxv]\n"
149                         "VST1.32	{d8-d11}, [%[outptr0]]!\n"
150                         "VMIN.f32	q6, q6, %q[maxv]\n"
151                         "VMIN.f32	q7, q7, %q[maxv]\n"
152                         "VST1.32	{d12-d15}, [%[outptr1]]!\n"
153 
154                         // Rows 2-3
155                         "VLD1.32	{d0-d3},   [%[inptr]]!\n"
156                         "VLD1.32	{d8-d11},  [%[outptr2]]\n"
157                         "VLD1.32	{d4-d7},   [%[inptr]]!\n"
158                         "VLD1.32	{d12-d15}, [%[outptr3]]\n"
159 
160                         "VADD.f32	q4, q4, q0\n"
161                         ASM_PREFETCH("[%[outptr0], #96]")
162                         "VADD.f32	q5, q5, q1\n"
163                         "VADD.f32	q6, q6, q2\n"
164                         "VADD.f32	q7, q7, q3\n"
165                         ASM_PREFETCH("[%[outptr1], #96]")
166                         "VMAX.f32	q4, q4, %q[minv]\n"
167                         "VMAX.f32	q5, q5, %q[minv]\n"
168                         "VMAX.f32	q6, q6, %q[minv]\n"
169                         ASM_PREFETCH("[%[outptr2], #128]")
170                         "VMAX.f32	q7, q7, %q[minv]\n"
171                         "VMIN.f32	q4, q4, %q[maxv]\n"
172                         "VMIN.f32	q5, q5, %q[maxv]\n"
173                         "VST1.32	{d8-d11}, [%[outptr2]]!\n"
174                         "VMIN.f32	q6, q6, %q[maxv]\n"
175                         "VMIN.f32	q7, q7, %q[maxv]\n"
176                         "VST1.32	{d12-d15}, [%[outptr3]]!\n"
177 
178                         // Rows 4-5
179                         "VLD1.32	{d0-d3},   [%[inptr]]!\n"
180                         "VLD1.32	{d8-d11},  [%[outptr4]]\n"
181                         "VLD1.32	{d4-d7},   [%[inptr]]!\n"
182                         "VLD1.32	{d12-d15}, [%[outptr5]]\n"
183 
184                         "VADD.f32	q4, q4, q0\n"
185                         ASM_PREFETCH("[%[outptr3], #96]")
186                         "VADD.f32	q5, q5, q1\n"
187                         "VADD.f32	q6, q6, q2\n"
188                         "VADD.f32	q7, q7, q3\n"
189                         ASM_PREFETCH("[%[outptr4], #128]")
190                         "VMAX.f32	q4, q4, %q[minv]\n"
191                         "VMAX.f32	q5, q5, %q[minv]\n"
192                         "VMAX.f32	q6, q6, %q[minv]\n"
193                         ASM_PREFETCH("[%[outptr5], #128]")
194                         "VMAX.f32	q7, q7, %q[minv]\n"
195                         "VMIN.f32	q4, q4, %q[maxv]\n"
196                         "VMIN.f32	q5, q5, %q[maxv]\n"
197                         "VST1.32	{d8-d11}, [%[outptr4]]!\n"
198                         "VMIN.f32	q6, q6, %q[maxv]\n"
199                         "VMIN.f32	q7, q7, %q[maxv]\n"
200                         "VST1.32	{d12-d15}, [%[outptr5]]!\n"
201                     : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
202                       [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [inptr] "+r" (inptr)
203                     : [minv] "w" (minv), [maxv] "w" (maxv)
204                     : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "memory"
205                     );
206                 }
207             } else {
208                 /* Bias mode: Add bias to everything, then min/max/write as before. */
209                 const float *biasptr = bias ? bias + i : nullbias;
210 
211                 /* For ragged X, manually copy over the valid results. */
212                 if ((i+7) >= xmax) {
213                     for (int xi=0; xi<7; xi++) {
214                         if ((i+xi) < xmax) {
215                             *outptr0 = std::min(std::max(minval, inptr[xi] + biasptr[xi]), maxval);
216                             outptr0++;
217                             *outptr1 = std::min(std::max(minval, inptr[xi + 8] + biasptr[xi]), maxval);
218                             outptr1++;
219                             *outptr2 = std::min(std::max(minval, inptr[xi + 16] + biasptr[xi]), maxval);
220                             outptr2++;
221                             *outptr3 = std::min(std::max(minval, inptr[xi + 24] + biasptr[xi]), maxval);
222                             outptr3++;
223                             *outptr4 = std::min(std::max(minval, inptr[xi + 32] + biasptr[xi]), maxval);
224                             outptr4++;
225                             *outptr5 = std::min(std::max(minval, inptr[xi + 40] + biasptr[xi]), maxval);
226                             outptr5++;
227                         }
228                     }
229                     inptr += 48;
230                 } else {
231                     /* Optimized routine to copy an entire block */
232                     __asm __volatile (
233                         // Rows 0-1
234                         "VLD1.32	{d8-d11},   [%[inptr]]!\n"
235                         "VLD1.32	{d0-d3},   [%[biasptr]]\n"
236                         "VLD1.32	{d12-d15},  [%[inptr]]!\n"
237 
238                         "VADD.f32	q4, q4, q0\n"
239                         ASM_PREFETCH("[%[inptr], #352]")
240                         "VADD.f32	q5, q5, q1\n"
241                         "VADD.f32	q6, q6, q0\n"
242                         "VADD.f32	q7, q7, q1\n"
243                         ASM_PREFETCH("[%[inptr], #416]")
244                         "VMAX.f32	q4, q4, %q[minv]\n"
245                         "VMAX.f32	q5, q5, %q[minv]\n"
246                         "VMAX.f32	q6, q6, %q[minv]\n"
247                         ASM_PREFETCH("[%[inptr], #480]")
248                         "VMAX.f32	q7, q7, %q[minv]\n"
249                         "VMIN.f32	q4, q4, %q[maxv]\n"
250                         "VMIN.f32	q5, q5, %q[maxv]\n"
251                         "VST1.32	{d8-d11}, [%[outptr0]]!\n"
252                         "VMIN.f32	q6, q6, %q[maxv]\n"
253                         "VMIN.f32	q7, q7, %q[maxv]\n"
254                         "VST1.32	{d12-d15}, [%[outptr1]]!\n"
255 
256                         // Rows 2-3
257                         "VLD1.32	{d8-d11},   [%[inptr]]!\n"
258                         "VLD1.32	{d12-d15},  [%[inptr]]!\n"
259 
260                         "VADD.f32	q4, q4, q0\n"
261                         ASM_PREFETCH("[%[outptr0], #96]")
262                         "VADD.f32	q5, q5, q1\n"
263                         "VADD.f32	q6, q6, q0\n"
264                         "VADD.f32	q7, q7, q1\n"
265                         ASM_PREFETCH("[%[outptr1], #96]")
266                         "VMAX.f32	q4, q4, %q[minv]\n"
267                         "VMAX.f32	q5, q5, %q[minv]\n"
268                         "VMAX.f32	q6, q6, %q[minv]\n"
269                         ASM_PREFETCH("[%[outptr2], #128]")
270                         "VMAX.f32	q7, q7, %q[minv]\n"
271                         "VMIN.f32	q4, q4, %q[maxv]\n"
272                         "VMIN.f32	q5, q5, %q[maxv]\n"
273                         "VST1.32	{d8-d11}, [%[outptr2]]!\n"
274                         "VMIN.f32	q6, q6, %q[maxv]\n"
275                         "VMIN.f32	q7, q7, %q[maxv]\n"
276                         "VST1.32	{d12-d15}, [%[outptr3]]!\n"
277 
278                         // Rows 4-5
279                         "VLD1.32	{d8-d11},   [%[inptr]]!\n"
280                         "VLD1.32	{d12-d15},  [%[inptr]]!\n"
281 
282                         "VADD.f32	q4, q4, q0\n"
283                         ASM_PREFETCH("[%[outptr3], #96]")
284                         "VADD.f32	q5, q5, q1\n"
285                         "VADD.f32	q6, q6, q0\n"
286                         "VADD.f32	q7, q7, q1\n"
287                         ASM_PREFETCH("[%[outptr4], #128]")
288                         "VMAX.f32	q4, q4, %q[minv]\n"
289                         "VMAX.f32	q5, q5, %q[minv]\n"
290                         "VMAX.f32	q6, q6, %q[minv]\n"
291                         ASM_PREFETCH("[%[outptr5], #128]")
292                         "VMAX.f32	q7, q7, %q[minv]\n"
293                         "VMIN.f32	q4, q4, %q[maxv]\n"
294                         "VMIN.f32	q5, q5, %q[maxv]\n"
295                         "VST1.32	{d8-d11}, [%[outptr4]]!\n"
296                         "VMIN.f32	q6, q6, %q[maxv]\n"
297                         "VMIN.f32	q7, q7, %q[maxv]\n"
298                         "VST1.32	{d12-d15}, [%[outptr5]]!\n"
299                     : [outptr0] "+r" (outptr0), [outptr1] "+r" (outptr1), [outptr2] "+r" (outptr2), [outptr3] "+r" (outptr3),
300                       [outptr4] "+r" (outptr4), [outptr5] "+r" (outptr5), [inptr] "+r" (inptr)
301                     : [minv] "w" (minv), [maxv] "w" (maxv), [biasptr] "r" (biasptr)
302                     : "q0", "q1", "q2", "q3", "q4", "q5", "q6", "q7", "memory"
303                     );
304                 }
305             }
306         }
307     }
308 }
309 
310 #endif // __arm__
311