xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 
25 #include "src/cpu/kernels/gemm_matrix_mul/generic/neon/impl.h"
26 #include "src/core/utils/helpers/float_ops.h"
27 
28 #include <arm_neon.h>
29 
30 namespace arm_compute
31 {
32 namespace cpu
33 {
34 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
vector_matrix_multiply_f16(const ITensor * lhs,const ITensor * rhs,ITensor * dst,const Window & window,const ThreadInfo & info,float alpha)35 void vector_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
36 {
37     const auto width_matrix_b  = static_cast<int>(dst->info()->dimension(0));
38     const auto in_b_stride     = static_cast<int>(rhs->info()->strides_in_bytes()[1] / rhs->info()->element_size());
39     const auto num_elems_vec_a = static_cast<int>(lhs->info()->dimension(0));
40 
41     // The implementation computes 32 elements per iteration
42     const int window_start_x = 32 * info.thread_id;
43     const int window_step_x  = 32 * info.num_threads;
44     const int window_end_x   = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
45     ARM_COMPUTE_ERROR_ON_MSG((window_end_x - window_start_x) % window_step_x, " (window_end_x - window_start_x) must be multiple of window_step_x");
46 
47     Window win_out(window);
48     win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
49     win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
50 
51     Window win_a(window);
52     win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
53     win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
54 
55     Window win_b;
56     // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
57     // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
58     if(rhs->info()->num_dimensions() >= 3)
59     {
60         win_b = window;
61     }
62     win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
63     win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
64 
65     Iterator ina(lhs, win_a);
66     Iterator inb(rhs, win_b);
67     Iterator out(dst, win_out);
68 
69     const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
70 
71     const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
72 
73     execute_window_loop(win_out, [&](const Coordinates &)
74     {
75         int x = window_start_x;
76         // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
77         // window_end_x is computed above which may cause out-of-bound writes to the dst.
78         for(; x < (window_end_x - window_step_x); x += window_step_x)
79         {
80             if(x > width_matrix_b)
81             {
82                 return;
83             }
84 
85             auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
86 
87             float16x8_t acc0 = vdupq_n_f16(0.f);
88             float16x8_t acc1 = vdupq_n_f16(0.f);
89             float16x8_t acc2 = vdupq_n_f16(0.f);
90             float16x8_t acc3 = vdupq_n_f16(0.f);
91 
92             auto             vec_a          = reinterpret_cast<const float16_t *>(ina.ptr());
93             const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
94             for(; vec_a <= (vec_a_end_addr - 4);)
95             {
96                 const float16x4_t a0l = vld1_f16(vec_a);
97 
98                 float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
99                 float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
100                 float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
101                 float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
102                 float16x8_t b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
103                 float16x8_t b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
104                 float16x8_t b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
105                 float16x8_t b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
106 
107                 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 0));
108                 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 0));
109                 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 0));
110                 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 0));
111                 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 1));
112                 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 1));
113                 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 1));
114                 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 1));
115 
116                 matrix_b += 2 * in_b_stride;
117 
118                 b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
119                 b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
120                 b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
121                 b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
122                 b10 = vld1q_f16(matrix_b + 0 + 1 * in_b_stride);
123                 b11 = vld1q_f16(matrix_b + 8 + 1 * in_b_stride);
124                 b12 = vld1q_f16(matrix_b + 16 + 1 * in_b_stride);
125                 b13 = vld1q_f16(matrix_b + 24 + 1 * in_b_stride);
126 
127                 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b00, a0l, 2));
128                 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b01, a0l, 2));
129                 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b02, a0l, 2));
130                 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b03, a0l, 2));
131                 acc0 = vaddq_f16(acc0, vmulq_lane_f16(b10, a0l, 3));
132                 acc1 = vaddq_f16(acc1, vmulq_lane_f16(b11, a0l, 3));
133                 acc2 = vaddq_f16(acc2, vmulq_lane_f16(b12, a0l, 3));
134                 acc3 = vaddq_f16(acc3, vmulq_lane_f16(b13, a0l, 3));
135 
136                 vec_a += 4;
137                 matrix_b += 2 * in_b_stride;
138             }
139 
140             for(; vec_a < vec_a_end_addr; ++vec_a)
141             {
142                 const float16_t   a0  = *vec_a;
143                 const float16x8_t b00 = vld1q_f16(matrix_b + 0 + 0 * in_b_stride);
144                 const float16x8_t b01 = vld1q_f16(matrix_b + 8 + 0 * in_b_stride);
145                 const float16x8_t b02 = vld1q_f16(matrix_b + 16 + 0 * in_b_stride);
146                 const float16x8_t b03 = vld1q_f16(matrix_b + 24 + 0 * in_b_stride);
147 
148                 acc0 = vaddq_f16(acc0, vmulq_n_f16(b00, a0));
149                 acc1 = vaddq_f16(acc1, vmulq_n_f16(b01, a0));
150                 acc2 = vaddq_f16(acc2, vmulq_n_f16(b02, a0));
151                 acc3 = vaddq_f16(acc3, vmulq_n_f16(b03, a0));
152 
153                 matrix_b += in_b_stride;
154             }
155 
156             // Multiply by the weight of matrix product (alpha)
157             if(multiply_alpha)
158             {
159                 acc0 = vmulq_f16(acc0, alpha_f16);
160                 acc1 = vmulq_f16(acc1, alpha_f16);
161                 acc2 = vmulq_f16(acc2, alpha_f16);
162                 acc3 = vmulq_f16(acc3, alpha_f16);
163             }
164 
165             auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
166 
167             vst1q_f16(vec_out + 0, acc0);
168             vst1q_f16(vec_out + 8, acc1);
169             vst1q_f16(vec_out + 16, acc2);
170             vst1q_f16(vec_out + 24, acc3);
171         }
172 
173         for(; x < window_end_x; ++x)
174         {
175             if(x > width_matrix_b)
176             {
177                 return;
178             }
179 
180             auto matrix_b = reinterpret_cast<const float16_t *>(inb.ptr()) + x;
181 
182             float16x4_t vacc = vdup_n_f16(0.f);
183 
184             auto             vec_a          = reinterpret_cast<const float16_t *>(ina.ptr());
185             const float16_t *vec_a_end_addr = vec_a + num_elems_vec_a;
186             for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
187             {
188                 const float16x4_t a0l = vld1_f16(vec_a);
189 
190                 const float16x4_t b_col =
191                 {
192                     *(matrix_b + 0 * in_b_stride),
193                     *(matrix_b + 1 * in_b_stride),
194                     *(matrix_b + 2 * in_b_stride),
195                     *(matrix_b + 3 * in_b_stride),
196                 };
197 
198                 vacc = vadd_f16(vacc, vmul_f16(a0l, b_col));
199 
200                 matrix_b += 4 * in_b_stride;
201             }
202 
203             float16_t acc = vget_lane_f16(vacc, 0) + vget_lane_f16(vacc, 1) + vget_lane_f16(vacc, 2) + vget_lane_f16(vacc, 3);
204 
205             for(; vec_a < vec_a_end_addr; ++vec_a)
206             {
207                 const float16_t a0  = *vec_a;
208                 const float16_t b00 = *matrix_b;
209 
210                 acc += b00 * a0;
211 
212                 matrix_b += in_b_stride;
213             }
214 
215             // Multiply by the weight of matrix product (alpha)
216             if(multiply_alpha)
217             {
218                 acc *= static_cast<float16_t>(alpha);
219             }
220 
221             auto vec_out = reinterpret_cast<float16_t *>(out.ptr()) + x;
222 
223             *(vec_out) = acc;
224         }
225     },
226     ina, inb, out);
227 }
228 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
229 
vector_matrix_multiply_f32(const ITensor * lhs,const ITensor * rhs,ITensor * dst,const Window & window,const ThreadInfo & info,float alpha)230 void vector_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
231 {
232     const auto width_matrix_b  = static_cast<int>(dst->info()->dimension(0));
233     const auto in_b_stride     = static_cast<int>(rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type()));
234     const auto num_elems_vec_a = static_cast<int>(lhs->info()->dimension(0));
235 
236     // The implementation computes 16 elements per iteration
237     const int window_start_x = 16 * info.thread_id;
238     const int window_step_x  = 16 * info.num_threads;
239     // Make sure (window_end_x - window_start_x) is a multiple of window_step_x
240     const int window_end_x = ceil_to_multiple(width_matrix_b - window_start_x, window_step_x) + window_start_x;
241 
242     Window win_out(window);
243     win_out.set(Window::DimX, Window::Dimension(0, 1, 1));
244     win_out.set(Window::DimY, Window::Dimension(0, 1, 1));
245 
246     Window win_a(window);
247     win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
248     win_a.set(Window::DimY, Window::Dimension(0, 0, 0));
249 
250     Window win_b;
251     // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
252     // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
253     if(rhs->info()->num_dimensions() >= 3)
254     {
255         win_b = window;
256     }
257     win_b.set(Window::DimX, Window::Dimension(0, 1, 1));
258     win_b.set(Window::DimY, Window::Dimension(0, 1, 1));
259 
260     Iterator ina(lhs, win_a);
261     Iterator inb(rhs, win_b);
262     Iterator out(dst, win_out);
263 
264     const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
265 
266     const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
267 
268     execute_window_loop(win_out, [&](const Coordinates &)
269     {
270         int x = window_start_x;
271         // Here we don't check for x lower equal than (window_end_x - window_step_x) because of
272         // window_end_x is computed above which may cause out-of-bound writes to the dst.
273         for(; x < (window_end_x - window_step_x); x += window_step_x)
274         {
275             if(x > width_matrix_b)
276             {
277                 return;
278             }
279 
280             float32x4_t acc0 = vdupq_n_f32(0.f);
281             float32x4_t acc1 = vdupq_n_f32(0.f);
282             float32x4_t acc2 = vdupq_n_f32(0.f);
283             float32x4_t acc3 = vdupq_n_f32(0.f);
284 
285             auto vec_a    = reinterpret_cast<const float *>(ina.ptr());
286             auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
287 
288 #if __arm__
289             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
290             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
291             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
292 #endif /* __arm__ */
293 
294             auto vec_a_end_addr = vec_a + num_elems_vec_a;
295             for(; vec_a <= (vec_a_end_addr - 4);)
296             {
297                 float32x2_t a0l = vld1_f32(vec_a);
298 
299                 float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
300                 float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
301                 float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
302                 float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
303 
304                 float32x4_t b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
305                 float32x4_t b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
306                 float32x4_t b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
307                 float32x4_t b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
308 
309 #if __arm__
310                 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
311                 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
312                 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
313                 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
314                 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
315 #endif /* __arm__ */
316 
317                 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
318                 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
319                 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
320                 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
321 
322                 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
323                 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
324                 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
325                 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
326 
327                 vec_a += 2;
328                 matrix_b += 2 * in_b_stride;
329 
330                 a0l = vld1_f32(vec_a);
331 
332                 b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
333                 b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
334                 b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
335                 b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
336 
337                 b10 = vld1q_f32(matrix_b + 0 + 1 * in_b_stride);
338                 b11 = vld1q_f32(matrix_b + 4 + 1 * in_b_stride);
339                 b12 = vld1q_f32(matrix_b + 8 + 1 * in_b_stride);
340                 b13 = vld1q_f32(matrix_b + 12 + 1 * in_b_stride);
341 
342                 acc0 = vmlaq_lane_f32(acc0, b00, a0l, 0);
343                 acc1 = vmlaq_lane_f32(acc1, b01, a0l, 0);
344                 acc2 = vmlaq_lane_f32(acc2, b02, a0l, 0);
345                 acc3 = vmlaq_lane_f32(acc3, b03, a0l, 0);
346 
347                 acc0 = vmlaq_lane_f32(acc0, b10, a0l, 1);
348                 acc1 = vmlaq_lane_f32(acc1, b11, a0l, 1);
349                 acc2 = vmlaq_lane_f32(acc2, b12, a0l, 1);
350                 acc3 = vmlaq_lane_f32(acc3, b13, a0l, 1);
351 
352                 vec_a += 2;
353                 matrix_b += 2 * in_b_stride;
354             }
355 
356             for(; vec_a < vec_a_end_addr; ++vec_a)
357             {
358                 const float a0 = *vec_a;
359 
360                 const float32x4_t b00 = vld1q_f32(matrix_b + 0 + 0 * in_b_stride);
361                 const float32x4_t b01 = vld1q_f32(matrix_b + 4 + 0 * in_b_stride);
362                 const float32x4_t b02 = vld1q_f32(matrix_b + 8 + 0 * in_b_stride);
363                 const float32x4_t b03 = vld1q_f32(matrix_b + 12 + 0 * in_b_stride);
364 
365                 acc0 = vmlaq_n_f32(acc0, b00, a0);
366                 acc1 = vmlaq_n_f32(acc1, b01, a0);
367                 acc2 = vmlaq_n_f32(acc2, b02, a0);
368                 acc3 = vmlaq_n_f32(acc3, b03, a0);
369 
370                 matrix_b += in_b_stride;
371             }
372 
373             // Multiply by the weight of matrix product (alpha)
374             if(multiply_alpha)
375             {
376                 acc0 = vmulq_f32(acc0, alpha_f32);
377                 acc1 = vmulq_f32(acc1, alpha_f32);
378                 acc2 = vmulq_f32(acc2, alpha_f32);
379                 acc3 = vmulq_f32(acc3, alpha_f32);
380             }
381 
382             const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
383 
384             vst1q_f32(vec_out + 0, acc0);
385             vst1q_f32(vec_out + 4, acc1);
386             vst1q_f32(vec_out + 8, acc2);
387             vst1q_f32(vec_out + 12, acc3);
388         }
389 
390         // Left-over loop
391         for(; x < window_end_x; ++x)
392         {
393             if(x > width_matrix_b)
394             {
395                 return;
396             }
397 
398             float32x4_t vacc = vdupq_n_f32(0.f);
399 
400             auto vec_a    = reinterpret_cast<const float *>(ina.ptr());
401             auto matrix_b = reinterpret_cast<const float *>(inb.ptr()) + x;
402 
403 #if __arm__
404             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
405             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b)));
406             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + in_b_stride)));
407 #endif /* __arm__ */
408 
409             auto vec_a_end_addr = vec_a + num_elems_vec_a;
410             for(; vec_a <= (vec_a_end_addr - 4); vec_a += 4)
411             {
412                 const float32x4_t a0l = vld1q_f32(vec_a);
413 
414                 const float32x4_t b_col =
415                 {
416                     *(matrix_b + 0 * in_b_stride),
417                     *(matrix_b + 1 * in_b_stride),
418                     *(matrix_b + 2 * in_b_stride),
419                     *(matrix_b + 3 * in_b_stride),
420                 };
421 
422 #if __arm__
423                 asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(vec_a)));
424                 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 1 * in_b_stride)));
425                 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 2 * in_b_stride)));
426                 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 3 * in_b_stride)));
427                 asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(matrix_b + 4 * in_b_stride)));
428 #endif /* __arm__ */
429 
430                 vacc = vmlaq_f32(vacc, b_col, a0l);
431 
432                 matrix_b += 4 * in_b_stride;
433             }
434 
435             float acc = vgetq_lane_f32(vacc, 0) + vgetq_lane_f32(vacc, 1) + vgetq_lane_f32(vacc, 2) + vgetq_lane_f32(vacc, 3);
436 
437             for(; vec_a < vec_a_end_addr; ++vec_a)
438             {
439                 const float a0 = *vec_a;
440 
441                 const float b00 = *matrix_b;
442 
443                 acc += b00 * a0;
444 
445                 matrix_b += in_b_stride;
446             }
447 
448             // Multiply by the weight of matrix product (alpha)
449             if(multiply_alpha)
450             {
451                 acc *= alpha;
452             }
453 
454             const auto vec_out = reinterpret_cast<float *>(out.ptr()) + x;
455 
456             *vec_out = acc;
457         }
458     },
459     ina, inb, out);
460 }
461 
matrix_matrix_multiply_f32(const ITensor * lhs,const ITensor * rhs,ITensor * dst,const Window & window,const ThreadInfo & info,float alpha)462 void matrix_matrix_multiply_f32(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
463 {
464     ARM_COMPUTE_UNUSED(info);
465     const int    out_width            = static_cast<int>(dst->info()->dimension(0));
466     const int    out_height           = static_cast<int>(dst->info()->dimension(1));
467     const size_t in_b_stride          = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type());
468     const size_t out_stride1          = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type());
469     const size_t out_stride2          = out_stride1 * 2;
470     const size_t out_stride3          = out_stride1 * 3;
471     const int    num_elems_matrix_b_x = rhs->info()->dimension(0);
472 
473     // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix
474     Window win_a(window);
475     win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
476     win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
477 
478     Window win_b;
479     // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
480     // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
481     if(rhs->info()->num_dimensions() >= 3)
482     {
483         win_b = window;
484     }
485     // Set step_x and step_y for matrix B. Scale by a factor of 4 the X range as the input transposed matrix A has 4 times less the cols of the dst matrix
486     // The step along the x direction is 2 times the in_b_stride because for each iteration we compute 2 blocks of size 4x4
487     win_b.set(Window::DimX, Window::Dimension(window.x().start() / 4, window.x().end() / 4, 2 * in_b_stride));
488     win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
489 
490     Iterator ina(lhs, win_a);
491     Iterator inb(rhs, win_b);
492     Iterator out(dst, window);
493 
494     const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
495 
496     const float32x4_t alpha_f32 = vdupq_n_f32(alpha);
497 
498     // The implementation assumes that the matrix A and Matrix B have been reshaped respectively with CpuGemmInterleave4x4 and CpuGemmTranspose1xW
499     // The reshaping of the matrices helps to have a cache friendly implementation and helps to avoid the data re-arrangements needed for computing 16x4 elements per iteration
500     // All the values needed for computing a single 4x4 block will be read from consecutive memory positions
501     execute_window_loop(window, [&](const Coordinates & id)
502     {
503         auto mtx_a0 = reinterpret_cast<const float *>(ina.ptr());
504         auto mtx_b0 = reinterpret_cast<const float *>(inb.ptr());
505         auto mtx_b1 = mtx_b0 + in_b_stride;
506 
507         float32x4_t acc00 = vdupq_n_f32(0.f);
508         float32x4_t acc10 = vdupq_n_f32(0.f);
509         float32x4_t acc20 = vdupq_n_f32(0.f);
510         float32x4_t acc30 = vdupq_n_f32(0.f);
511 
512         float32x4_t acc01 = vdupq_n_f32(0.f);
513         float32x4_t acc11 = vdupq_n_f32(0.f);
514         float32x4_t acc21 = vdupq_n_f32(0.f);
515         float32x4_t acc31 = vdupq_n_f32(0.f);
516 
517 #if __arm__
518         asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
519         asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
520         asm volatile("PLD [%0, #128*1]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
521 #endif /* __arm__ */
522 
523         auto mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
524         for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
525         {
526             float32x4_t a0 = vld1q_dup_f32(mtx_a0 + 0);
527             float32x4_t a1 = vld1q_dup_f32(mtx_a0 + 1);
528             float32x4_t a2 = vld1q_dup_f32(mtx_a0 + 2);
529             float32x4_t a3 = vld1q_dup_f32(mtx_a0 + 3);
530 
531             float32x4_t b00 = vld1q_f32(mtx_b0);
532             float32x4_t b10 = vld1q_f32(mtx_b1);
533             float32x4_t b01 = vld1q_f32(mtx_b0 + 4);
534             float32x4_t b11 = vld1q_f32(mtx_b1 + 4);
535 
536 #if __arm__
537             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
538             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
539             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
540 #endif /* __arm__ */
541 
542             // 4x4 block 0
543             acc00 = vmlaq_f32(acc00, b00, a0);
544             acc10 = vmlaq_f32(acc10, b00, a1);
545             acc20 = vmlaq_f32(acc20, b00, a2);
546             acc30 = vmlaq_f32(acc30, b00, a3);
547 
548             float32x4_t a4 = vld1q_dup_f32(mtx_a0 + 4);
549             float32x4_t a5 = vld1q_dup_f32(mtx_a0 + 5);
550             float32x4_t a6 = vld1q_dup_f32(mtx_a0 + 6);
551             float32x4_t a7 = vld1q_dup_f32(mtx_a0 + 7);
552 
553             // 4x4 block 1
554             acc01 = vmlaq_f32(acc01, b10, a0);
555             acc11 = vmlaq_f32(acc11, b10, a1);
556             acc21 = vmlaq_f32(acc21, b10, a2);
557             acc31 = vmlaq_f32(acc31, b10, a3);
558 
559             // 4x4 block 0
560             acc00 = vmlaq_f32(acc00, b01, a4);
561             acc10 = vmlaq_f32(acc10, b01, a5);
562             acc20 = vmlaq_f32(acc20, b01, a6);
563             acc30 = vmlaq_f32(acc30, b01, a7);
564 
565             // 4x4 block 1
566             acc01 = vmlaq_f32(acc01, b11, a4);
567             acc11 = vmlaq_f32(acc11, b11, a5);
568             acc21 = vmlaq_f32(acc21, b11, a6);
569             acc31 = vmlaq_f32(acc31, b11, a7);
570 
571             mtx_a0 += 8;
572             mtx_b0 += 8;
573             mtx_b1 += 8;
574 
575             a0 = vld1q_dup_f32(mtx_a0 + 0);
576             a1 = vld1q_dup_f32(mtx_a0 + 1);
577             a2 = vld1q_dup_f32(mtx_a0 + 2);
578             a3 = vld1q_dup_f32(mtx_a0 + 3);
579 
580             b00 = vld1q_f32(mtx_b0);
581             b10 = vld1q_f32(mtx_b1);
582             b01 = vld1q_f32(mtx_b0 + 4);
583             b11 = vld1q_f32(mtx_b1 + 4);
584 
585             // 4x4 block 0
586             acc00 = vmlaq_f32(acc00, b00, a0);
587             acc10 = vmlaq_f32(acc10, b00, a1);
588             acc20 = vmlaq_f32(acc20, b00, a2);
589             acc30 = vmlaq_f32(acc30, b00, a3);
590 
591             a4 = vld1q_dup_f32(mtx_a0 + 4);
592             a5 = vld1q_dup_f32(mtx_a0 + 5);
593             a6 = vld1q_dup_f32(mtx_a0 + 6);
594             a7 = vld1q_dup_f32(mtx_a0 + 7);
595 
596             // 4x4 block 1
597             acc01 = vmlaq_f32(acc01, b10, a0);
598             acc11 = vmlaq_f32(acc11, b10, a1);
599             acc21 = vmlaq_f32(acc21, b10, a2);
600             acc31 = vmlaq_f32(acc31, b10, a3);
601 
602             // 4x4 block 0
603             acc00 = vmlaq_f32(acc00, b01, a4);
604             acc10 = vmlaq_f32(acc10, b01, a5);
605             acc20 = vmlaq_f32(acc20, b01, a6);
606             acc30 = vmlaq_f32(acc30, b01, a7);
607 
608             // 4x4 block 1
609             acc01 = vmlaq_f32(acc01, b11, a4);
610             acc11 = vmlaq_f32(acc11, b11, a5);
611             acc21 = vmlaq_f32(acc21, b11, a6);
612             acc31 = vmlaq_f32(acc31, b11, a7);
613 
614             mtx_a0 += 8;
615             mtx_b0 += 8;
616             mtx_b1 += 8;
617 
618             a0  = vld1q_dup_f32(mtx_a0 + 0);
619             a1  = vld1q_dup_f32(mtx_a0 + 1);
620             a2  = vld1q_dup_f32(mtx_a0 + 2);
621             a3  = vld1q_dup_f32(mtx_a0 + 3);
622             b00 = vld1q_f32(mtx_b0);
623             b10 = vld1q_f32(mtx_b1);
624             b01 = vld1q_f32(mtx_b0 + 4);
625             b11 = vld1q_f32(mtx_b1 + 4);
626 
627 #if __arm__
628             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
629             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
630             asm volatile("PLD [%0, #128*4]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
631 #endif /* __arm__ */
632 
633             // 4x4 block 0
634             acc00 = vmlaq_f32(acc00, b00, a0);
635             acc10 = vmlaq_f32(acc10, b00, a1);
636             acc20 = vmlaq_f32(acc20, b00, a2);
637             acc30 = vmlaq_f32(acc30, b00, a3);
638 
639             a4 = vld1q_dup_f32(mtx_a0 + 4);
640             a5 = vld1q_dup_f32(mtx_a0 + 5);
641             a6 = vld1q_dup_f32(mtx_a0 + 6);
642             a7 = vld1q_dup_f32(mtx_a0 + 7);
643 
644             // 4x4 block 1
645             acc01 = vmlaq_f32(acc01, b10, a0);
646             acc11 = vmlaq_f32(acc11, b10, a1);
647             acc21 = vmlaq_f32(acc21, b10, a2);
648             acc31 = vmlaq_f32(acc31, b10, a3);
649 
650             // 4x4 block 0
651             acc00 = vmlaq_f32(acc00, b01, a4);
652             acc10 = vmlaq_f32(acc10, b01, a5);
653             acc20 = vmlaq_f32(acc20, b01, a6);
654             acc30 = vmlaq_f32(acc30, b01, a7);
655 
656             // 4x4 block 1
657             acc01 = vmlaq_f32(acc01, b11, a4);
658             acc11 = vmlaq_f32(acc11, b11, a5);
659             acc21 = vmlaq_f32(acc21, b11, a6);
660             acc31 = vmlaq_f32(acc31, b11, a7);
661 
662             mtx_a0 += 8;
663             mtx_b0 += 8;
664             mtx_b1 += 8;
665 
666             a0  = vld1q_dup_f32(mtx_a0 + 0);
667             a1  = vld1q_dup_f32(mtx_a0 + 1);
668             a2  = vld1q_dup_f32(mtx_a0 + 2);
669             a3  = vld1q_dup_f32(mtx_a0 + 3);
670             b00 = vld1q_f32(mtx_b0);
671             b10 = vld1q_f32(mtx_b1);
672             b01 = vld1q_f32(mtx_b0 + 4);
673             b11 = vld1q_f32(mtx_b1 + 4);
674 
675             // 4x4 block 0
676             acc00 = vmlaq_f32(acc00, b00, a0);
677             acc10 = vmlaq_f32(acc10, b00, a1);
678             acc20 = vmlaq_f32(acc20, b00, a2);
679             acc30 = vmlaq_f32(acc30, b00, a3);
680 
681             a4 = vld1q_dup_f32(mtx_a0 + 4);
682             a5 = vld1q_dup_f32(mtx_a0 + 5);
683             a6 = vld1q_dup_f32(mtx_a0 + 6);
684             a7 = vld1q_dup_f32(mtx_a0 + 7);
685 
686             // 4x4 block 1
687             acc01 = vmlaq_f32(acc01, b10, a0);
688             acc11 = vmlaq_f32(acc11, b10, a1);
689             acc21 = vmlaq_f32(acc21, b10, a2);
690             acc31 = vmlaq_f32(acc31, b10, a3);
691 
692             // 4x4 block 0
693             acc00 = vmlaq_f32(acc00, b01, a4);
694             acc10 = vmlaq_f32(acc10, b01, a5);
695             acc20 = vmlaq_f32(acc20, b01, a6);
696             acc30 = vmlaq_f32(acc30, b01, a7);
697 
698             // 4x4 block 1
699             acc01 = vmlaq_f32(acc01, b11, a4);
700             acc11 = vmlaq_f32(acc11, b11, a5);
701             acc21 = vmlaq_f32(acc21, b11, a6);
702             acc31 = vmlaq_f32(acc31, b11, a7);
703 
704             mtx_a0 += 8;
705             mtx_b0 += 8;
706             mtx_b1 += 8;
707         }
708 
709         for(; mtx_b0 < mtx_b0_end_addr;)
710         {
711             float32x4_t a0  = vld1q_dup_f32(mtx_a0 + 0);
712             float32x4_t a1  = vld1q_dup_f32(mtx_a0 + 1);
713             float32x4_t a2  = vld1q_dup_f32(mtx_a0 + 2);
714             float32x4_t a3  = vld1q_dup_f32(mtx_a0 + 3);
715             float32x4_t b00 = vld1q_f32(mtx_b0);
716             float32x4_t b10 = vld1q_f32(mtx_b1);
717 
718 #if __arm__
719             asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_a0)));
720             asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b0)));
721             asm volatile("PLD [%0, #128*2]" ::"r"(reinterpret_cast<const uint8_t *>(mtx_b1)));
722 #endif /* __arm__ */
723             // 4x4 block 0
724             acc00 = vmlaq_f32(acc00, b00, a0);
725             acc10 = vmlaq_f32(acc10, b00, a1);
726             acc20 = vmlaq_f32(acc20, b00, a2);
727             acc30 = vmlaq_f32(acc30, b00, a3);
728 
729             // 4x4 block 1
730             acc01 = vmlaq_f32(acc01, b10, a0);
731             acc11 = vmlaq_f32(acc11, b10, a1);
732             acc21 = vmlaq_f32(acc21, b10, a2);
733             acc31 = vmlaq_f32(acc31, b10, a3);
734 
735             mtx_a0 += 4;
736             mtx_b0 += 4;
737             mtx_b1 += 4;
738         }
739 
740         // Multiply by the weight of matrix product (alpha)
741         if(multiply_alpha)
742         {
743             acc00 = vmulq_f32(acc00, alpha_f32);
744             acc10 = vmulq_f32(acc10, alpha_f32);
745             acc20 = vmulq_f32(acc20, alpha_f32);
746             acc30 = vmulq_f32(acc30, alpha_f32);
747             acc01 = vmulq_f32(acc01, alpha_f32);
748             acc11 = vmulq_f32(acc11, alpha_f32);
749             acc21 = vmulq_f32(acc21, alpha_f32);
750             acc31 = vmulq_f32(acc31, alpha_f32);
751         }
752 
753         const auto mtx_out0 = reinterpret_cast<float *>(out.ptr());
754         const auto mtx_out1 = mtx_out0 + 4;
755 
756         if(id.x() < (out_width - 8))
757         {
758             vst1q_f32(mtx_out0, acc00);
759             vst1q_f32(mtx_out1, acc01);
760             if(id.y() + 1 < out_height)
761             {
762                 vst1q_f32(mtx_out0 + out_stride1, acc10);
763                 vst1q_f32(mtx_out1 + out_stride1, acc11);
764                 if(id.y() + 2 < out_height)
765                 {
766                     vst1q_f32(mtx_out0 + out_stride2, acc20);
767                     vst1q_f32(mtx_out1 + out_stride2, acc21);
768                     if(id.y() + 3 < out_height)
769                     {
770                         vst1q_f32(mtx_out0 + out_stride3, acc30);
771                         vst1q_f32(mtx_out1 + out_stride3, acc31);
772                     }
773                 }
774             }
775         }
776         else if(id.x() < (out_width - 4))
777         {
778             vst1q_f32(mtx_out0, acc00);
779             if(id.y() + 1 < out_height)
780             {
781                 vst1q_f32(mtx_out0 + out_stride1, acc10);
782                 if(id.y() + 2 < out_height)
783                 {
784                     vst1q_f32(mtx_out0 + out_stride2, acc20);
785                     if(id.y() + 3 < out_height)
786                     {
787                         vst1q_f32(mtx_out0 + out_stride3, acc30);
788                     }
789                 }
790             }
791             // Left-over columns
792             const int columns_left = out_width - id.x() - 4;
793             for(auto x = 0; x < columns_left; ++x)
794             {
795                 *(mtx_out1 + x) = acc01[x];
796                 if(id.y() + 1 < out_height)
797                 {
798                     *(mtx_out1 + x + out_stride1) = acc11[x];
799                     if(id.y() + 2 < out_height)
800                     {
801                         *(mtx_out1 + x + out_stride2) = acc21[x];
802                         if(id.y() + 3 < out_height)
803                         {
804                             *(mtx_out1 + x + out_stride3) = acc31[x];
805                         }
806                     }
807                 }
808             }
809         }
810         else
811         {
812             // Left-over columns
813             const int columns_left = out_width - id.x();
814             for(int x = 0; x < columns_left; ++x)
815             {
816                 *(mtx_out0 + x) = acc00[x];
817                 if(id.y() + 1 < out_height)
818                 {
819                     *(mtx_out0 + x + out_stride1) = acc10[x];
820                     if(id.y() + 2 < out_height)
821                     {
822                         *(mtx_out0 + x + out_stride2) = acc20[x];
823                         if(id.y() + 3 < out_height)
824                         {
825                             *(mtx_out0 + x + out_stride3) = acc30[x];
826                         }
827                     }
828                 }
829             }
830         }
831     },
832     ina, inb, out);
833 }
834 
835 #ifdef __ARM_FEATURE_FP16_VECTOR_ARITHMETIC
matrix_matrix_multiply_f16(const ITensor * lhs,const ITensor * rhs,ITensor * dst,const Window & window,const ThreadInfo & info,float alpha)836 void matrix_matrix_multiply_f16(const ITensor *lhs, const ITensor *rhs, ITensor *dst, const Window &window, const ThreadInfo &info, float alpha)
837 {
838     ARM_COMPUTE_UNUSED(info);
839     const int    out_width            = static_cast<int>(dst->info()->dimension(0));
840     const int    out_height           = static_cast<int>(dst->info()->dimension(1));
841     const size_t in_b_stride          = rhs->info()->strides_in_bytes()[1] / data_size_from_type(rhs->info()->data_type());
842     const size_t out_stride           = dst->info()->strides_in_bytes()[1] / data_size_from_type(dst->info()->data_type());
843     const int    num_elems_matrix_b_x = rhs->info()->dimension(0);
844 
845     // Set step_x and step_y for matrix A. Scale by a factor of 4 the Y range as the input interleaved matrix A has 4 times less the rows of the dst matrix
846     Window win_a(window);
847     win_a.set(Window::DimX, Window::Dimension(0, 0, 0));
848     win_a.set(Window::DimY, Window::Dimension(window.y().start() / 4, std::max(window.y().end() / 4, 1), 1));
849 
850     Window win_b;
851     // Don't slice matrix B along the z dimension if matrix B has just 2 dimensions and matrix A more than 2
852     // This scenario can happen when the the matrix multiplication is used to perform a convolution operation
853     if(rhs->info()->num_dimensions() >= 3)
854     {
855         win_b = window;
856     }
857     // Set step_x and step_y for matrix B. Scale by a factor of 8 the X range as the input transposed matrix A has 8 times less the cols of the dst matrix
858     win_b.set(Window::DimX, Window::Dimension(window.x().start() / 8, window.x().end() / 8, in_b_stride));
859     win_b.set(Window::DimY, Window::Dimension(0, 0, 0));
860 
861     Iterator ina(lhs, win_a);
862     Iterator inb(rhs, win_b);
863     Iterator out(dst, window);
864 
865     const bool multiply_alpha = !(helpers::float_ops::is_one(alpha));
866 
867     const float16x8_t alpha_f16 = vdupq_n_f16(alpha);
868 
869     execute_window_loop(window, [&](const Coordinates & id)
870     {
871         const auto   *mtx_a0  = reinterpret_cast<const float16_t *>(ina.ptr());
872         const auto   *mtx_b0  = reinterpret_cast<const float16_t *>(inb.ptr());
873         auto         *mtx_out = reinterpret_cast<float16_t *>(out.ptr());
874         float16x8x4_t c =
875         {
876             {
877                 vdupq_n_f16(0.f),
878                 vdupq_n_f16(0.f),
879                 vdupq_n_f16(0.f),
880                 vdupq_n_f16(0.f)
881             }
882         };
883 
884         /*
885         This kernel puts the values in a 4x4 block of Matrix A on the same row (Interleaved values)
886              |a00 a01 a02 a03 | a04 a05 a06 a07|
887              |a10 a11 a12 a13 | a14 a15 a16 a17|
888              |a20 a21 a22 a23 | a24 a25 a26 a27| = | a00 a10 a20 a30 || a01 a11 a21 a31 || a02 a12 a22 a32 || a03 a13 a23 a33 | a40 a50 a60 a70 | ...
889              |a30 a31 a32 a33 | a34 a35 a36 a37|   | a04 a14 a24 a34 || a05 a15 a25 a35 || a06 a15 a26 a36 || a07 a17 a27 a37 | a44 a54 a64 a74 | ...
890              |a40 a41 a42 a43 | a44 a45 a46 a47|
891              |a50 a51 a52 a53 | a54 a55 a56 a57|
892              |a60 a61 a62 a63 | a64 a65 a66 a67|
893              |a70 a71 a72 a73 | a74 a75 a76 a77|
894 
895              After this operation, the dst matrix will have the following shape: [ height * 4, width / 4 ]
896 
897         B Matrix has been transposed as shown below
898 
899            |b00 b01 b02 b03 b04 b05 b06 b07|
900            |b10 b11 b12 b13 b14 b15 b16 b17|
901            |b20 b21 b22 b23 b24 b25 b26 b27|
902            |b30 b31 b32 b33 b34 b35 b36 b37|
903           ------------------->
904 
905            |b00 b01 b02 b03 b04 b05 b06 b07||b10 b11 b12 b13 b14 b15 b16 b17||b20 b21 b22 b23 b24 b25 b26 b27||b30 b31 b32 b33 b34 b35 b36 b37|
906 
907             c.val[0][0] = a00*b00 + a01*b10 + a02*b20 + a03*b30
908             c.val[0][1] = a00*b01 + a01*b11 + a02*b21 + a03*b31
909 
910         The size of the dst tensor's XY-plane must be the following shape [ width * 8, height / 8 ]. All other dimensions must have the same size.
911         */
912         const float16_t *mtx_b0_end_addr = mtx_b0 + num_elems_matrix_b_x;
913 
914         for(; mtx_b0 <= (mtx_b0_end_addr - 32);)
915 
916         {
917             const float16x8_t p00 = vld1q_f16(mtx_a0);
918             const float16x8_t p02 = vld1q_f16(mtx_a0 + 8);
919 
920             const float16x8_t q00 = vld1q_f16(mtx_b0);
921             const float16x8_t q02 = vld1q_f16(mtx_b0 + 8);
922             const float16x8_t q04 = vld1q_f16(mtx_b0 + 16);
923             const float16x8_t q06 = vld1q_f16(mtx_b0 + 24);
924 
925             c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vgetq_lane_f16(p00, 0)));
926             c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vgetq_lane_f16(p00, 1)));
927             c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vgetq_lane_f16(p00, 2)));
928             c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vgetq_lane_f16(p00, 3)));
929 
930             c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q02, vgetq_lane_f16(p00, 4)));
931             c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q02, vgetq_lane_f16(p00, 5)));
932             c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q02, vgetq_lane_f16(p00, 6)));
933             c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q02, vgetq_lane_f16(p00, 7)));
934 
935             c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q04, vgetq_lane_f16(p02, 0)));
936             c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q04, vgetq_lane_f16(p02, 1)));
937             c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q04, vgetq_lane_f16(p02, 2)));
938             c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q04, vgetq_lane_f16(p02, 3)));
939 
940             c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q06, vgetq_lane_f16(p02, 4)));
941             c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q06, vgetq_lane_f16(p02, 5)));
942             c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q06, vgetq_lane_f16(p02, 6)));
943             c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q06, vgetq_lane_f16(p02, 7)));
944 
945             mtx_a0 += 16;
946             mtx_b0 += 32;
947         }
948 
949         for(; mtx_b0 < mtx_b0_end_addr;)
950 
951         {
952             const float16x4_t p00 = vld1_f16(mtx_a0);
953             const float16x8_t q00 = vld1q_f16(mtx_b0);
954 
955             c.val[0] = vaddq_f16(c.val[0], vmulq_n_f16(q00, vget_lane_f16(p00, 0)));
956             c.val[1] = vaddq_f16(c.val[1], vmulq_n_f16(q00, vget_lane_f16(p00, 1)));
957             c.val[2] = vaddq_f16(c.val[2], vmulq_n_f16(q00, vget_lane_f16(p00, 2)));
958             c.val[3] = vaddq_f16(c.val[3], vmulq_n_f16(q00, vget_lane_f16(p00, 3)));
959 
960             mtx_a0 += 4;
961             mtx_b0 += 8;
962         }
963 
964         if(multiply_alpha)
965         {
966             c.val[0] = vmulq_f16(c.val[0], alpha_f16);
967             c.val[1] = vmulq_f16(c.val[1], alpha_f16);
968             c.val[2] = vmulq_f16(c.val[2], alpha_f16);
969             c.val[3] = vmulq_f16(c.val[3], alpha_f16);
970         }
971 
972         if(id.x() < (out_width - 8))
973         {
974             vst1q_f16(mtx_out, c.val[0]);
975             if(id.y() + 1 < out_height)
976             {
977                 vst1q_f16(mtx_out + 1 * out_stride, c.val[1]);
978                 if(id.y() + 2 < out_height)
979                 {
980                     vst1q_f16(mtx_out + 2 * out_stride, c.val[2]);
981                     if(id.y() + 3 < out_height)
982                     {
983                         vst1q_f16(mtx_out + 3 * out_stride, c.val[3]);
984                     }
985                 }
986             }
987         }
988         else
989         {
990             // Left-over columns
991             const int columns_left = out_width - id.x();
992             for(int x = 0; x < columns_left; ++x)
993             {
994                 *(mtx_out + x) = c.val[0][x];
995                 if(id.y() + 1 < out_height)
996                 {
997                     *(mtx_out + x + 1 * out_stride) = c.val[1][x];
998                     if(id.y() + 2 < out_height)
999                     {
1000                         *(mtx_out + x + 2 * out_stride) = c.val[2][x];
1001                         if(id.y() + 3 < out_height)
1002                         {
1003                             *(mtx_out + x + 3 * out_stride) = c.val[3][x];
1004                         }
1005                     }
1006                 }
1007             }
1008         }
1009     },
1010     ina, inb, out);
1011 }
1012 #endif /* __ARM_FEATURE_FP16_VECTOR_ARITHMETIC */
1013 
1014 } // namespace cpu
1015 
1016 } // namespace arm_compute
1017