xref: /aosp_15_r20/external/ComputeLibrary/src/gpu/cl/operators/ClGemmLowpMatrixMultiplyCore.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H
25 #define ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H
26 
27 #include "arm_compute/core/TensorInfo.h"
28 #include "arm_compute/runtime/CL/CLTypes.h"
29 
30 #include "src/gpu/cl/ClCompileContext.h"
31 #include "src/gpu/cl/IClOperator.h"
32 
33 namespace arm_compute
34 {
35 namespace opencl
36 {
37 namespace kernels
38 {
39 // Forward declarations
40 class ClCastKernel;
41 class ClGemmLowpMatrixMultiplyNativeKernel;
42 class ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel;
43 class ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel;
44 class ClGemmReshapeRhsMatrixKernel;
45 class ClGemmLowpMatrixAReductionKernel;
46 class ClGemmLowpMatrixBReductionKernel;
47 class ClGemmLowpOffsetContributionKernel;
48 class ClGemmLowpOffsetContributionOutputStageKernel;
49 } // namespace kernels
50 
51 /** Basic function to execute GEMMLowpMatrixMultiplyCore on OpenCL. */
52 class ClGemmLowpMatrixMultiplyCore : public IClOperator
53 {
54 public:
55     ClGemmLowpMatrixMultiplyCore();
56     ~ClGemmLowpMatrixMultiplyCore();
57     /** Initialise the kernel's inputs, output
58      *
59      * Valid data layouts:
60      * - NHWC
61      * - NCHW
62      *
63      * Valid data type configurations:
64      * |src0           |src1               |src2     |dst            |
65      * |:--------------|:------------------|:--------|:--------------|
66      * |QASYMM8        |QASYMM8            |S32      |QASYMM8        |
67      * |QASYMM8        |QSYMM8_PER_CHANNEL |S32      |QASYMM8        |
68      * |QASYMM8        |QSYMM8             |S32      |QASYMM8        |
69      * |QASYMM8        |QASYMM8            |S32      |S32            |
70      * |QASYMM8        |QSYMM8_PER_CHANNEL |S32      |S32            |
71      * |QASYMM8        |QSYMM8             |S32      |S32            |
72      * |QASYMM8_SIGNED |QASYMM8_SIGNED     |S32      |QASYMM8_SIGNED |
73      * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32      |QASYMM8_SIGNED |
74      * |QASYMM8_SIGNED |QSYMM8             |S32      |QASYMM8_SIGNED |
75      * |QASYMM8_SIGNED |QASYMM8_SIGNED     |S32      |S32            |
76      * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32      |S32            |
77      * |QASYMM8_SIGNED |QSYMM8             |S32      |S32            |
78      *
79      * @note GEMMLowp:  low precision GEMM kernel. [A * B + C]
80      *  This kernel performs the following computations:
81      *
82      *  -# Convert a values from 8-bit quantized to int32 and add a_offset to each of them.
83      *  -# Convert b values from 8-bit quantized to int32 and add b_offset to each of them.
84      *  -# Compute the matrix product of the resulting a * b in int32.
85      *  -# Quantize to uint8 if gemm_info.gemmlowp_output_stage != NONE
86      *
87      * @param[in]  compile_context The compile context to be used.
88      * @param[in]  a               First input tensor  (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED.
89      * @param[in]  b               Second input tensor (Matrix B). Data type supported: same as @p a
90      * @param[in]  c               Third input tensor  (Matrix C). It can be a nullptr. Data type supported: S32
91      * @param[out] output          Output tensor. Data type supported: S32 or QASYMM8/QASYMM8_SIGNED if gemm_info.gemmlowp_output_stage != NONE
92      * @param[in]  gemm_info       (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
93      *                       if the reshape of matrix B should be executed only for the first run
94      */
95     void configure(const CLCompileContext &compile_context, ITensorInfo *a, ITensorInfo *b, ITensorInfo *c, ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo());
96     /** Static function to check if given info will lead to a valid configuration
97      *
98      * Similar to ClGemmLowpMatrixMultiplyCore::configure()
99      *
100      * @return a status
101      */
102     static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo());
103 
104     // Inherited methods overridden:
105     void run(ITensorPack &tensors) override;
106     void prepare(ITensorPack &constants) override;
107     experimental::MemoryRequirements workspace() const override;
108 
109 private:
110     enum AuxTensorIdx
111     {
112         ResultS32 = 0,
113         RhsQAsymm8,
114         RhsReshape,
115         VecSumCol,
116         VecSumRow,
117         Multipliers,
118         Shifts,
119         Count
120     };
121 
122 private:
123     // Kernels used
124     std::unique_ptr<kernels::ClCastKernel>                                      _weights_to_qasymm8;
125     std::unique_ptr<kernels::ClGemmLowpMatrixMultiplyNativeKernel>              _mm_native_kernel;
126     std::unique_ptr<kernels::ClGemmLowpMatrixMultiplyReshapedOnlyRhsKernel>     _mm_reshaped_only_rhs_kernel;
127     std::unique_ptr<kernels::ClGemmLowpMatrixMultiplyReshapedOnlyRhsMMULKernel> _mm_reshaped_only_rhs_mmul_kernel;
128     std::unique_ptr<kernels::ClGemmReshapeRhsMatrixKernel>                      _mtx_b_reshape_kernel;
129     std::unique_ptr<kernels::ClGemmLowpMatrixAReductionKernel>                  _mtx_a_reduction_kernel;
130     std::unique_ptr<kernels::ClGemmLowpMatrixBReductionKernel>                  _mtx_b_reduction_kernel;
131     std::unique_ptr<kernels::ClGemmLowpOffsetContributionKernel>                _offset_contribution_kernel;
132     std::unique_ptr<kernels::ClGemmLowpOffsetContributionOutputStageKernel>     _offset_contribution_output_stage_kernel;
133 
134     // Temporary tensors
135     TensorInfo _qasymm8_weights{};
136     TensorInfo _vector_sum_col{};
137     TensorInfo _vector_sum_row{};
138     TensorInfo _tmp_b{};
139     TensorInfo _mm_result_s32{};
140     TensorInfo _gemm_output_stage_multipliers{};
141     TensorInfo _gemm_output_stage_shifts{};
142 
143     int32_t          _a_offset{ 0 };
144     int32_t          _b_offset{ 0 };
145     bool             _reshape_b_only_on_first_run{ false };
146     bool             _run_output_stage{ false };
147     bool             _convert_to_qasymm8{ false };
148     bool             _run_offset_contribution{ false };
149     bool             _is_prepared{ false };
150     GEMMInfo         _gemm_info{};
151     CLGEMMKernelType _gemm_kernel_type{};
152 
153     experimental::MemoryRequirements _aux_mem{};
154 };
155 } // namespace opencl
156 } // namespace arm_compute
157 #endif /* ARM_COMPUTE_CL_GEMMLOWP_MATRIXMULTIPLY_CORE_H */