xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.h"
25 
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/ITensor.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/Utils.h"
32 #include "arm_compute/core/Validate.h"
33 #include "arm_compute/core/Window.h"
34 #include "src/core/helpers/AutoConfiguration.h"
35 #include "src/core/helpers/WindowHelpers.h"
36 
37 #include <arm_neon.h>
38 
39 namespace arm_compute
40 {
41 namespace cpu
42 {
43 namespace kernels
44 {
45 namespace
46 {
vector_matrix_multiply_u8(Iterator & ina,Iterator & inb,Iterator & out,int width_a,int width_b,int width_out,size_t stride_b,const Window & window)47 void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
48 {
49     execute_window_loop(window, [&](const Coordinates & id)
50     {
51         if(id.x() > width_b)
52         {
53             return;
54         }
55 
56         // Note: Since the input are all positives, we can use uint32_t
57         // Accumulators for the block 0
58         uint32x4x4_t c0 =
59         {
60             {
61                 vdupq_n_u32(0),
62                 vdupq_n_u32(0),
63                 vdupq_n_u32(0),
64                 vdupq_n_u32(0)
65             }
66         };
67 
68         auto vec_a          = reinterpret_cast<const uint8_t *>(ina.ptr());
69         auto matrix_b       = reinterpret_cast<const uint8_t *>(inb.ptr());
70         auto vec_a_end_addr = vec_a + width_a;
71 
72         // This for loop performs 8 accumulations
73         for(; vec_a <= (vec_a_end_addr - 8);)
74         {
75             const uint8x8_t  a00_u8 = vld1_u8(vec_a);
76             const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
77             const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
78             const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
79             const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
80             const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
81             const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
82             const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
83             const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
84 
85             // Convert a00_u8 to uint16_t and get the lower part
86             const uint16x4x2_t a00_u16 =
87             {
88                 {
89                     vget_low_u16(vmovl_u8(a00_u8)),
90                     vget_high_u16(vmovl_u8(a00_u8))
91                 }
92             };
93 
94             const uint16x4x4_t b00_u16 =
95             {
96                 {
97                     vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
98                     vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
99                     vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
100                     vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
101                 }
102             };
103 
104             const uint16x4x4_t b10_u16 =
105             {
106                 {
107                     vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
108                     vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
109                     vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
110                     vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
111                 }
112             };
113 
114             const uint16x4x4_t b20_u16 =
115             {
116                 {
117                     vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
118                     vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
119                     vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
120                     vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
121                 }
122             };
123 
124             const uint16x4x4_t b30_u16 =
125             {
126                 {
127                     vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
128                     vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
129                     vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
130                     vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
131                 }
132             };
133 
134             const uint16x4x4_t b40_u16 =
135             {
136                 {
137                     vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
138                     vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
139                     vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
140                     vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
141                 }
142             };
143 
144             const uint16x4x4_t b50_u16 =
145             {
146                 {
147                     vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
148                     vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
149                     vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
150                     vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
151                 }
152             };
153 
154             const uint16x4x4_t b60_u16 =
155             {
156                 {
157                     vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
158                     vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
159                     vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
160                     vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
161                 }
162             };
163 
164             const uint16x4x4_t b70_u16 =
165             {
166                 {
167                     vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
168                     vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
169                     vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
170                     vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
171                 }
172             };
173 
174             // Accumulate 0:
175             c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
176             c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
177             c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
178             c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
179 
180             // Accumulate 1:
181             c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
182             c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
183             c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
184             c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
185 
186             // Accumulate 2:
187             c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
188             c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
189             c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
190             c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
191 
192             // Accumulate 3:
193             c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
194             c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
195             c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
196             c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
197 
198             // Accumulate 4:
199             c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
200             c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
201             c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
202             c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
203 
204             // Accumulate 5:
205             c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
206             c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
207             c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
208             c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
209 
210             // Accumulate 6:
211             c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
212             c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
213             c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
214             c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
215 
216             // Accumulate 7:
217             c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
218             c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
219             c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
220             c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
221 
222             vec_a += 8;
223             matrix_b += 8 * stride_b;
224         }
225 
226         // This for loop performs the left-over accumulations
227         for(; vec_a < vec_a_end_addr;)
228         {
229             const uint8x8_t  a00_u8 = vld1_dup_u8(vec_a);
230             const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
231 
232             const uint16x4x4_t b00_u16 =
233             {
234                 {
235                     vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
236                     vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
237                     vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
238                     vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
239                 }
240             };
241 
242             // Convert a00_u8 to uint16_t and get the lower part
243             const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
244 
245             // Accumulate 0:
246             c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
247             c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
248             c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
249             c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
250 
251             vec_a += 1;
252             matrix_b += stride_b;
253         }
254 
255         auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
256         if(id.x() < (width_out - 16))
257         {
258             vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
259             vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
260             vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
261             vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
262         }
263         else
264         {
265             auto left_over = width_out - id.x();
266             for(auto k = 0; k < 4 && left_over; ++k)
267             {
268                 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
269                 {
270                     *(vec_out + k * 4 + j) = c0.val[k][j];
271                 }
272             }
273         }
274     },
275     ina, inb, out);
276 }
277 
vector_matrix_multiply_s8(Iterator & ina,Iterator & inb,Iterator & out,int width_a,int width_b,int width_out,size_t stride_b,const Window & window)278 void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
279 {
280     execute_window_loop(window, [&](const Coordinates & id)
281     {
282         if(id.x() > width_b)
283         {
284             return;
285         }
286 
287         // Accumulators for the block 0
288         int32x4x4_t c0 =
289         {
290             {
291                 vdupq_n_s32(0),
292                 vdupq_n_s32(0),
293                 vdupq_n_s32(0),
294                 vdupq_n_s32(0)
295             }
296         };
297 
298         auto vec_a          = reinterpret_cast<const int8_t *>(ina.ptr());
299         auto matrix_b       = reinterpret_cast<const int8_t *>(inb.ptr());
300         auto vec_a_end_addr = vec_a + width_a;
301 
302         // This for loop performs 8 accumulations
303         for(; vec_a <= (vec_a_end_addr - 8);)
304         {
305             const int8x8_t  a00_s8 = vld1_s8(vec_a);
306             const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
307             const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
308             const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
309             const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
310             const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
311             const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
312             const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
313             const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
314 
315             // Convert a00_s8 to int16_t and get the lower part
316             const int16x4x2_t a00_s16 =
317             {
318                 {
319                     vget_low_s16(vmovl_s8(a00_s8)),
320                     vget_high_s16(vmovl_s8(a00_s8))
321                 }
322             };
323 
324             const int16x4x4_t b00_s16 =
325             {
326                 {
327                     vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
328                     vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
329                     vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
330                     vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
331                 }
332             };
333 
334             const int16x4x4_t b10_s16 =
335             {
336                 {
337                     vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
338                     vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
339                     vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
340                     vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
341                 }
342             };
343 
344             const int16x4x4_t b20_s16 =
345             {
346                 {
347                     vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
348                     vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
349                     vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
350                     vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
351                 }
352             };
353 
354             const int16x4x4_t b30_s16 =
355             {
356                 {
357                     vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
358                     vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
359                     vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
360                     vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
361                 }
362             };
363 
364             const int16x4x4_t b40_s16 =
365             {
366                 {
367                     vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
368                     vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
369                     vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
370                     vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
371                 }
372             };
373 
374             const int16x4x4_t b50_s16 =
375             {
376                 {
377                     vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
378                     vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
379                     vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
380                     vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
381                 }
382             };
383 
384             const int16x4x4_t b60_s16 =
385             {
386                 {
387                     vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
388                     vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
389                     vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
390                     vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
391                 }
392             };
393 
394             const int16x4x4_t b70_s16 =
395             {
396                 {
397                     vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
398                     vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
399                     vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
400                     vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
401                 }
402             };
403 
404             // Accumulate 0:
405             c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
406             c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
407             c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
408             c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
409 
410             // Accumulate 1:
411             c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
412             c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
413             c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
414             c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
415 
416             // Accumulate 2:
417             c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
418             c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
419             c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
420             c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
421 
422             // Accumulate 3:
423             c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
424             c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
425             c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
426             c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
427 
428             // Accumulate 4:
429             c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
430             c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
431             c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
432             c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
433 
434             // Accumulate 5:
435             c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
436             c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
437             c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
438             c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
439 
440             // Accumulate 6:
441             c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
442             c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
443             c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
444             c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
445 
446             // Accumulate 7:
447             c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
448             c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
449             c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
450             c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
451 
452             vec_a += 8;
453             matrix_b += 8 * stride_b;
454         }
455 
456         // This for loop performs the left-over accumulations
457         for(; vec_a < vec_a_end_addr;)
458         {
459             const int8x8_t  a00_s8 = vld1_dup_s8(vec_a);
460             const int8x16_t b00_s8 = vld1q_s8(matrix_b);
461 
462             const int16x4x4_t b00_s16 =
463             {
464                 {
465                     vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
466                     vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
467                     vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
468                     vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
469                 }
470             };
471 
472             // Convert a00_s8 to uint16_t and get the lower part
473             const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
474 
475             // Accumulate 0:
476             c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
477             c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
478             c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
479             c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
480 
481             vec_a += 1;
482             matrix_b += stride_b;
483         }
484 
485         auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
486         if(id.x() < (width_out - 16))
487         {
488             vst1q_s32(vec_out + 0, c0.val[0]);
489             vst1q_s32(vec_out + 4, c0.val[1]);
490             vst1q_s32(vec_out + 8, c0.val[2]);
491             vst1q_s32(vec_out + 12, c0.val[3]);
492         }
493         else
494         {
495             auto left_over = width_out - id.x();
496             for(auto k = 0; k < 4 && left_over; ++k)
497             {
498                 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
499                 {
500                     *(vec_out + k * 4 + j) = c0.val[k][j];
501                 }
502             }
503         }
504     },
505     ina, inb, out);
506 }
507 
matrix_multiply_u8(Iterator & ina,Iterator & inb,Iterator & out,int width_b,const TensorInfo & out_info,const Window & window)508 void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
509 {
510     const auto   width_out  = static_cast<int>(out_info.dimension(0));
511     const auto   height_out = static_cast<int>(out_info.dimension(1));
512     const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
513     execute_window_loop(window, [&](const Coordinates & id)
514     {
515         const uint8_t *mtx_a0 = ina.ptr();
516         const uint8_t *mtx_b0 = inb.ptr();
517 
518         // Note: Since the input are all positives, we can use uint32_t
519         // Accumulators for the block 0
520         uint32x4x4_t c0 =
521         {
522             {
523                 vdupq_n_u32(0),
524                 vdupq_n_u32(0),
525                 vdupq_n_u32(0),
526                 vdupq_n_u32(0)
527             }
528         };
529 
530         // Accumulators for the block 1
531         uint32x4x4_t c1 =
532         {
533             {
534                 vdupq_n_u32(0),
535                 vdupq_n_u32(0),
536                 vdupq_n_u32(0),
537                 vdupq_n_u32(0)
538             }
539         };
540 
541         // Accumulators for the block 2
542         uint32x4x4_t c2 =
543         {
544             {
545                 vdupq_n_u32(0),
546                 vdupq_n_u32(0),
547                 vdupq_n_u32(0),
548                 vdupq_n_u32(0)
549             }
550         };
551 
552         // Accumulators for the block 3
553         uint32x4x4_t c3 =
554         {
555             {
556                 vdupq_n_u32(0),
557                 vdupq_n_u32(0),
558                 vdupq_n_u32(0),
559                 vdupq_n_u32(0)
560             }
561         };
562 
563         for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
564         {
565             const uint8x8_t  a00_u8 = vld1_u8(mtx_a0);
566             const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
567 
568             // Convert a00_u8 to uint16_t and get the lower part
569             const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
570 
571             // Convert b00_s8 to uint16_t
572             const uint16x4x4_t b00_u16 =
573             {
574                 {
575                     vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
576                     vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
577                     vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
578                     vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
579                 }
580             };
581 
582             // 4x4 block 0
583             c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
584             c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
585             c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
586             c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
587 
588             // 4x4 block 1
589             c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
590             c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
591             c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
592             c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
593 
594             // 4x4 block 2
595             c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
596             c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
597             c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
598             c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
599 
600             // 4x4 block 3
601             c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
602             c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
603             c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
604             c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
605         }
606 
607         auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
608 
609         if(id.y() < height_out && id.x() < (width_out - 16))
610         {
611             vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
612             vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
613             vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
614             vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
615             if(id.y() + 1 < height_out)
616             {
617                 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
618                 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
619                 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
620                 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
621                 if(id.y() + 2 < height_out)
622                 {
623                     vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
624                     vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
625                     vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
626                     vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
627                     if(id.y() + 3 < height_out)
628                     {
629                         vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
630                         vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
631                         vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
632                         vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
633                     }
634                 }
635             }
636         }
637         else
638         {
639             const auto left_over_value = width_out - id.x();
640             auto       left_over       = left_over_value;
641             for(auto k = 0; k < 4 && left_over; ++k)
642             {
643                 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
644                 {
645                     *(mtx_out + k * 4 + j) = c0.val[k][j];
646                 }
647             }
648             if(id.y() + 1 < height_out)
649             {
650                 left_over = left_over_value;
651                 for(auto k = 0; k < 4 && left_over; ++k)
652                 {
653                     for(auto j = 0; j < 4 && left_over; ++j, --left_over)
654                     {
655                         *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
656                     }
657                 }
658                 if(id.y() + 2 < height_out)
659                 {
660                     left_over = left_over_value;
661                     for(auto k = 0; k < 4 && left_over; ++k)
662                     {
663                         for(auto j = 0; j < 4 && left_over; ++j, --left_over)
664                         {
665                             *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
666                         }
667                     }
668                     if(id.y() + 3 < height_out)
669                     {
670                         left_over = left_over_value;
671                         for(auto k = 0; k < 4 && left_over; ++k)
672                         {
673                             for(auto j = 0; j < 4 && left_over; ++j, --left_over)
674                             {
675                                 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
676                             }
677                         }
678                     }
679                 }
680             }
681         }
682     },
683     ina, inb, out);
684 }
685 
matrix_multiply_s8(Iterator & ina,Iterator & inb,Iterator & out,int width_b,const TensorInfo & out_info,const Window & window)686 void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
687 {
688     const auto   width_out  = static_cast<int>(out_info.dimension(0));
689     const auto   height_out = static_cast<int>(out_info.dimension(1));
690     const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
691     // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
692     // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
693     // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
694     execute_window_loop(window, [&](const Coordinates & id)
695     {
696         auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
697         auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
698 
699         // Note: Since the input are all positives, we can use uint32_t
700         // Accumulators for the block 0
701         int32x4x4_t c0 =
702         {
703             {
704                 vdupq_n_s32(0),
705                 vdupq_n_s32(0),
706                 vdupq_n_s32(0),
707                 vdupq_n_s32(0)
708             }
709         };
710 
711         // Accumulators for the block 1
712         int32x4x4_t c1 =
713         {
714             {
715                 vdupq_n_s32(0),
716                 vdupq_n_s32(0),
717                 vdupq_n_s32(0),
718                 vdupq_n_s32(0)
719             }
720         };
721 
722         // Accumulators for the block 2
723         int32x4x4_t c2 =
724         {
725             {
726                 vdupq_n_s32(0),
727                 vdupq_n_s32(0),
728                 vdupq_n_s32(0),
729                 vdupq_n_s32(0)
730             }
731         };
732 
733         // Accumulators for the block 3
734         int32x4x4_t c3 =
735         {
736             {
737                 vdupq_n_s32(0),
738                 vdupq_n_s32(0),
739                 vdupq_n_s32(0),
740                 vdupq_n_s32(0)
741             }
742         };
743 
744         for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
745         {
746             const int8x8_t  a00_s8 = vld1_s8(mtx_a0);
747             const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
748 
749             // Convert a00_s8 to uint16_t and get the lower part
750             const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
751 
752             // Convert b00_s8 to int16_t
753             const int16x4x4_t b00_s16 =
754             {
755                 {
756                     vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
757                     vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
758                     vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
759                     vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
760                 }
761             };
762 
763             // 4x4 block 0
764             c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
765             c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
766             c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
767             c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
768 
769             // 4x4 block 1
770             c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
771             c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
772             c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
773             c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
774 
775             // 4x4 block 2
776             c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
777             c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
778             c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
779             c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
780 
781             // 4x4 block 3
782             c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
783             c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
784             c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
785             c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
786         }
787         auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
788         if(id.y() < height_out && id.x() < (width_out - 16))
789         {
790             vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
791             vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
792             vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
793             vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
794             if(id.y() + 1 < height_out)
795             {
796                 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
797                 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
798                 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
799                 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
800                 if(id.y() + 2 < height_out)
801                 {
802                     vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
803                     vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
804                     vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
805                     vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
806                     if(id.y() + 3 < height_out)
807                     {
808                         vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
809                         vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
810                         vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
811                         vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
812                     }
813                 }
814             }
815         }
816         else if(id.y() < height_out)
817         {
818             const auto left_over_value = width_out - id.x();
819             auto       left_over       = left_over_value;
820             for(auto k = 0; k < 4 && left_over; ++k)
821             {
822                 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
823                 {
824                     *(mtx_out + k * 4 + j) = c0.val[k][j];
825                 }
826             }
827             if(id.y() + 1 < height_out)
828             {
829                 left_over = left_over_value;
830                 for(auto k = 0; k < 4 && left_over; ++k)
831                 {
832                     for(auto j = 0; j < 4 && left_over; ++j, --left_over)
833                     {
834                         *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
835                     }
836                 }
837                 if(id.y() + 2 < height_out)
838                 {
839                     left_over = left_over_value;
840                     for(auto k = 0; k < 4 && left_over; ++k)
841                     {
842                         for(auto j = 0; j < 4 && left_over; ++j, --left_over)
843                         {
844                             *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
845                         }
846                     }
847                     if(id.y() + 3 < height_out)
848                     {
849                         left_over = left_over_value;
850                         for(auto k = 0; k < 4 && left_over; ++k)
851                         {
852                             for(auto j = 0; j < 4 && left_over; ++j, --left_over)
853                             {
854                                 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
855                             }
856                         }
857                     }
858                 }
859             }
860         }
861 
862     },
863     ina, inb, out);
864 }
865 
validate_arguments(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)866 Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
867 {
868     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, DataType::U8);
869     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::S8, DataType::U8);
870     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
871 
872     TensorShape in0_shape = src0->tensor_shape();
873     TensorShape in1_shape = src1->tensor_shape();
874     TensorShape out_shape = dst->tensor_shape();
875 
876     // Check vector-by-matrix case
877     if(out_shape[1] == 1)
878     {
879         ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
880     }
881     else
882     {
883         in0_shape.collapse(2);
884         in1_shape.collapse(2);
885         out_shape.collapse(2);
886 
887         ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
888         ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
889         ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
890     }
891 
892     return Status{};
893 }
894 } // namespace
895 
configure(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)896 void CpuGemmLowpMatrixMultiplyKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
897 {
898     ARM_COMPUTE_UNUSED(src0);
899     ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
900     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, dst));
901 
902     TensorShape in1_shape = src1->tensor_shape();
903     in1_shape.collapse(2);
904 
905     _slide_matrix_b = in1_shape[2] != 1;
906 
907     constexpr unsigned int num_elems_processed_per_iteration_x = 16;
908     constexpr unsigned int num_elems_processed_per_iteration_y = 4;
909 
910     Window win;
911     // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
912     if((dst->dimension(1) == 1))
913     {
914         // Configure kernel window
915         win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x));
916     }
917     else
918     {
919         win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
920     }
921 
922     ICpuKernel::configure(win);
923 }
924 
validate(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)925 Status CpuGemmLowpMatrixMultiplyKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
926 {
927     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, dst));
928     return Status{};
929 }
930 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)931 void CpuGemmLowpMatrixMultiplyKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
932 {
933     ARM_COMPUTE_UNUSED(info);
934     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
935     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
936 
937     auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
938     auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
939     auto dst  = tensors.get_tensor(TensorType::ACL_DST);
940 
941     // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
942     if((dst->info()->dimension(1) == 1))
943     {
944         const auto width_matrix_a = static_cast<int>(src0->info()->dimension(0));
945         const auto width_matrix_b = static_cast<int>(src1->info()->dimension(0));
946         const auto width_out      = static_cast<int>(dst->info()->dimension(0));
947         const auto in_b_stride    = static_cast<int>(src1->info()->strides_in_bytes()[1] / data_size_from_type(src1->info()->data_type()));
948 
949         // The implementation computes 16 elements per iteration
950         const int window_start_x = 16 * info.thread_id;
951         const int window_step_x  = 16 * info.num_threads;
952         // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
953         const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
954 
955         Window win_out(window);
956         win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
957         win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
958 
959         Window win_a(window);
960         win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
961         win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
962 
963         Window win_b;
964         // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
965         // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
966         if(src1->info()->num_dimensions() >= 3)
967         {
968             win_b = window;
969         }
970         win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
971         win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
972 
973         Iterator ina(src0, win_a);
974         Iterator inb(src1, win_b);
975         Iterator out(dst, win_out);
976 
977         switch(src0->info()->data_type())
978         {
979             case DataType::S8:
980             case DataType::QASYMM8_SIGNED:
981             {
982                 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
983                 break;
984             }
985             case DataType::U8:
986             case DataType::QASYMM8:
987             {
988                 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
989                 break;
990             }
991             default:
992             {
993                 ARM_COMPUTE_ERROR("Not supported");
994                 break;
995             }
996         }
997     }
998     else
999     {
1000         const size_t in_b_stride = src1->info()->strides_in_bytes()[1];
1001         const int    width_b     = src1->info()->dimension(0);
1002 
1003         // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1004         Window win_a(window);
1005         win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1006         win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
1007 
1008         // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
1009         Window win_b;
1010         // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1011         // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1012         if(_slide_matrix_b)
1013         {
1014             win_b = window;
1015         }
1016         win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
1017         win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1018 
1019         // The step x and step y for the output matrix has been already set using in configure()
1020         Iterator ina(src0, win_a);
1021         Iterator inb(src1, win_b);
1022         Iterator out(dst, window);
1023 
1024         switch(src0->info()->data_type())
1025         {
1026             case DataType::S8:
1027             case DataType::QASYMM8_SIGNED:
1028             {
1029                 matrix_multiply_s8(ina, inb, out, width_b, *dst->info(), window);
1030                 break;
1031             }
1032             case DataType::U8:
1033             case DataType::QASYMM8:
1034             {
1035                 matrix_multiply_u8(ina, inb, out, width_b, *dst->info(), window);
1036                 break;
1037             }
1038             default:
1039             {
1040                 ARM_COMPUTE_ERROR("Not supported");
1041                 break;
1042             }
1043         }
1044     }
1045 }
1046 
name() const1047 const char *CpuGemmLowpMatrixMultiplyKernel::name() const
1048 {
1049     return "CpuGemmLowpMatrixMultiplyKernel";
1050 }
1051 } // namespace kernels
1052 } // namespace cpu
1053 } // namespace arm_compute