xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/operators/CpuGemm.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2021-2023 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/operators/CpuGemm.h"
25 
26 #include "arm_compute/core/TensorInfo.h"
27 #include "arm_compute/core/Validate.h"
28 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
29 #include "arm_compute/runtime/NEON/NEScheduler.h"
30 #include "src/common/utils/Log.h"
31 #include "src/core/CPP/Validate.h"
32 #include "src/core/helpers/AutoConfiguration.h"
33 #include "src/core/helpers/MemoryHelpers.h"
34 #include "src/cpu/utils/CpuAuxTensorHandler.h"
35 
36 using namespace arm_compute::experimental;
37 using namespace arm_compute::misc::shape_calculator;
38 
39 namespace arm_compute
40 {
41 namespace cpu
42 {
43 namespace
44 {
init_assembly_metadata(const GEMMInfo & info)45 cpu::AsmGemmInfo init_assembly_metadata(const GEMMInfo &info)
46 {
47     cpu::AsmGemmInfo asm_info;
48     asm_info.method                  = cpu::AsmConvMethod::Im2Col;
49     asm_info.reinterpret_input_as_3d = info.reinterpret_input_as_3d();
50     asm_info.depth_output_gemm3d     = info.depth_output_gemm3d();
51     asm_info.activation_info         = info.activation_info();
52     asm_info.fast_mode               = info.fast_math();
53     asm_info.fixed_format            = info.fixed_format();
54     asm_info.weight_format           = info.weight_format();
55 
56     return asm_info;
57 }
58 } // namespace
59 
configure(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,ITensorInfo * d,float alpha,float beta,const GEMMInfo & gemm_info)60 void CpuGemm::configure(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
61 {
62     ARM_COMPUTE_ERROR_ON_NULLPTR(a, b, d);
63     ARM_COMPUTE_ERROR_THROW_ON(CpuGemm::validate(a, b, c, d, alpha, beta, gemm_info));
64     ARM_COMPUTE_LOG_PARAMS(a, b, c, d, alpha, beta, gemm_info);
65 
66     const cpu::AsmGemmInfo asm_info      = init_assembly_metadata(gemm_info);
67     const bool             is_c_bias     = gemm_info.reshape_b_only_on_first_run();
68     bool                   run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, (is_c_bias) ? c : nullptr, d, asm_info)) && gemm_info.reshape_b_only_on_first_run();
69 
70     // Check if we need to reshape the matrix B only on the first run
71     _is_prepared                      = false;
72     _reshape_b_only_on_first_run      = gemm_info.reshape_b_only_on_first_run();
73     _run_vector_matrix_multiplication = a->dimension(1) < 2;
74     _run_alpha_scale                  = alpha != 1.f;
75     _run_bias_addition                = c != nullptr && gemm_info.reshape_b_only_on_first_run();
76     _run_addition                     = beta != 0 && c != nullptr && !gemm_info.reshape_b_only_on_first_run();
77     _run_activation                   = gemm_info.activation_info().enabled() && (!run_optimised || (run_optimised && !cpu::CpuGemmAssemblyDispatch::is_activation_supported(gemm_info.activation_info())));
78 
79     if(run_optimised)
80     {
81         const ITensorInfo *c_to_use = is_c_bias ? c : nullptr;
82         _asm_glue                   = std::make_unique<cpu::CpuGemmAssemblyDispatch>();
83         _asm_glue->configure(a, b, c_to_use, d, asm_info);
84         ARM_COMPUTE_ERROR_ON(!_asm_glue->is_configured());
85 
86         auto asm_mem_req           = _asm_glue->workspace();
87         _aux_mem[AsmGemmWorkspace] = asm_mem_req[AsmGemmWorkspace];
88         _aux_mem[Pretraspose]      = asm_mem_req[Pretraspose];
89 
90         // Scale product by alpha
91         if(_run_alpha_scale)
92         {
93             _alpha_scale_func = std::make_unique<cpu::CpuActivation>();
94             _alpha_scale_func->configure(d, nullptr, ActivationLayerInfo(ActivationLayerInfo::ActivationFunction::LINEAR, alpha, 0.f));
95         }
96     }
97     else
98     {
99         // Pick output tensor in case bias addition should be performed
100         ITensorInfo *gemm_output_to_use = (_run_bias_addition) ? &_tmp_d : d;
101 
102         _mm_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixMultiplyKernel>();
103 
104         // Select between GEMV and GEMM
105         if(_run_vector_matrix_multiplication)
106         {
107             // Configure the matrix multiply kernel
108             _mm_kernel->configure(a, b, gemm_output_to_use, alpha, false);
109         }
110         else
111         {
112             const int m = a->dimension(1);
113             const int n = b->dimension(0);
114             const int k = a->dimension(0);
115 
116             // Configure interleave kernel
117             _interleave_kernel = std::make_unique<cpu::kernels::CpuGemmInterleave4x4Kernel>();
118             _interleave_kernel->configure(a, &_tmp_a);
119             _aux_mem[InterleavedLHS] = MemoryInfo(offset_int_vec(InterleavedLHS), MemoryLifetime::Temporary, _tmp_a.total_size());
120 
121             // Configure transpose kernel
122             _transpose_kernel = std::make_unique<cpu::kernels::CpuGemmTranspose1xWKernel>();
123             _transpose_kernel->configure(b, &_tmp_b);
124             _aux_mem[TransposedRHS] = MemoryInfo(offset_int_vec(TransposedRHS), MemoryLifetime::Persistent, _tmp_b.total_size());
125 
126             // Configure matrix multiplication kernel
127             _mm_kernel->configure(&_tmp_a, &_tmp_b, gemm_output_to_use, alpha, true, GEMMReshapeInfo(m, n, k));
128         }
129 
130         if(_run_bias_addition)
131         {
132             _add_bias = std::make_unique<cpu::CpuAdd>();
133             _add_bias->configure(gemm_output_to_use, c, d, ConvertPolicy::SATURATE);
134             _aux_mem[TempResult] = MemoryInfo(offset_int_vec(TempResult), MemoryLifetime::Temporary, _tmp_d.total_size());
135         }
136     }
137 
138     // Configure matrix addition kernel
139     if(_run_addition)
140     {
141         _ma_kernel = std::make_unique<cpu::kernels::CpuGemmMatrixAdditionKernel>();
142         _ma_kernel->configure(c, d, beta);
143     }
144 
145     // Configure activation
146     if(_run_activation)
147     {
148         _activation_func = std::make_unique<cpu::CpuActivation>();
149         _activation_func->configure(d, nullptr, gemm_info.activation_info());
150     }
151 }
152 
validate(const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * d,float alpha,float beta,const GEMMInfo & gemm_info)153 Status CpuGemm::validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d, float alpha, float beta, const GEMMInfo &gemm_info)
154 {
155     ARM_COMPUTE_UNUSED(alpha);
156     const bool is_c_bias = gemm_info.reshape_b_only_on_first_run();
157 
158     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(a);
159     ARM_COMPUTE_RETURN_ERROR_ON_CPU_BF16_UNSUPPORTED(a);
160     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(a, 1, DataType::BFLOAT16, DataType::F16, DataType::F32);
161 
162     if (is_fixed_format_fast_math(gemm_info.weight_format()))
163     {
164         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(a, DataType::F32);
165         ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_NOT_IN(b, DataType::BFLOAT16);
166     }
167     else
168     {
169         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, b);
170     }
171 
172     ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(0) != b->dimension(1), "The product AB is defined only if the number of columns in A is equal to the number of rows in B");
173     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_a_reshaped(), "Matrix A already reshaped is not supported");
174     ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.is_b_reshaped(), "Matrix B already reshaped is not supported");
175     if(a->data_type() != DataType::BFLOAT16)
176     {
177         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(a, d);
178     }
179 
180     if(c != nullptr && !is_c_bias)
181     {
182         ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.depth_output_gemm3d() != 0);
183         ARM_COMPUTE_RETURN_ERROR_ON(gemm_info.reinterpret_input_as_3d());
184         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(c, d);
185         ARM_COMPUTE_RETURN_ERROR_ON_MSG(a->dimension(1) != c->dimension(1), "The C matrix must have the same number of rows as the matrix A");
186         ARM_COMPUTE_RETURN_ERROR_ON_MSG(b->dimension(0) != c->dimension(0), "The C matrix must have the same number of columns as the matrix B");
187     }
188 
189     if(d->total_size() != 0)
190     {
191         // For fixed format we are expecting some kind of blocked format for B/RHS so the dimension won't necessarily match the result matrix any more.
192         ARM_COMPUTE_RETURN_ERROR_ON(!gemm_info.fixed_format() && b->dimension(0) != d->dimension(0));
193         if(gemm_info.depth_output_gemm3d() != 0)
194         {
195             if(gemm_info.reinterpret_input_as_3d())
196             {
197                 ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1));
198                 ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(2) != d->dimension(2));
199             }
200             else
201             {
202                 ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1) * d->dimension(2));
203             }
204         }
205         else
206         {
207             ARM_COMPUTE_RETURN_ERROR_ON(a->dimension(1) != d->dimension(1));
208         }
209     }
210 
211     // Check if we need to run the optimized assembly kernel
212     cpu::AsmGemmInfo asm_info      = init_assembly_metadata(gemm_info);
213     const bool       run_optimised = bool(cpu::CpuGemmAssemblyDispatch::validate(a, b, is_c_bias ? c : nullptr, d, asm_info));
214 
215     if(!run_optimised)
216     {
217         ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.reinterpret_input_as_3d(), "CpuGemm cannot reinterpret the input tensor as 3D");
218         ARM_COMPUTE_RETURN_ERROR_ON_MSG(gemm_info.depth_output_gemm3d() != 0, "CpuGemm cannot reinterpret the output tensor as 3D");
219 
220         // Check if the first input tensor is a vector.
221         const bool run_vector_matrix_multiplication = a->dimension(1) < 2;
222         // Check if we need to reshape the matrix A and matrix B
223         const bool run_interleave_transpose = !run_vector_matrix_multiplication && !(gemm_info.reshape_b_only_on_first_run());
224 
225         // Arguments used by GEMMReshapeInfo
226         // If we pass the matrix A and matrix B reshaped to CpuGemmMatrixMultiplyKernel, we need to pass m, n, k, mult_transpose1xW_width and mult_interleave4x4_height to GEMMReshapeInfo
227         // in order to know how the matrices have been reshaped
228         const int m                         = a->dimension(1);
229         const int n                         = b->dimension(0);
230         const int k                         = a->dimension(0);
231         int       mult_transpose1xW_width   = 1;
232         int       mult_interleave4x4_height = 1;
233 
234         const GEMMReshapeInfo reshape_info = GEMMReshapeInfo(m, n, k, mult_transpose1xW_width, mult_interleave4x4_height, gemm_info.depth_output_gemm3d());
235 
236         const ITensorInfo *matrix_a_info = a;
237         const ITensorInfo *matrix_b_info = b;
238 
239         TensorInfo tmp_a_info{};
240         TensorInfo tmp_b_info{};
241         TensorInfo tmp_output_info = *d->clone();
242 
243         if(run_interleave_transpose)
244         {
245             matrix_a_info = &tmp_a_info;
246             matrix_b_info = &tmp_b_info;
247 
248             // Validate interleave kernel
249             auto_init_if_empty(tmp_a_info, a->clone()->set_tensor_shape(compute_interleaved_shape(*a, mult_interleave4x4_height, gemm_info.reinterpret_input_as_3d())));
250             ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmInterleave4x4Kernel::validate(a, &tmp_a_info));
251 
252             // Validate transpose kernel
253             auto_init_if_empty(tmp_b_info, b->clone()->set_tensor_shape(compute_transpose1xW_with_element_size_shape(*b, mult_transpose1xW_width)));
254             ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmTranspose1xWKernel::validate(b, &tmp_b_info));
255         }
256 
257         // Validate matrix multiply
258         auto_init_if_empty(tmp_output_info, matrix_a_info->clone()->set_tensor_shape(compute_mm_shape(*matrix_a_info, *matrix_b_info, run_interleave_transpose, reshape_info)));
259         ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixMultiplyKernel::validate(matrix_a_info, matrix_b_info, &tmp_output_info, alpha, run_interleave_transpose, reshape_info));
260 
261         if(c != nullptr && gemm_info.reshape_b_only_on_first_run())
262         {
263             ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuAdd::validate(&tmp_output_info, c, d, ConvertPolicy::SATURATE));
264         }
265     }
266 
267     // Validate matrix addition kernel
268     if(beta != 0 && c != nullptr && !is_c_bias)
269     {
270         ARM_COMPUTE_RETURN_ON_ERROR(cpu::kernels::CpuGemmMatrixAdditionKernel::validate(c, d, beta));
271     }
272 
273     // Validate activation
274     const ActivationLayerInfo &activation = gemm_info.activation_info();
275     if(activation.enabled())
276     {
277         ARM_COMPUTE_RETURN_ON_ERROR(cpu::CpuActivation::validate(d, nullptr, activation));
278     }
279 
280     return Status{};
281 }
282 
run(ITensorPack & tensors)283 void CpuGemm::run(ITensorPack &tensors)
284 {
285     prepare(tensors);
286 
287     auto a = tensors.get_const_tensor(ACL_SRC_0);
288     auto b = tensors.get_const_tensor(ACL_SRC_1);
289     auto c = tensors.get_const_tensor(ACL_SRC_2);
290     auto d = tensors.get_tensor(ACL_DST);
291 
292     if(_asm_glue && _asm_glue->is_configured())
293     {
294         // Pass c to asm dispatch only if it's the bias tensor
295         ITensorPack asm_pack = tensors;
296         asm_pack.add_const_tensor(ACL_SRC_2, (_reshape_b_only_on_first_run) ? c : nullptr);
297         _asm_glue->run(asm_pack);
298         if(_run_alpha_scale)
299         {
300             ITensorPack pack{ { ACL_SRC, d }, { ACL_DST, d } };
301             _alpha_scale_func->run(pack);
302         }
303     }
304     else
305     {
306         CpuAuxTensorHandler interleaved_a(offset_int_vec(InterleavedLHS), _tmp_a, tensors, true);
307         CpuAuxTensorHandler transposed_b(offset_int_vec(TransposedRHS), _tmp_b, tensors, true);
308         CpuAuxTensorHandler temp_d(offset_int_vec(TempResult), _tmp_d, tensors, true);
309 
310         ITensorPack mm_pack{ { ACL_SRC_0, a }, { ACL_SRC_1, b }, { ACL_DST, (_run_bias_addition) ? temp_d.get() : d } };
311         if(!_run_vector_matrix_multiplication)
312         {
313             // Run interleave kernel
314             ITensorPack interleave_pack{ { ACL_SRC, a }, { ACL_DST, interleaved_a.get() } };
315             NEScheduler::get().schedule_op(_interleave_kernel.get(), Window::DimY, _interleave_kernel->window(), interleave_pack);
316 
317             if(!_reshape_b_only_on_first_run)
318             {
319                 // Run transpose kernel
320                 ITensorPack transpose_pack{ { ACL_SRC, b }, { ACL_DST, transposed_b.get() } };
321                 NEScheduler::get().schedule_op(_transpose_kernel.get(), Window::DimY, _transpose_kernel->window(), transpose_pack);
322             }
323 
324             // Use reshaped matrices
325             mm_pack.add_const_tensor(ACL_SRC_0, interleaved_a.get());
326             mm_pack.add_const_tensor(ACL_SRC_1, transposed_b.get());
327         }
328 
329         NEScheduler::get().schedule_op(_mm_kernel.get(), _run_vector_matrix_multiplication ? Window::DimX : Window::DimY, _mm_kernel->window(), mm_pack);
330 
331         // Run bias addition kernel
332         if(_run_bias_addition)
333         {
334             ITensorPack pack{ { ACL_SRC_0, temp_d.get() }, { ACL_SRC_1, c }, { ACL_DST, d } };
335             _add_bias->run(pack);
336         }
337     }
338 
339     // Run matrix addition kernel
340     if(_run_addition)
341     {
342         ITensorPack c_add_pack{ { ACL_SRC, c }, { ACL_DST, d } };
343         NEScheduler::get().schedule_op(_ma_kernel.get(), Window::DimY, _ma_kernel->window(), c_add_pack);
344     }
345 
346     // Run activation function
347     if(_run_activation)
348     {
349         ITensorPack pack{ { ACL_SRC, d }, { ACL_DST, d } };
350         _activation_func->run(pack);
351     }
352 }
353 
prepare(ITensorPack & tensors)354 void CpuGemm::prepare(ITensorPack &tensors)
355 {
356     if(!_is_prepared)
357     {
358         if(_asm_glue && _asm_glue->is_configured())
359         {
360             _asm_glue->prepare(tensors);
361         }
362         else if(_reshape_b_only_on_first_run && !_run_vector_matrix_multiplication)
363         {
364             const ITensor *b     = tensors.get_const_tensor(ACL_SRC_1);
365             ITensor       *b_aux = utils::cast::polymorphic_cast<ITensor *>(tensors.get_tensor(offset_int_vec(TransposedRHS)));
366             ARM_COMPUTE_ERROR_ON_NULLPTR(b, b_aux);
367 
368             CpuAuxTensorHandler transposed_b(_tmp_b, *b_aux);
369             ITensorPack         transpose_pack{ { ACL_SRC, b }, { ACL_DST, transposed_b.get() } };
370             NEScheduler::get().schedule_op(_transpose_kernel.get(), Window::DimY, _transpose_kernel->window(), transpose_pack);
371         }
372         _is_prepared = true;
373     }
374 }
375 
workspace() const376 experimental::MemoryRequirements CpuGemm::workspace() const
377 {
378     return _aux_mem;
379 }
380 
has_opt_impl(arm_compute::WeightFormat & expected_weight_format,const ITensorInfo * a,const ITensorInfo * b,const ITensorInfo * c,const ITensorInfo * d,const GEMMInfo & gemm_info)381 Status CpuGemm::has_opt_impl(arm_compute::WeightFormat &expected_weight_format, const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *d,
382                              const GEMMInfo &gemm_info)
383 {
384     const cpu::AsmGemmInfo asm_info = init_assembly_metadata(gemm_info);
385 
386     return CpuGemmAssemblyDispatch::has_opt_impl(expected_weight_format, a, b, c, d, asm_info);
387 }
388 
isVarWeightsKernel() const389 bool CpuGemm::isVarWeightsKernel() const
390 {
391     return _asm_glue && _asm_glue->isVarWeightsKernel();
392 }
393 } // namespace cpu
394 } // namespace arm_compute
395