1 /*
2 * Copyright (c) 2017-2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/cpu/kernels/CpuGemmLowpMatrixMultiplyKernel.h"
25
26 #include "arm_compute/core/Error.h"
27 #include "arm_compute/core/Helpers.h"
28 #include "arm_compute/core/ITensor.h"
29 #include "arm_compute/core/TensorInfo.h"
30 #include "arm_compute/core/Types.h"
31 #include "arm_compute/core/Utils.h"
32 #include "arm_compute/core/Validate.h"
33 #include "arm_compute/core/Window.h"
34 #include "src/core/helpers/AutoConfiguration.h"
35 #include "src/core/helpers/WindowHelpers.h"
36
37 #include <arm_neon.h>
38
39 namespace arm_compute
40 {
41 namespace cpu
42 {
43 namespace kernels
44 {
45 namespace
46 {
vector_matrix_multiply_u8(Iterator & ina,Iterator & inb,Iterator & out,int width_a,int width_b,int width_out,size_t stride_b,const Window & window)47 void inline vector_matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
48 {
49 execute_window_loop(window, [&](const Coordinates & id)
50 {
51 if(id.x() > width_b)
52 {
53 return;
54 }
55
56 // Note: Since the input are all positives, we can use uint32_t
57 // Accumulators for the block 0
58 uint32x4x4_t c0 =
59 {
60 {
61 vdupq_n_u32(0),
62 vdupq_n_u32(0),
63 vdupq_n_u32(0),
64 vdupq_n_u32(0)
65 }
66 };
67
68 auto vec_a = reinterpret_cast<const uint8_t *>(ina.ptr());
69 auto matrix_b = reinterpret_cast<const uint8_t *>(inb.ptr());
70 auto vec_a_end_addr = vec_a + width_a;
71
72 // This for loop performs 8 accumulations
73 for(; vec_a <= (vec_a_end_addr - 8);)
74 {
75 const uint8x8_t a00_u8 = vld1_u8(vec_a);
76 const uint8x16_t b00_u8 = vld1q_u8(matrix_b + 0 * stride_b);
77 const uint8x16_t b10_u8 = vld1q_u8(matrix_b + 1 * stride_b);
78 const uint8x16_t b20_u8 = vld1q_u8(matrix_b + 2 * stride_b);
79 const uint8x16_t b30_u8 = vld1q_u8(matrix_b + 3 * stride_b);
80 const uint8x16_t b40_u8 = vld1q_u8(matrix_b + 4 * stride_b);
81 const uint8x16_t b50_u8 = vld1q_u8(matrix_b + 5 * stride_b);
82 const uint8x16_t b60_u8 = vld1q_u8(matrix_b + 6 * stride_b);
83 const uint8x16_t b70_u8 = vld1q_u8(matrix_b + 7 * stride_b);
84
85 // Convert a00_u8 to uint16_t and get the lower part
86 const uint16x4x2_t a00_u16 =
87 {
88 {
89 vget_low_u16(vmovl_u8(a00_u8)),
90 vget_high_u16(vmovl_u8(a00_u8))
91 }
92 };
93
94 const uint16x4x4_t b00_u16 =
95 {
96 {
97 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
98 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
99 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
100 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
101 }
102 };
103
104 const uint16x4x4_t b10_u16 =
105 {
106 {
107 vget_low_u16(vmovl_u8(vget_low_u8(b10_u8))),
108 vget_high_u16(vmovl_u8(vget_low_u8(b10_u8))),
109 vget_low_u16(vmovl_u8(vget_high_u8(b10_u8))),
110 vget_high_u16(vmovl_u8(vget_high_u8(b10_u8)))
111 }
112 };
113
114 const uint16x4x4_t b20_u16 =
115 {
116 {
117 vget_low_u16(vmovl_u8(vget_low_u8(b20_u8))),
118 vget_high_u16(vmovl_u8(vget_low_u8(b20_u8))),
119 vget_low_u16(vmovl_u8(vget_high_u8(b20_u8))),
120 vget_high_u16(vmovl_u8(vget_high_u8(b20_u8)))
121 }
122 };
123
124 const uint16x4x4_t b30_u16 =
125 {
126 {
127 vget_low_u16(vmovl_u8(vget_low_u8(b30_u8))),
128 vget_high_u16(vmovl_u8(vget_low_u8(b30_u8))),
129 vget_low_u16(vmovl_u8(vget_high_u8(b30_u8))),
130 vget_high_u16(vmovl_u8(vget_high_u8(b30_u8)))
131 }
132 };
133
134 const uint16x4x4_t b40_u16 =
135 {
136 {
137 vget_low_u16(vmovl_u8(vget_low_u8(b40_u8))),
138 vget_high_u16(vmovl_u8(vget_low_u8(b40_u8))),
139 vget_low_u16(vmovl_u8(vget_high_u8(b40_u8))),
140 vget_high_u16(vmovl_u8(vget_high_u8(b40_u8)))
141 }
142 };
143
144 const uint16x4x4_t b50_u16 =
145 {
146 {
147 vget_low_u16(vmovl_u8(vget_low_u8(b50_u8))),
148 vget_high_u16(vmovl_u8(vget_low_u8(b50_u8))),
149 vget_low_u16(vmovl_u8(vget_high_u8(b50_u8))),
150 vget_high_u16(vmovl_u8(vget_high_u8(b50_u8)))
151 }
152 };
153
154 const uint16x4x4_t b60_u16 =
155 {
156 {
157 vget_low_u16(vmovl_u8(vget_low_u8(b60_u8))),
158 vget_high_u16(vmovl_u8(vget_low_u8(b60_u8))),
159 vget_low_u16(vmovl_u8(vget_high_u8(b60_u8))),
160 vget_high_u16(vmovl_u8(vget_high_u8(b60_u8)))
161 }
162 };
163
164 const uint16x4x4_t b70_u16 =
165 {
166 {
167 vget_low_u16(vmovl_u8(vget_low_u8(b70_u8))),
168 vget_high_u16(vmovl_u8(vget_low_u8(b70_u8))),
169 vget_low_u16(vmovl_u8(vget_high_u8(b70_u8))),
170 vget_high_u16(vmovl_u8(vget_high_u8(b70_u8)))
171 }
172 };
173
174 // Accumulate 0:
175 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16.val[0], 0);
176 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16.val[0], 0);
177 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16.val[0], 0);
178 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16.val[0], 0);
179
180 // Accumulate 1:
181 c0.val[0] = vmlal_lane_u16(c0.val[0], b10_u16.val[0], a00_u16.val[0], 1);
182 c0.val[1] = vmlal_lane_u16(c0.val[1], b10_u16.val[1], a00_u16.val[0], 1);
183 c0.val[2] = vmlal_lane_u16(c0.val[2], b10_u16.val[2], a00_u16.val[0], 1);
184 c0.val[3] = vmlal_lane_u16(c0.val[3], b10_u16.val[3], a00_u16.val[0], 1);
185
186 // Accumulate 2:
187 c0.val[0] = vmlal_lane_u16(c0.val[0], b20_u16.val[0], a00_u16.val[0], 2);
188 c0.val[1] = vmlal_lane_u16(c0.val[1], b20_u16.val[1], a00_u16.val[0], 2);
189 c0.val[2] = vmlal_lane_u16(c0.val[2], b20_u16.val[2], a00_u16.val[0], 2);
190 c0.val[3] = vmlal_lane_u16(c0.val[3], b20_u16.val[3], a00_u16.val[0], 2);
191
192 // Accumulate 3:
193 c0.val[0] = vmlal_lane_u16(c0.val[0], b30_u16.val[0], a00_u16.val[0], 3);
194 c0.val[1] = vmlal_lane_u16(c0.val[1], b30_u16.val[1], a00_u16.val[0], 3);
195 c0.val[2] = vmlal_lane_u16(c0.val[2], b30_u16.val[2], a00_u16.val[0], 3);
196 c0.val[3] = vmlal_lane_u16(c0.val[3], b30_u16.val[3], a00_u16.val[0], 3);
197
198 // Accumulate 4:
199 c0.val[0] = vmlal_lane_u16(c0.val[0], b40_u16.val[0], a00_u16.val[1], 0);
200 c0.val[1] = vmlal_lane_u16(c0.val[1], b40_u16.val[1], a00_u16.val[1], 0);
201 c0.val[2] = vmlal_lane_u16(c0.val[2], b40_u16.val[2], a00_u16.val[1], 0);
202 c0.val[3] = vmlal_lane_u16(c0.val[3], b40_u16.val[3], a00_u16.val[1], 0);
203
204 // Accumulate 5:
205 c0.val[0] = vmlal_lane_u16(c0.val[0], b50_u16.val[0], a00_u16.val[1], 1);
206 c0.val[1] = vmlal_lane_u16(c0.val[1], b50_u16.val[1], a00_u16.val[1], 1);
207 c0.val[2] = vmlal_lane_u16(c0.val[2], b50_u16.val[2], a00_u16.val[1], 1);
208 c0.val[3] = vmlal_lane_u16(c0.val[3], b50_u16.val[3], a00_u16.val[1], 1);
209
210 // Accumulate 6:
211 c0.val[0] = vmlal_lane_u16(c0.val[0], b60_u16.val[0], a00_u16.val[1], 2);
212 c0.val[1] = vmlal_lane_u16(c0.val[1], b60_u16.val[1], a00_u16.val[1], 2);
213 c0.val[2] = vmlal_lane_u16(c0.val[2], b60_u16.val[2], a00_u16.val[1], 2);
214 c0.val[3] = vmlal_lane_u16(c0.val[3], b60_u16.val[3], a00_u16.val[1], 2);
215
216 // Accumulate 7:
217 c0.val[0] = vmlal_lane_u16(c0.val[0], b70_u16.val[0], a00_u16.val[1], 3);
218 c0.val[1] = vmlal_lane_u16(c0.val[1], b70_u16.val[1], a00_u16.val[1], 3);
219 c0.val[2] = vmlal_lane_u16(c0.val[2], b70_u16.val[2], a00_u16.val[1], 3);
220 c0.val[3] = vmlal_lane_u16(c0.val[3], b70_u16.val[3], a00_u16.val[1], 3);
221
222 vec_a += 8;
223 matrix_b += 8 * stride_b;
224 }
225
226 // This for loop performs the left-over accumulations
227 for(; vec_a < vec_a_end_addr;)
228 {
229 const uint8x8_t a00_u8 = vld1_dup_u8(vec_a);
230 const uint8x16_t b00_u8 = vld1q_u8(matrix_b);
231
232 const uint16x4x4_t b00_u16 =
233 {
234 {
235 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
236 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
237 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
238 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
239 }
240 };
241
242 // Convert a00_u8 to uint16_t and get the lower part
243 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
244
245 // Accumulate 0:
246 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
247 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
248 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
249 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
250
251 vec_a += 1;
252 matrix_b += stride_b;
253 }
254
255 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
256 if(id.x() < (width_out - 16))
257 {
258 vst1q_s32(vec_out + 0, vreinterpretq_s32_u32(c0.val[0]));
259 vst1q_s32(vec_out + 4, vreinterpretq_s32_u32(c0.val[1]));
260 vst1q_s32(vec_out + 8, vreinterpretq_s32_u32(c0.val[2]));
261 vst1q_s32(vec_out + 12, vreinterpretq_s32_u32(c0.val[3]));
262 }
263 else
264 {
265 auto left_over = width_out - id.x();
266 for(auto k = 0; k < 4 && left_over; ++k)
267 {
268 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
269 {
270 *(vec_out + k * 4 + j) = c0.val[k][j];
271 }
272 }
273 }
274 },
275 ina, inb, out);
276 }
277
vector_matrix_multiply_s8(Iterator & ina,Iterator & inb,Iterator & out,int width_a,int width_b,int width_out,size_t stride_b,const Window & window)278 void inline vector_matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_a, int width_b, int width_out, size_t stride_b, const Window &window)
279 {
280 execute_window_loop(window, [&](const Coordinates & id)
281 {
282 if(id.x() > width_b)
283 {
284 return;
285 }
286
287 // Accumulators for the block 0
288 int32x4x4_t c0 =
289 {
290 {
291 vdupq_n_s32(0),
292 vdupq_n_s32(0),
293 vdupq_n_s32(0),
294 vdupq_n_s32(0)
295 }
296 };
297
298 auto vec_a = reinterpret_cast<const int8_t *>(ina.ptr());
299 auto matrix_b = reinterpret_cast<const int8_t *>(inb.ptr());
300 auto vec_a_end_addr = vec_a + width_a;
301
302 // This for loop performs 8 accumulations
303 for(; vec_a <= (vec_a_end_addr - 8);)
304 {
305 const int8x8_t a00_s8 = vld1_s8(vec_a);
306 const int8x16_t b00_s8 = vld1q_s8(matrix_b + 0 * stride_b);
307 const int8x16_t b10_s8 = vld1q_s8(matrix_b + 1 * stride_b);
308 const int8x16_t b20_s8 = vld1q_s8(matrix_b + 2 * stride_b);
309 const int8x16_t b30_s8 = vld1q_s8(matrix_b + 3 * stride_b);
310 const int8x16_t b40_s8 = vld1q_s8(matrix_b + 4 * stride_b);
311 const int8x16_t b50_s8 = vld1q_s8(matrix_b + 5 * stride_b);
312 const int8x16_t b60_s8 = vld1q_s8(matrix_b + 6 * stride_b);
313 const int8x16_t b70_s8 = vld1q_s8(matrix_b + 7 * stride_b);
314
315 // Convert a00_s8 to int16_t and get the lower part
316 const int16x4x2_t a00_s16 =
317 {
318 {
319 vget_low_s16(vmovl_s8(a00_s8)),
320 vget_high_s16(vmovl_s8(a00_s8))
321 }
322 };
323
324 const int16x4x4_t b00_s16 =
325 {
326 {
327 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
328 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
329 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
330 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
331 }
332 };
333
334 const int16x4x4_t b10_s16 =
335 {
336 {
337 vget_low_s16(vmovl_s8(vget_low_s8(b10_s8))),
338 vget_high_s16(vmovl_s8(vget_low_s8(b10_s8))),
339 vget_low_s16(vmovl_s8(vget_high_s8(b10_s8))),
340 vget_high_s16(vmovl_s8(vget_high_s8(b10_s8)))
341 }
342 };
343
344 const int16x4x4_t b20_s16 =
345 {
346 {
347 vget_low_s16(vmovl_s8(vget_low_s8(b20_s8))),
348 vget_high_s16(vmovl_s8(vget_low_s8(b20_s8))),
349 vget_low_s16(vmovl_s8(vget_high_s8(b20_s8))),
350 vget_high_s16(vmovl_s8(vget_high_s8(b20_s8)))
351 }
352 };
353
354 const int16x4x4_t b30_s16 =
355 {
356 {
357 vget_low_s16(vmovl_s8(vget_low_s8(b30_s8))),
358 vget_high_s16(vmovl_s8(vget_low_s8(b30_s8))),
359 vget_low_s16(vmovl_s8(vget_high_s8(b30_s8))),
360 vget_high_s16(vmovl_s8(vget_high_s8(b30_s8)))
361 }
362 };
363
364 const int16x4x4_t b40_s16 =
365 {
366 {
367 vget_low_s16(vmovl_s8(vget_low_s8(b40_s8))),
368 vget_high_s16(vmovl_s8(vget_low_s8(b40_s8))),
369 vget_low_s16(vmovl_s8(vget_high_s8(b40_s8))),
370 vget_high_s16(vmovl_s8(vget_high_s8(b40_s8)))
371 }
372 };
373
374 const int16x4x4_t b50_s16 =
375 {
376 {
377 vget_low_s16(vmovl_s8(vget_low_s8(b50_s8))),
378 vget_high_s16(vmovl_s8(vget_low_s8(b50_s8))),
379 vget_low_s16(vmovl_s8(vget_high_s8(b50_s8))),
380 vget_high_s16(vmovl_s8(vget_high_s8(b50_s8)))
381 }
382 };
383
384 const int16x4x4_t b60_s16 =
385 {
386 {
387 vget_low_s16(vmovl_s8(vget_low_s8(b60_s8))),
388 vget_high_s16(vmovl_s8(vget_low_s8(b60_s8))),
389 vget_low_s16(vmovl_s8(vget_high_s8(b60_s8))),
390 vget_high_s16(vmovl_s8(vget_high_s8(b60_s8)))
391 }
392 };
393
394 const int16x4x4_t b70_s16 =
395 {
396 {
397 vget_low_s16(vmovl_s8(vget_low_s8(b70_s8))),
398 vget_high_s16(vmovl_s8(vget_low_s8(b70_s8))),
399 vget_low_s16(vmovl_s8(vget_high_s8(b70_s8))),
400 vget_high_s16(vmovl_s8(vget_high_s8(b70_s8)))
401 }
402 };
403
404 // Accumulate 0:
405 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16.val[0], 0);
406 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16.val[0], 0);
407 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16.val[0], 0);
408 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16.val[0], 0);
409
410 // Accumulate 1:
411 c0.val[0] = vmlal_lane_s16(c0.val[0], b10_s16.val[0], a00_s16.val[0], 1);
412 c0.val[1] = vmlal_lane_s16(c0.val[1], b10_s16.val[1], a00_s16.val[0], 1);
413 c0.val[2] = vmlal_lane_s16(c0.val[2], b10_s16.val[2], a00_s16.val[0], 1);
414 c0.val[3] = vmlal_lane_s16(c0.val[3], b10_s16.val[3], a00_s16.val[0], 1);
415
416 // Accumulate 2:
417 c0.val[0] = vmlal_lane_s16(c0.val[0], b20_s16.val[0], a00_s16.val[0], 2);
418 c0.val[1] = vmlal_lane_s16(c0.val[1], b20_s16.val[1], a00_s16.val[0], 2);
419 c0.val[2] = vmlal_lane_s16(c0.val[2], b20_s16.val[2], a00_s16.val[0], 2);
420 c0.val[3] = vmlal_lane_s16(c0.val[3], b20_s16.val[3], a00_s16.val[0], 2);
421
422 // Accumulate 3:
423 c0.val[0] = vmlal_lane_s16(c0.val[0], b30_s16.val[0], a00_s16.val[0], 3);
424 c0.val[1] = vmlal_lane_s16(c0.val[1], b30_s16.val[1], a00_s16.val[0], 3);
425 c0.val[2] = vmlal_lane_s16(c0.val[2], b30_s16.val[2], a00_s16.val[0], 3);
426 c0.val[3] = vmlal_lane_s16(c0.val[3], b30_s16.val[3], a00_s16.val[0], 3);
427
428 // Accumulate 4:
429 c0.val[0] = vmlal_lane_s16(c0.val[0], b40_s16.val[0], a00_s16.val[1], 0);
430 c0.val[1] = vmlal_lane_s16(c0.val[1], b40_s16.val[1], a00_s16.val[1], 0);
431 c0.val[2] = vmlal_lane_s16(c0.val[2], b40_s16.val[2], a00_s16.val[1], 0);
432 c0.val[3] = vmlal_lane_s16(c0.val[3], b40_s16.val[3], a00_s16.val[1], 0);
433
434 // Accumulate 5:
435 c0.val[0] = vmlal_lane_s16(c0.val[0], b50_s16.val[0], a00_s16.val[1], 1);
436 c0.val[1] = vmlal_lane_s16(c0.val[1], b50_s16.val[1], a00_s16.val[1], 1);
437 c0.val[2] = vmlal_lane_s16(c0.val[2], b50_s16.val[2], a00_s16.val[1], 1);
438 c0.val[3] = vmlal_lane_s16(c0.val[3], b50_s16.val[3], a00_s16.val[1], 1);
439
440 // Accumulate 6:
441 c0.val[0] = vmlal_lane_s16(c0.val[0], b60_s16.val[0], a00_s16.val[1], 2);
442 c0.val[1] = vmlal_lane_s16(c0.val[1], b60_s16.val[1], a00_s16.val[1], 2);
443 c0.val[2] = vmlal_lane_s16(c0.val[2], b60_s16.val[2], a00_s16.val[1], 2);
444 c0.val[3] = vmlal_lane_s16(c0.val[3], b60_s16.val[3], a00_s16.val[1], 2);
445
446 // Accumulate 7:
447 c0.val[0] = vmlal_lane_s16(c0.val[0], b70_s16.val[0], a00_s16.val[1], 3);
448 c0.val[1] = vmlal_lane_s16(c0.val[1], b70_s16.val[1], a00_s16.val[1], 3);
449 c0.val[2] = vmlal_lane_s16(c0.val[2], b70_s16.val[2], a00_s16.val[1], 3);
450 c0.val[3] = vmlal_lane_s16(c0.val[3], b70_s16.val[3], a00_s16.val[1], 3);
451
452 vec_a += 8;
453 matrix_b += 8 * stride_b;
454 }
455
456 // This for loop performs the left-over accumulations
457 for(; vec_a < vec_a_end_addr;)
458 {
459 const int8x8_t a00_s8 = vld1_dup_s8(vec_a);
460 const int8x16_t b00_s8 = vld1q_s8(matrix_b);
461
462 const int16x4x4_t b00_s16 =
463 {
464 {
465 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
466 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
467 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
468 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
469 }
470 };
471
472 // Convert a00_s8 to uint16_t and get the lower part
473 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
474
475 // Accumulate 0:
476 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
477 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
478 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
479 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
480
481 vec_a += 1;
482 matrix_b += stride_b;
483 }
484
485 auto vec_out = reinterpret_cast<int32_t *>(out.ptr());
486 if(id.x() < (width_out - 16))
487 {
488 vst1q_s32(vec_out + 0, c0.val[0]);
489 vst1q_s32(vec_out + 4, c0.val[1]);
490 vst1q_s32(vec_out + 8, c0.val[2]);
491 vst1q_s32(vec_out + 12, c0.val[3]);
492 }
493 else
494 {
495 auto left_over = width_out - id.x();
496 for(auto k = 0; k < 4 && left_over; ++k)
497 {
498 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
499 {
500 *(vec_out + k * 4 + j) = c0.val[k][j];
501 }
502 }
503 }
504 },
505 ina, inb, out);
506 }
507
matrix_multiply_u8(Iterator & ina,Iterator & inb,Iterator & out,int width_b,const TensorInfo & out_info,const Window & window)508 void inline matrix_multiply_u8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
509 {
510 const auto width_out = static_cast<int>(out_info.dimension(0));
511 const auto height_out = static_cast<int>(out_info.dimension(1));
512 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
513 execute_window_loop(window, [&](const Coordinates & id)
514 {
515 const uint8_t *mtx_a0 = ina.ptr();
516 const uint8_t *mtx_b0 = inb.ptr();
517
518 // Note: Since the input are all positives, we can use uint32_t
519 // Accumulators for the block 0
520 uint32x4x4_t c0 =
521 {
522 {
523 vdupq_n_u32(0),
524 vdupq_n_u32(0),
525 vdupq_n_u32(0),
526 vdupq_n_u32(0)
527 }
528 };
529
530 // Accumulators for the block 1
531 uint32x4x4_t c1 =
532 {
533 {
534 vdupq_n_u32(0),
535 vdupq_n_u32(0),
536 vdupq_n_u32(0),
537 vdupq_n_u32(0)
538 }
539 };
540
541 // Accumulators for the block 2
542 uint32x4x4_t c2 =
543 {
544 {
545 vdupq_n_u32(0),
546 vdupq_n_u32(0),
547 vdupq_n_u32(0),
548 vdupq_n_u32(0)
549 }
550 };
551
552 // Accumulators for the block 3
553 uint32x4x4_t c3 =
554 {
555 {
556 vdupq_n_u32(0),
557 vdupq_n_u32(0),
558 vdupq_n_u32(0),
559 vdupq_n_u32(0)
560 }
561 };
562
563 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
564 {
565 const uint8x8_t a00_u8 = vld1_u8(mtx_a0);
566 const uint8x16_t b00_u8 = vld1q_u8(mtx_b0);
567
568 // Convert a00_u8 to uint16_t and get the lower part
569 const uint16x4_t a00_u16 = vget_low_u16(vmovl_u8(a00_u8));
570
571 // Convert b00_s8 to uint16_t
572 const uint16x4x4_t b00_u16 =
573 {
574 {
575 vget_low_u16(vmovl_u8(vget_low_u8(b00_u8))),
576 vget_high_u16(vmovl_u8(vget_low_u8(b00_u8))),
577 vget_low_u16(vmovl_u8(vget_high_u8(b00_u8))),
578 vget_high_u16(vmovl_u8(vget_high_u8(b00_u8)))
579 }
580 };
581
582 // 4x4 block 0
583 c0.val[0] = vmlal_lane_u16(c0.val[0], b00_u16.val[0], a00_u16, 0);
584 c0.val[1] = vmlal_lane_u16(c0.val[1], b00_u16.val[1], a00_u16, 0);
585 c0.val[2] = vmlal_lane_u16(c0.val[2], b00_u16.val[2], a00_u16, 0);
586 c0.val[3] = vmlal_lane_u16(c0.val[3], b00_u16.val[3], a00_u16, 0);
587
588 // 4x4 block 1
589 c1.val[0] = vmlal_lane_u16(c1.val[0], b00_u16.val[0], a00_u16, 1);
590 c1.val[1] = vmlal_lane_u16(c1.val[1], b00_u16.val[1], a00_u16, 1);
591 c1.val[2] = vmlal_lane_u16(c1.val[2], b00_u16.val[2], a00_u16, 1);
592 c1.val[3] = vmlal_lane_u16(c1.val[3], b00_u16.val[3], a00_u16, 1);
593
594 // 4x4 block 2
595 c2.val[0] = vmlal_lane_u16(c2.val[0], b00_u16.val[0], a00_u16, 2);
596 c2.val[1] = vmlal_lane_u16(c2.val[1], b00_u16.val[1], a00_u16, 2);
597 c2.val[2] = vmlal_lane_u16(c2.val[2], b00_u16.val[2], a00_u16, 2);
598 c2.val[3] = vmlal_lane_u16(c2.val[3], b00_u16.val[3], a00_u16, 2);
599
600 // 4x4 block 3
601 c3.val[0] = vmlal_lane_u16(c3.val[0], b00_u16.val[0], a00_u16, 3);
602 c3.val[1] = vmlal_lane_u16(c3.val[1], b00_u16.val[1], a00_u16, 3);
603 c3.val[2] = vmlal_lane_u16(c3.val[2], b00_u16.val[2], a00_u16, 3);
604 c3.val[3] = vmlal_lane_u16(c3.val[3], b00_u16.val[3], a00_u16, 3);
605 }
606
607 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
608
609 if(id.y() < height_out && id.x() < (width_out - 16))
610 {
611 vst1q_s32(mtx_out + 0 * out_stride + 0, vreinterpretq_s32_u32(c0.val[0]));
612 vst1q_s32(mtx_out + 0 * out_stride + 4, vreinterpretq_s32_u32(c0.val[1]));
613 vst1q_s32(mtx_out + 0 * out_stride + 8, vreinterpretq_s32_u32(c0.val[2]));
614 vst1q_s32(mtx_out + 0 * out_stride + 12, vreinterpretq_s32_u32(c0.val[3]));
615 if(id.y() + 1 < height_out)
616 {
617 vst1q_s32(mtx_out + 1 * out_stride + 0, vreinterpretq_s32_u32(c1.val[0]));
618 vst1q_s32(mtx_out + 1 * out_stride + 4, vreinterpretq_s32_u32(c1.val[1]));
619 vst1q_s32(mtx_out + 1 * out_stride + 8, vreinterpretq_s32_u32(c1.val[2]));
620 vst1q_s32(mtx_out + 1 * out_stride + 12, vreinterpretq_s32_u32(c1.val[3]));
621 if(id.y() + 2 < height_out)
622 {
623 vst1q_s32(mtx_out + 2 * out_stride + 0, vreinterpretq_s32_u32(c2.val[0]));
624 vst1q_s32(mtx_out + 2 * out_stride + 4, vreinterpretq_s32_u32(c2.val[1]));
625 vst1q_s32(mtx_out + 2 * out_stride + 8, vreinterpretq_s32_u32(c2.val[2]));
626 vst1q_s32(mtx_out + 2 * out_stride + 12, vreinterpretq_s32_u32(c2.val[3]));
627 if(id.y() + 3 < height_out)
628 {
629 vst1q_s32(mtx_out + 3 * out_stride + 0, vreinterpretq_s32_u32(c3.val[0]));
630 vst1q_s32(mtx_out + 3 * out_stride + 4, vreinterpretq_s32_u32(c3.val[1]));
631 vst1q_s32(mtx_out + 3 * out_stride + 8, vreinterpretq_s32_u32(c3.val[2]));
632 vst1q_s32(mtx_out + 3 * out_stride + 12, vreinterpretq_s32_u32(c3.val[3]));
633 }
634 }
635 }
636 }
637 else
638 {
639 const auto left_over_value = width_out - id.x();
640 auto left_over = left_over_value;
641 for(auto k = 0; k < 4 && left_over; ++k)
642 {
643 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
644 {
645 *(mtx_out + k * 4 + j) = c0.val[k][j];
646 }
647 }
648 if(id.y() + 1 < height_out)
649 {
650 left_over = left_over_value;
651 for(auto k = 0; k < 4 && left_over; ++k)
652 {
653 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
654 {
655 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
656 }
657 }
658 if(id.y() + 2 < height_out)
659 {
660 left_over = left_over_value;
661 for(auto k = 0; k < 4 && left_over; ++k)
662 {
663 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
664 {
665 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
666 }
667 }
668 if(id.y() + 3 < height_out)
669 {
670 left_over = left_over_value;
671 for(auto k = 0; k < 4 && left_over; ++k)
672 {
673 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
674 {
675 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
676 }
677 }
678 }
679 }
680 }
681 }
682 },
683 ina, inb, out);
684 }
685
matrix_multiply_s8(Iterator & ina,Iterator & inb,Iterator & out,int width_b,const TensorInfo & out_info,const Window & window)686 void inline matrix_multiply_s8(Iterator &ina, Iterator &inb, Iterator &out, int width_b, const TensorInfo &out_info, const Window &window)
687 {
688 const auto width_out = static_cast<int>(out_info.dimension(0));
689 const auto height_out = static_cast<int>(out_info.dimension(1));
690 const size_t out_stride = out_info.strides_in_bytes()[1] / out_info.element_size();
691 // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
692 // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
693 // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
694 execute_window_loop(window, [&](const Coordinates & id)
695 {
696 auto *mtx_a0 = reinterpret_cast<const int8_t *>(ina.ptr());
697 auto *mtx_b0 = reinterpret_cast<const int8_t *>(inb.ptr());
698
699 // Note: Since the input are all positives, we can use uint32_t
700 // Accumulators for the block 0
701 int32x4x4_t c0 =
702 {
703 {
704 vdupq_n_s32(0),
705 vdupq_n_s32(0),
706 vdupq_n_s32(0),
707 vdupq_n_s32(0)
708 }
709 };
710
711 // Accumulators for the block 1
712 int32x4x4_t c1 =
713 {
714 {
715 vdupq_n_s32(0),
716 vdupq_n_s32(0),
717 vdupq_n_s32(0),
718 vdupq_n_s32(0)
719 }
720 };
721
722 // Accumulators for the block 2
723 int32x4x4_t c2 =
724 {
725 {
726 vdupq_n_s32(0),
727 vdupq_n_s32(0),
728 vdupq_n_s32(0),
729 vdupq_n_s32(0)
730 }
731 };
732
733 // Accumulators for the block 3
734 int32x4x4_t c3 =
735 {
736 {
737 vdupq_n_s32(0),
738 vdupq_n_s32(0),
739 vdupq_n_s32(0),
740 vdupq_n_s32(0)
741 }
742 };
743
744 for(int k = 0; k < width_b; k += 16, mtx_a0 += 4, mtx_b0 += 16)
745 {
746 const int8x8_t a00_s8 = vld1_s8(mtx_a0);
747 const int8x16_t b00_s8 = vld1q_s8(mtx_b0);
748
749 // Convert a00_s8 to uint16_t and get the lower part
750 const int16x4_t a00_s16 = vget_low_s16(vmovl_s8(a00_s8));
751
752 // Convert b00_s8 to int16_t
753 const int16x4x4_t b00_s16 =
754 {
755 {
756 vget_low_s16(vmovl_s8(vget_low_s8(b00_s8))),
757 vget_high_s16(vmovl_s8(vget_low_s8(b00_s8))),
758 vget_low_s16(vmovl_s8(vget_high_s8(b00_s8))),
759 vget_high_s16(vmovl_s8(vget_high_s8(b00_s8)))
760 }
761 };
762
763 // 4x4 block 0
764 c0.val[0] = vmlal_lane_s16(c0.val[0], b00_s16.val[0], a00_s16, 0);
765 c0.val[1] = vmlal_lane_s16(c0.val[1], b00_s16.val[1], a00_s16, 0);
766 c0.val[2] = vmlal_lane_s16(c0.val[2], b00_s16.val[2], a00_s16, 0);
767 c0.val[3] = vmlal_lane_s16(c0.val[3], b00_s16.val[3], a00_s16, 0);
768
769 // 4x4 block 1
770 c1.val[0] = vmlal_lane_s16(c1.val[0], b00_s16.val[0], a00_s16, 1);
771 c1.val[1] = vmlal_lane_s16(c1.val[1], b00_s16.val[1], a00_s16, 1);
772 c1.val[2] = vmlal_lane_s16(c1.val[2], b00_s16.val[2], a00_s16, 1);
773 c1.val[3] = vmlal_lane_s16(c1.val[3], b00_s16.val[3], a00_s16, 1);
774
775 // 4x4 block 2
776 c2.val[0] = vmlal_lane_s16(c2.val[0], b00_s16.val[0], a00_s16, 2);
777 c2.val[1] = vmlal_lane_s16(c2.val[1], b00_s16.val[1], a00_s16, 2);
778 c2.val[2] = vmlal_lane_s16(c2.val[2], b00_s16.val[2], a00_s16, 2);
779 c2.val[3] = vmlal_lane_s16(c2.val[3], b00_s16.val[3], a00_s16, 2);
780
781 // 4x4 block 3
782 c3.val[0] = vmlal_lane_s16(c3.val[0], b00_s16.val[0], a00_s16, 3);
783 c3.val[1] = vmlal_lane_s16(c3.val[1], b00_s16.val[1], a00_s16, 3);
784 c3.val[2] = vmlal_lane_s16(c3.val[2], b00_s16.val[2], a00_s16, 3);
785 c3.val[3] = vmlal_lane_s16(c3.val[3], b00_s16.val[3], a00_s16, 3);
786 }
787 auto mtx_out = reinterpret_cast<int32_t *>(out.ptr());
788 if(id.y() < height_out && id.x() < (width_out - 16))
789 {
790 vst1q_s32(mtx_out + 0 * out_stride + 0, c0.val[0]);
791 vst1q_s32(mtx_out + 0 * out_stride + 4, c0.val[1]);
792 vst1q_s32(mtx_out + 0 * out_stride + 8, c0.val[2]);
793 vst1q_s32(mtx_out + 0 * out_stride + 12, c0.val[3]);
794 if(id.y() + 1 < height_out)
795 {
796 vst1q_s32(mtx_out + 1 * out_stride + 0, c1.val[0]);
797 vst1q_s32(mtx_out + 1 * out_stride + 4, c1.val[1]);
798 vst1q_s32(mtx_out + 1 * out_stride + 8, c1.val[2]);
799 vst1q_s32(mtx_out + 1 * out_stride + 12, c1.val[3]);
800 if(id.y() + 2 < height_out)
801 {
802 vst1q_s32(mtx_out + 2 * out_stride + 0, c2.val[0]);
803 vst1q_s32(mtx_out + 2 * out_stride + 4, c2.val[1]);
804 vst1q_s32(mtx_out + 2 * out_stride + 8, c2.val[2]);
805 vst1q_s32(mtx_out + 2 * out_stride + 12, c2.val[3]);
806 if(id.y() + 3 < height_out)
807 {
808 vst1q_s32(mtx_out + 3 * out_stride + 0, c3.val[0]);
809 vst1q_s32(mtx_out + 3 * out_stride + 4, c3.val[1]);
810 vst1q_s32(mtx_out + 3 * out_stride + 8, c3.val[2]);
811 vst1q_s32(mtx_out + 3 * out_stride + 12, c3.val[3]);
812 }
813 }
814 }
815 }
816 else if(id.y() < height_out)
817 {
818 const auto left_over_value = width_out - id.x();
819 auto left_over = left_over_value;
820 for(auto k = 0; k < 4 && left_over; ++k)
821 {
822 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
823 {
824 *(mtx_out + k * 4 + j) = c0.val[k][j];
825 }
826 }
827 if(id.y() + 1 < height_out)
828 {
829 left_over = left_over_value;
830 for(auto k = 0; k < 4 && left_over; ++k)
831 {
832 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
833 {
834 *(mtx_out + out_stride + k * 4 + j) = c1.val[k][j];
835 }
836 }
837 if(id.y() + 2 < height_out)
838 {
839 left_over = left_over_value;
840 for(auto k = 0; k < 4 && left_over; ++k)
841 {
842 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
843 {
844 *(mtx_out + out_stride * 2 + k * 4 + j) = c2.val[k][j];
845 }
846 }
847 if(id.y() + 3 < height_out)
848 {
849 left_over = left_over_value;
850 for(auto k = 0; k < 4 && left_over; ++k)
851 {
852 for(auto j = 0; j < 4 && left_over; ++j, --left_over)
853 {
854 *(mtx_out + out_stride * 3 + k * 4 + j) = c3.val[k][j];
855 }
856 }
857 }
858 }
859 }
860 }
861
862 },
863 ina, inb, out);
864 }
865
validate_arguments(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)866 Status validate_arguments(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
867 {
868 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src0, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::S8, DataType::U8);
869 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(src1, 1, DataType::QASYMM8, DataType::QASYMM8_SIGNED, DataType::QSYMM8, DataType::QSYMM8_PER_CHANNEL, DataType::S8, DataType::U8);
870 ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(dst, 1, DataType::S32);
871
872 TensorShape in0_shape = src0->tensor_shape();
873 TensorShape in1_shape = src1->tensor_shape();
874 TensorShape out_shape = dst->tensor_shape();
875
876 // Check vector-by-matrix case
877 if(out_shape[1] == 1)
878 {
879 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[0] != in1_shape[1], "The number of input0's columns must be equal to input1's rows");
880 }
881 else
882 {
883 in0_shape.collapse(2);
884 in1_shape.collapse(2);
885 out_shape.collapse(2);
886
887 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in0_shape[2] != out_shape[2], "Output tensor must have the same number of batches of input0 tensor");
888 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[2] != 1 && in0_shape[2] != in1_shape[2], "Input1 tensor must have the same number of batches of input0 or the number of batches must be set to 1");
889 ARM_COMPUTE_RETURN_ERROR_ON_MSG(in1_shape[0] % 16, "Input1's width must be a multiple of 16");
890 }
891
892 return Status{};
893 }
894 } // namespace
895
configure(const ITensorInfo * src0,const ITensorInfo * src1,ITensorInfo * dst)896 void CpuGemmLowpMatrixMultiplyKernel::configure(const ITensorInfo *src0, const ITensorInfo *src1, ITensorInfo *dst)
897 {
898 ARM_COMPUTE_UNUSED(src0);
899 ARM_COMPUTE_ERROR_ON_NULLPTR(src0, src1, dst);
900 ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(src0, src1, dst));
901
902 TensorShape in1_shape = src1->tensor_shape();
903 in1_shape.collapse(2);
904
905 _slide_matrix_b = in1_shape[2] != 1;
906
907 constexpr unsigned int num_elems_processed_per_iteration_x = 16;
908 constexpr unsigned int num_elems_processed_per_iteration_y = 4;
909
910 Window win;
911 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication
912 if((dst->dimension(1) == 1))
913 {
914 // Configure kernel window
915 win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x));
916 }
917 else
918 {
919 win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
920 }
921
922 ICpuKernel::configure(win);
923 }
924
validate(const ITensorInfo * src0,const ITensorInfo * src1,const ITensorInfo * dst)925 Status CpuGemmLowpMatrixMultiplyKernel::validate(const ITensorInfo *src0, const ITensorInfo *src1, const ITensorInfo *dst)
926 {
927 ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(src0, src1, dst));
928 return Status{};
929 }
930
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)931 void CpuGemmLowpMatrixMultiplyKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
932 {
933 ARM_COMPUTE_UNUSED(info);
934 ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
935 ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(ICpuKernel::window(), window);
936
937 auto src0 = tensors.get_const_tensor(TensorType::ACL_SRC_0);
938 auto src1 = tensors.get_const_tensor(TensorType::ACL_SRC_1);
939 auto dst = tensors.get_tensor(TensorType::ACL_DST);
940
941 // Check if the output tensor is a vector. If so,the kernel runs the vector-matrix multiplication path
942 if((dst->info()->dimension(1) == 1))
943 {
944 const auto width_matrix_a = static_cast<int>(src0->info()->dimension(0));
945 const auto width_matrix_b = static_cast<int>(src1->info()->dimension(0));
946 const auto width_out = static_cast<int>(dst->info()->dimension(0));
947 const auto in_b_stride = static_cast<int>(src1->info()->strides_in_bytes()[1] / data_size_from_type(src1->info()->data_type()));
948
949 // The implementation computes 16 elements per iteration
950 const int window_start_x = 16 * info.thread_id;
951 const int window_step_x = 16 * info.num_threads;
952 // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
953 const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
954
955 Window win_out(window);
956 win_out.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
957 win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
958
959 Window win_a(window);
960 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
961 win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
962
963 Window win_b;
964 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
965 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
966 if(src1->info()->num_dimensions() >= 3)
967 {
968 win_b = window;
969 }
970 win_b.set(Window::DimX, Window::Dimension(window_start_x, window_end_x, window_step_x));
971 win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
972
973 Iterator ina(src0, win_a);
974 Iterator inb(src1, win_b);
975 Iterator out(dst, win_out);
976
977 switch(src0->info()->data_type())
978 {
979 case DataType::S8:
980 case DataType::QASYMM8_SIGNED:
981 {
982 vector_matrix_multiply_s8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
983 break;
984 }
985 case DataType::U8:
986 case DataType::QASYMM8:
987 {
988 vector_matrix_multiply_u8(ina, inb, out, width_matrix_a, width_matrix_b, width_out, in_b_stride, window);
989 break;
990 }
991 default:
992 {
993 ARM_COMPUTE_ERROR("Not supported");
994 break;
995 }
996 }
997 }
998 else
999 {
1000 const size_t in_b_stride = src1->info()->strides_in_bytes()[1];
1001 const int width_b = src1->info()->dimension(0);
1002
1003 // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the output matrix
1004 Window win_a(window);
1005 win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
1006 win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, window.y().end() / 4, 1));
1007
1008 // Set step_x and step_y for matrix B. Scale by a factor of 16 the X range as the input transposed matrix A has 16 times less the columns of the output matrix
1009 Window win_b;
1010 // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
1011 // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
1012 if(_slide_matrix_b)
1013 {
1014 win_b = window;
1015 }
1016 win_b.set(Window::DimX, Window::Dimension(window.x().start() / 16, window.x().end() / 16, in_b_stride));
1017 win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
1018
1019 // The step x and step y for the output matrix has been already set using in configure()
1020 Iterator ina(src0, win_a);
1021 Iterator inb(src1, win_b);
1022 Iterator out(dst, window);
1023
1024 switch(src0->info()->data_type())
1025 {
1026 case DataType::S8:
1027 case DataType::QASYMM8_SIGNED:
1028 {
1029 matrix_multiply_s8(ina, inb, out, width_b, *dst->info(), window);
1030 break;
1031 }
1032 case DataType::U8:
1033 case DataType::QASYMM8:
1034 {
1035 matrix_multiply_u8(ina, inb, out, width_b, *dst->info(), window);
1036 break;
1037 }
1038 default:
1039 {
1040 ARM_COMPUTE_ERROR("Not supported");
1041 break;
1042 }
1043 }
1044 }
1045 }
1046
name() const1047 const char *CpuGemmLowpMatrixMultiplyKernel::name() const
1048 {
1049 return "CpuGemmLowpMatrixMultiplyKernel";
1050 }
1051 } // namespace kernels
1052 } // namespace cpu
1053 } // namespace arm_compute