xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuGemmLowpMatrixReductionKernel.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CPU_GEMMLOWP_REDUCTION_KERNEL_H
25 #define ARM_COMPUTE_CPU_GEMMLOWP_REDUCTION_KERNEL_H
26 
27 #include "src/core/common/Macros.h"
28 #include "src/cpu/ICpuKernel.h"
29 
30 namespace arm_compute
31 {
32 // Forward declarations
33 struct GEMMLowpReductionKernelInfo;
34 namespace cpu
35 {
36 namespace kernels
37 {
38 /** Kernel used to compute the row-vectors of sums of all the entries in each row of Matrix A.
39  *
40  * @note This stage is needed to handle the offset of matrix product
41  *       https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
42  */
43 class CpuGemmLowpMatrixAReductionKernel : public ICpuKernel<CpuGemmLowpMatrixAReductionKernel>
44 {
45 public:
46     /** Default constructor */
47     CpuGemmLowpMatrixAReductionKernel() = default;
48     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmLowpMatrixAReductionKernel);
49     /** Initialise the kernel's input and output.
50      *
51      * @param[in]  src  Input tensor. Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM8/QSYMM8_PER_CHANNEL
52      * @param[out] dst  Output row-vector of sums of all the entries in each row of mtx_a. Data type supported: S32
53      * @param[in]  info Kernel metadata:
54      *                            - k            (num_mtx_a_cols) Number of matrix A columns
55      *                            - is_reshaped  (is_interleaved4x4) True if the matrix A has been interleaved4x4
56      *                            - scalar       Scalar value to multiply each reduced row by.
57      *                            - mul_byscalar True if each reduced column must be multiplied by a scalar value.
58      */
59     void configure(const ITensorInfo *src, ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info);
60     /** Static function to check if given info will lead to a valid configuration
61      *
62      * Similar to CpuGemmLowpMatrixAReductionKernel::configure()
63      *
64      * @return a status
65      */
66     static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info);
67 
68     // Inherited methods overridden:
69     void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
70     const char *name() const override;
71 
72 private:
73     /** Execution of the reduction kernel specialized on the input type
74      *
75      * @param[in] src    Input tensor
76      * @param[in] dst    Output tensor
77      * @param[in] window Execution window
78      */
79     template <typename T>
80     void run_internal(const ITensor *src, ITensor *dst, const Window &window);
81 
82     /** Common signature for all reduction functions
83      *
84      * @param[in]  src    Input tensor
85      * @param[out] dst    Output tensor
86      * @param[in]  window Region on which to execute the kernel. (Must be a valid region of the window returned by window()).
87      */
88     using CpuGemmLowpMatrixAReductionKernelPtr = void (CpuGemmLowpMatrixAReductionKernel::*)(const ITensor *src, ITensor *dst, const Window &window);
89 
90     CpuGemmLowpMatrixAReductionKernelPtr _func{ nullptr };
91     int32_t                              _k{ 0 };
92     int32_t                              _scalar{ 0 };
93     bool                                 _mul_by_scalar{ false };
94 };
95 
96 /** Kernel used to compute the row-vectors of sums of all the entries in each column of Matrix B.
97  *
98  * @note This stage is needed to handle the offset of matrix product
99  *       https://github.com/google/gemmlowp/blob/master/doc/low-precision.md
100  */
101 class CpuGemmLowpMatrixBReductionKernel : public ICpuKernel<CpuGemmLowpMatrixBReductionKernel>
102 {
103 public:
104     /** Default constructor */
105     CpuGemmLowpMatrixBReductionKernel() = default;
106     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuGemmLowpMatrixBReductionKernel);
107     /** Initialise the kernel's input and output.
108      *
109      * @param[in]  src  Input tensor. Data type supported: Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM8/QSYMM8_PER_CHANNEL
110      * @param[out] dst  Output row-vector of sums of all the entries in each column of mtx_b. Data type supported: S32
111      * @param[in]  info Kernel metadata:
112      *                            - k            (num_mtx_b_rows) Number of matrix B rows.
113      *                            - is_reshaped  (is_transposed1xW) True if the input tensor is transposed 1xW.
114      *                            - scalar       Scalar value to multiply each reduced row by.
115      *                            - mul_byscalar True if each reduced row must be multiplied by a scalar value.
116      */
117     void configure(const ITensorInfo *src, ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info);
118     /** Static function to check if given info will lead to a valid configuration
119      *
120      * Similar to CpuGemmLowpMatrixBReductionKernel::configure()
121      *
122      * @return a status
123      */
124     static Status validate(const ITensorInfo *src, const ITensorInfo *dst, const GEMMLowpReductionKernelInfo &info);
125 
126     // Inherited methods overridden:
127     void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
128     const char *name() const override;
129 
130 private:
131     /** Execution of the reduction kernel specialized on the input type
132      *
133      * @param[in] src    Input tensor
134      * @param[in] dst    Output tensor
135      * @param[in] window Execution window
136      * @param[in] info   Thread-related information
137      */
138     template <typename T>
139     void run_internal(const ITensor *src, ITensor *dst, const Window &window, const ThreadInfo &info);
140 
141     /** Common signature for all reduction functions
142      *
143      * @param[in]  src    Input tensor
144      * @param[out] dst    Output tensor
145      * @param[in]  window Region on which to execute the kernel. (Must be a valid region of the window returned by window()).
146      */
147     using CpuGemmLowpMatrixBReductionKernelPtr = void (CpuGemmLowpMatrixBReductionKernel::*)(const ITensor *src, ITensor *dst, const Window &window, const ThreadInfo &info);
148 
149     CpuGemmLowpMatrixBReductionKernelPtr _func{ nullptr };
150     int32_t                              _k{ 0 };
151     int32_t                              _scalar{ 0 };
152     bool                                 _mul_by_scalar{ false };
153 };
154 } // namespace kernels
155 } // namespace cpu
156 } // namespace arm_compute
157 #endif /* ARM_COMPUTE_CPU_GEMMLOWP_REDUCTION_KERNEL_H */
158