1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H
25 #define ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H
26 
27 #include "arm_compute/core/Types.h"
28 #include "arm_compute/runtime/IFunction.h"
29 #include "arm_compute/runtime/IMemoryManager.h"
30 #include "arm_compute/runtime/IWeightsManager.h"
31 
32 #include <memory>
33 
34 namespace arm_compute
35 {
36 class ITensor;
37 class ITensorInfo;
38 
39 /** Function to run Gemm on quantized types.
40  *
41  *  This function calls the following:
42  *
43  * -# @ref cpu::CpuGemmLowpMatrixMultiplyCore
44  */
45 class NEGEMMLowpMatrixMultiplyCore : public IFunction
46 {
47 public:
48     /** Constructor */
49     NEGEMMLowpMatrixMultiplyCore(std::shared_ptr<IMemoryManager> memory_manager = nullptr, IWeightsManager *weights_manager = nullptr);
50     /** Prevent instances of this class from being copied (As this class contains pointers) */
51     NEGEMMLowpMatrixMultiplyCore(const NEGEMMLowpMatrixMultiplyCore &) = delete;
52     /** Default move constructor */
53     NEGEMMLowpMatrixMultiplyCore(NEGEMMLowpMatrixMultiplyCore &&) = default;
54     /** Prevent instances of this class from being copied (As this class contains pointers) */
55     NEGEMMLowpMatrixMultiplyCore &operator=(const NEGEMMLowpMatrixMultiplyCore &) = delete;
56     /** Default move assignment operator */
57     NEGEMMLowpMatrixMultiplyCore &operator=(NEGEMMLowpMatrixMultiplyCore &&) = default;
58     /** Default destructor */
59     ~NEGEMMLowpMatrixMultiplyCore();
60     /** Initialise the kernel's inputs, output
61      *
62      * Valid data layouts:
63      * - NHWC
64      * - NCHW
65      *
66      * Valid data type configurations:
67      * |src0           |src1               |src2     |dst            |
68      * |:--------------|:------------------|:--------|:--------------|
69      * |QASYMM8        |QASYMM8            |S32      |QASYMM8        |
70      * |QASYMM8        |QSYMM8_PER_CHANNEL |S32      |QASYMM8        |
71      * |QASYMM8        |QSYMM8             |S32      |QASYMM8        |
72      * |QASYMM8        |QASYMM8            |S32      |S32            |
73      * |QASYMM8        |QSYMM8_PER_CHANNEL |S32      |S32            |
74      * |QASYMM8        |QSYMM8             |S32      |S32            |
75      * |QASYMM8_SIGNED |QASYMM8_SIGNED     |S32      |QASYMM8_SIGNED |
76      * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32      |QASYMM8_SIGNED |
77      * |QASYMM8_SIGNED |QSYMM8             |S32      |QASYMM8_SIGNED |
78      * |QASYMM8_SIGNED |QASYMM8_SIGNED     |S32      |S32            |
79      * |QASYMM8_SIGNED |QSYMM8_PER_CHANNEL |S32      |S32            |
80      * |QASYMM8_SIGNED |QSYMM8             |S32      |S32            |
81      *
82      * @note GEMM_LOWP:  low precision GEMM kernel
83      *  This kernel performs the following computations:
84      *
85      *  -# Convert a values from QASYMM8 to int32 and add a_offset to each of them.
86      *  -# Convert b values from QASYMM8 to int32 add b_offset to each of them.
87      *  -# Compute the matrix product of the resulting a * b in int32.
88      *
89      * @note The @p output type is S32 if @p gemm_info.type == GEMMLowpOutputStageType::NONE. It is QASYMM8/QASYMM8_SIGNED otherwise
90      *
91      * @param[in]  a         First input tensor  (Matrix A). Data type supported: QASYMM8/QASYMM8_SIGNED.
92      * @param[in]  b         Second input tensor (Matrix B). Data type supported: QASYMM8/QASYMM8_SIGNED/QSYMM8/QSYMM8_PER_CHANNEL.
93      * @param[in]  c         Third input tensor  (Matrix C). It can be a nullptr. Data type supported: S32
94      * @param[out] output    Output tensor. Data type supported: Data type supported: S32/QASYMM8/QASYMM8_SIGNED
95      * @param[in]  gemm_info (Optional) Specifies if the matrix A and/or matrix B have been reshaped and
96      *                       if the reshape of matrix B should be executed only for the first run
97      */
98     void configure(const ITensor *a, const ITensor *b, const ITensor *c, ITensor *output, const GEMMInfo &gemm_info = GEMMInfo());
99     /** Static function to check if given info will lead to a valid configuration of @ref NEGEMMLowpMatrixMultiplyCore
100      *
101      * Similar to @ref NEGEMMLowpMatrixMultiplyCore::configure()
102      *
103      * @return a status
104      */
105     static Status validate(const ITensorInfo *a, const ITensorInfo *b, const ITensorInfo *c, const ITensorInfo *output, const GEMMInfo &gemm_info = GEMMInfo());
106 
107     // Inherited methods overridden
108     void run() override;
109     void prepare() override;
110 
111 private:
112     struct Impl;
113     std::unique_ptr<Impl> _impl;
114 };
115 } // namespace arm_compute
116 #endif /*ARM_COMPUTE_NEGEMMLOWPMATRIXMULTIPLYCORE_H */
117