xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuGemmMatrixMultiplyKernel.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "src/cpu/kernels/CpuGemmMatrixMultiplyKernel.h"
25 
26 #include "arm_compute/core/Helpers.h"
27 #include "arm_compute/core/TensorInfo.h"
28 #include "arm_compute/core/Types.h"
29 #include "arm_compute/core/Validate.h"
30 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
31 #include "src/core/CPP/Validate.h"
32 #include "src/core/common/Registrars.h"
33 #include "src/core/helpers/AutoConfiguration.h"
34 #include "src/core/helpers/WindowHelpers.h"
35 #include "src/cpu/kernels/gemm_matrix_mul/list.h"
36 
37 namespace arm_compute
38 {
39 namespace cpu
40 {
41 namespace kernels
42 {
43 namespace
44 {
45 static const std::vector<CpuGemmMatrixMultiplyKernel::GemmMatrixMulKernel> available_kernels =
46 {
47     {
48         "neon_fp32_gemm_matrix_mul",
49         [](const DataTypeISASelectorData & data)
__anon77ad390e0202() 50         {
51             return (data.dt == DataType::F32);
52         },
53         REGISTER_FP32_NEON(neon_fp32_gemm_matrix_mul)
54     },
55     {
56         "neon_fp16_gemm_matrix_mul",
57         [](const DataTypeISASelectorData & data)
__anon77ad390e0302() 58         {
59             return (data.dt == DataType::F16) && data.isa.fp16;
60         },
61         REGISTER_FP16_NEON(neon_fp16_gemm_matrix_mul)
62     },
63 };
64 
validate_arguments(const ITensorInfo * lhs,const ITensorInfo * rhs,const ITensorInfo * dst,float alpha,bool is_interleaved,const GEMMReshapeInfo & reshape_info)65 inline Status validate_arguments(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
66 {
67     ARM_COMPUTE_UNUSED(alpha);
68 
69     ARM_COMPUTE_RETURN_ERROR_ON_CPU_F16_UNSUPPORTED(lhs);
70     ARM_COMPUTE_RETURN_ERROR_ON_DATA_TYPE_CHANNEL_NOT_IN(lhs, 1, DataType::F16, DataType::F32);
71     ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, rhs, dst);
72 
73     if(!is_interleaved)
74     {
75         ARM_COMPUTE_RETURN_ERROR_ON(lhs->dimension(0) != rhs->dimension(1));
76 
77         if(dst->total_size() != 0)
78         {
79             ARM_COMPUTE_RETURN_ERROR_ON(rhs->dimension(0) != dst->dimension(0));
80             ARM_COMPUTE_RETURN_ERROR_ON(lhs->dimension(1) != dst->dimension(1));
81             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
82         }
83     }
84     else
85     {
86         const int m                         = reshape_info.m();
87         const int n                         = reshape_info.n();
88         const int k                         = reshape_info.k();
89         const int mult_transpose1xW_width   = reshape_info.mult_transpose1xW_width();
90         const int mult_interleave4x4_height = reshape_info.mult_interleave4x4_height();
91 
92         /* Interleave */
93         TensorShape tensor_shape0{ lhs->tensor_shape() };
94         tensor_shape0.set(0, k);
95         tensor_shape0.set(1, m);
96 
97         const TensorInfo tensor_info0          = lhs->clone()->set_tensor_shape(tensor_shape0);
98         const TensorInfo tensor_info_reshaped0 = lhs->clone()->set_tensor_shape(misc::shape_calculator::compute_interleaved_shape(tensor_info0, mult_interleave4x4_height));
99         ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(lhs, &tensor_info_reshaped0);
100 
101         if(n != 0) /* Transpose */
102         {
103             TensorShape tensor_shape1{ rhs->tensor_shape() };
104             tensor_shape1.set(0, n);
105             tensor_shape1.set(1, k);
106 
107             const TensorInfo tensor_info1          = rhs->clone()->set_tensor_shape(tensor_shape1);
108             const TensorInfo tensor_info_reshaped1 = rhs->clone()->set_tensor_shape(misc::shape_calculator::compute_transpose1xW_with_element_size_shape(tensor_info1, mult_transpose1xW_width));
109             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_SHAPES(rhs, &tensor_info_reshaped1);
110         }
111 
112         if(dst->total_size() != 0)
113         {
114             if(n != 0)
115             {
116                 ARM_COMPUTE_RETURN_ERROR_ON(dst->dimension(0) != static_cast<size_t>(n));
117             }
118             ARM_COMPUTE_RETURN_ERROR_ON(dst->dimension(1) != static_cast<size_t>(m));
119             ARM_COMPUTE_RETURN_ERROR_ON_MISMATCHING_DATA_TYPES(lhs, dst);
120         }
121     }
122 
123     return Status{};
124 }
125 
126 } // namespace
127 
configure(const ITensorInfo * lhs,const ITensorInfo * rhs,ITensorInfo * dst,float alpha,bool is_interleaved,const GEMMReshapeInfo & reshape_info)128 void CpuGemmMatrixMultiplyKernel::configure(const ITensorInfo *lhs, const ITensorInfo *rhs, ITensorInfo *dst, float alpha, bool is_interleaved, const GEMMReshapeInfo &reshape_info)
129 {
130     ARM_COMPUTE_ERROR_ON_NULLPTR(lhs, rhs, dst);
131 
132     // dst tensor auto inizialitation if not yet initialized
133     TensorShape tensor_shape{ lhs->tensor_shape() };
134     tensor_shape.set(0, is_interleaved ? reshape_info.n() : rhs->dimension(0));
135     tensor_shape.set(1, is_interleaved ? reshape_info.m() : lhs->dimension(1));
136 
137     auto_init_if_empty(*dst, lhs->clone()->set_tensor_shape(tensor_shape));
138 
139     // Perform validate step
140     ARM_COMPUTE_ERROR_THROW_ON(validate_arguments(lhs, rhs, dst, alpha, is_interleaved, reshape_info));
141 
142     _alpha = alpha;
143 
144     // Configure kernel window
145     Window win{};
146 
147     // Check if the dst tensor is a vector. If so,the kernel runs the vector-matrix multiplication
148     const bool is_dst_vector = (dst->dimension(1) == 1);
149     if(is_dst_vector)
150     {
151         const unsigned int num_elems_processed_per_iteration_x = (lhs->data_type() == DataType::F32) ? 16 : 32;
152 
153         win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x));
154     }
155     else
156     {
157         constexpr unsigned int num_elems_processed_per_iteration_x = 8;
158         constexpr unsigned int num_elems_processed_per_iteration_y = 4;
159 
160         win = calculate_max_window(*dst, Steps(num_elems_processed_per_iteration_x, num_elems_processed_per_iteration_y));
161     }
162 
163     const auto uk = CpuGemmMatrixMultiplyKernel::get_implementation(DataTypeISASelectorData{ lhs->data_type(), CPUInfo::get().get_isa() });
164     ARM_COMPUTE_ERROR_ON_NULLPTR(uk);
165     _func = uk->ukernel;
166 
167     ICPPKernel::configure(win);
168 }
169 
validate(const ITensorInfo * lhs,const ITensorInfo * rhs,const ITensorInfo * dst,float alpha,bool is_interleaved,const GEMMReshapeInfo & reshape_info)170 Status CpuGemmMatrixMultiplyKernel::validate(const ITensorInfo *lhs, const ITensorInfo *rhs, const ITensorInfo *dst, float alpha, bool is_interleaved,
171                                              const GEMMReshapeInfo &reshape_info)
172 {
173     ARM_COMPUTE_RETURN_ON_ERROR(validate_arguments(lhs, rhs, dst, alpha, is_interleaved, reshape_info));
174 
175     return Status{};
176 }
177 
run_op(ITensorPack & tensors,const Window & window,const ThreadInfo & info)178 void CpuGemmMatrixMultiplyKernel::run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info)
179 {
180     ARM_COMPUTE_ERROR_ON_UNCONFIGURED_KERNEL(this);
181     ARM_COMPUTE_ERROR_ON_INVALID_SUBWINDOW(IKernel::window(), window);
182     ARM_COMPUTE_ERROR_ON(tensors.empty());
183     ARM_COMPUTE_ERROR_ON(_func == nullptr);
184 
185     const ITensor *lhs = tensors.get_const_tensor(TensorType::ACL_SRC_0);
186     const ITensor *rhs = tensors.get_const_tensor(TensorType::ACL_SRC_1);
187     ITensor       *dst = tensors.get_tensor(TensorType::ACL_DST);
188 
189     const bool is_dst_vector = (dst->info()->dimension(1) == 1);
190     (*_func)(lhs, rhs, dst, window, info, _alpha, is_dst_vector);
191 }
192 
name() const193 const char *CpuGemmMatrixMultiplyKernel::name() const
194 {
195     return "CpuGemmMatrixMultiplyKernel";
196 }
197 
get_available_kernels()198 const std::vector<CpuGemmMatrixMultiplyKernel::GemmMatrixMulKernel> &CpuGemmMatrixMultiplyKernel::get_available_kernels()
199 {
200     return available_kernels;
201 }
202 } // namespace kernels
203 } // namespace cpu
204 } // namespace arm_compute
205