xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/CpuSoftmaxKernel.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H
25 #define ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H
26 
27 #include "src/core/common/Macros.h"
28 #include "src/cpu/ICpuKernel.h"
29 
30 namespace arm_compute
31 {
32 namespace cpu
33 {
34 namespace kernels
35 {
36 /** Interface for the identifying the max value of 1D Logits */
37 class CpuLogits1DMaxKernel : public ICpuKernel<CpuLogits1DMaxKernel>
38 {
39 private:
40     using SoftmaxLogits1DMaxKernelPtr = std::add_pointer<void(const ITensor *, ITensor *, const Window &)>::type;
41 
42 public:
43     CpuLogits1DMaxKernel() = default;
44     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DMaxKernel);
45     /** Set the input and output tensors.
46      *
47      * @param[in]  src Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
48      * @param[out] dst Destination tensor info. Data types supported: same as @p input
49      */
50     void configure(const ITensorInfo *src, ITensorInfo *dst);
51     /** Static function to check if given info will lead to a valid configuration
52      *
53      * Similar to CpuLogits1DMaxKernel::configure()
54      *
55      * @return a status
56      */
57     static Status validate(const ITensorInfo *src, const ITensorInfo *dst);
58 
59     // Inherited methods overridden:
60     void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
61     const char *name() const override;
62 
63     struct SoftmaxLogits1DMaxKernel
64     {
65         const char                  *name;
66         const DataTypeISASelectorPtr is_selected;
67         SoftmaxLogits1DMaxKernelPtr  ukernel;
68     };
69 
70     static const std::vector<SoftmaxLogits1DMaxKernel> &get_available_kernels();
71 
72 private:
73     SoftmaxLogits1DMaxKernelPtr _run_method{ nullptr };
74     std::string                 _name{};
75 };
76 
77 /** Interface for softmax computation for QASYMM8 with pre-computed max. */
78 template <bool IS_LOG = false>
79 class CpuLogits1DSoftmaxKernel : public ICpuKernel<CpuLogits1DSoftmaxKernel<IS_LOG>>
80 {
81 private:
82     using SoftmaxLogits1DKernelPtr = std::add_pointer<void(const ITensor *, const ITensor *, void *const, ITensor *, float, bool, const Window &)>::type;
83 
84 public:
85     CpuLogits1DSoftmaxKernel() = default;
86     ARM_COMPUTE_DISALLOW_COPY_ALLOW_MOVE(CpuLogits1DSoftmaxKernel);
87 
88     /** Set the input and output tensors.
89      *
90      * @param[in]  src  Source tensor info. Data types supported: QASYMM8/QASYMM8_SIGNED/F16/F32.
91      * @param[in]  max  Max values tensor info. Same shape as input with dimension 0 set to 1.
92      *                  Data types supported: same as @p input.
93      * @param[out] dst  Destination tensor info. Data types supported: same as @p input.
94      * @param[in]  beta A scaling factor for the exponent.
95      *
96      * @param      tmp    Auxiliary tensor info. Must be type F32 and same shape as the input.
97      */
98     void configure(const ITensorInfo *src, const ITensorInfo *max, ITensorInfo *dst, const float beta, ITensorInfo *tmp);
99     /** Static function to check if given info will lead to a valid configuration
100      *
101      * Similar to CpuLogits1DSoftmaxKernel::configure()
102      *
103      * @return a status
104      */
105     static Status validate(const ITensorInfo *src, const ITensorInfo *max,
106                            const ITensorInfo *dst, const float beta, const ITensorInfo *tmp);
107 
108     // Inherited methods overridden:
109     void run_op(ITensorPack &tensors, const Window &window, const ThreadInfo &info) override;
110     const char *name() const override;
111 
112     struct SoftmaxLogits1DKernel
113     {
114         const char                  *name;
115         const DataTypeISASelectorPtr is_selected;
116         SoftmaxLogits1DKernelPtr     ukernel;
117     };
118 
119     static const std::vector<SoftmaxLogits1DKernel> &get_available_kernels();
120 
121 private:
122     float                    _beta{ 1.0f };
123     SoftmaxLogits1DKernelPtr _run_method{ nullptr };
124     std::string              _name{};
125 };
126 } // namespace kernels
127 } // namespace cpu
128 } // namespace arm_compute
129 #endif /* ARM_COMPUTE_CPU_SOFTMAX_KERNEL_H */
130