xref: /aosp_15_r20/external/ComputeLibrary/src/cpu/kernels/elementwise_binary/generic/sve/impl.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2021-2022 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef SRC_CORE_SVE_KERNELS_ELEMENTWISE_LIST_H
25 #define SRC_CORE_SVE_KERNELS_ELEMENTWISE_LIST_H
26 
27 #include "arm_compute/core/Helpers.h"
28 #include "src/core/NEON/wrapper/intrinsics/intrinsics.h"
29 #include "src/core/NEON/wrapper/svtraits.h"
30 
31 namespace arm_compute
32 {
33 namespace cpu
34 {
35 using namespace arm_compute::wrapper;
36 
37 template <typename VectorType>
elementwise_pow(svbool_t & pg,const VectorType & a,const VectorType & b)38 VectorType elementwise_pow(svbool_t &pg, const VectorType &a, const VectorType &b)
39 {
40     return svpow_z(pg, a, b);
41 }
42 
43 template <typename VectorType>
elementwise_div(svbool_t & pg,const VectorType & a,const VectorType & b)44 VectorType elementwise_div(svbool_t &pg, const VectorType &a, const VectorType &b)
45 {
46     return svdiv_z(pg, a, b);
47 }
48 
49 template <uint32_t bytewidth>
narrow_to_byte_predicate(svbool_t pg)50 svbool_t narrow_to_byte_predicate(svbool_t pg)
51 {
52     const auto all_false = svpfalse();
53 
54     switch(bytewidth)
55     {
56         case 8:
57             pg = svuzp1_b32(pg, all_false);
58         /* fall through */
59         case 4:
60             pg = svuzp1_b16(pg, all_false);
61         /* fall through */
62         case 2:
63             pg = svuzp1_b8(pg, all_false);
64         /* fall through */
65         default:
66             break;
67     }
68     return pg;
69 }
70 
71 template <typename VectorType>
elementwise_arithmetic_op(svbool_t & pg,const VectorType & a,const VectorType & b,ArithmeticOperation op)72 VectorType elementwise_arithmetic_op(svbool_t &pg, const VectorType &a, const VectorType &b, ArithmeticOperation op)
73 {
74     using ScalarType = typename wrapper::sve_scalar<VectorType>::type;
75     VectorType res{};
76 
77     switch(op)
78     {
79         case ArithmeticOperation::MAX:
80             res = svmax_z(pg, a, b);
81             break;
82         case ArithmeticOperation::MIN:
83             res = svmin_z(pg, a, b);
84             break;
85         case ArithmeticOperation::SQUARED_DIFF:
86         {
87             const auto tmp = svsub_z(pg, a, b);
88             res            = svmul_z(pg, tmp, tmp);
89             break;
90         }
91         case ArithmeticOperation::PRELU:
92         {
93             const auto zero = svdup_n(ScalarType(0));
94             const auto tmp  = svmul_z(pg, a, b);
95             const auto gt   = svcmpgt(pg, a, zero);
96             res             = svsel(gt, a, tmp);
97             break;
98         }
99         case ArithmeticOperation::DIV:
100         {
101             res = elementwise_div(pg, a, b);
102             break;
103         }
104         case ArithmeticOperation::POWER:
105         {
106             res = elementwise_pow(pg, a, b);
107             break;
108         }
109         default:
110             ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
111     }
112 
113     return res;
114 }
115 
116 template <typename InputVectorType, typename OutputVectorType>
elementwise_comparison_op(svbool_t & pg,const InputVectorType & a,const InputVectorType & b,ComparisonOperation op)117 OutputVectorType elementwise_comparison_op(svbool_t &pg, const InputVectorType &a, const InputVectorType &b, ComparisonOperation op)
118 {
119     svbool_t selection_vector{};
120 
121     switch(op)
122     {
123         case ComparisonOperation::Equal:
124             selection_vector = svcmpeq(pg, a, b);
125             break;
126         case ComparisonOperation::NotEqual:
127             selection_vector = svcmpne(pg, a, b);
128             break;
129         case ComparisonOperation::Greater:
130             selection_vector = svcmpgt(pg, a, b);
131             break;
132         case ComparisonOperation::GreaterEqual:
133             selection_vector = svcmpge(pg, a, b);
134             break;
135         case ComparisonOperation::Less:
136             selection_vector = svcmplt(pg, a, b);
137             break;
138         case ComparisonOperation::LessEqual:
139             selection_vector = svcmple(pg, a, b);
140             break;
141         default:
142             ARM_COMPUTE_ERROR("NOT_SUPPORTED!");
143     }
144 
145     using InputScalarType = typename wrapper::sve_scalar<InputVectorType>::type;
146     selection_vector      = narrow_to_byte_predicate<sizeof(InputScalarType)>(selection_vector);
147 
148     using OutputScalarType  = typename wrapper::sve_scalar<OutputVectorType>::type;
149     const auto false_vector = svdup_n(static_cast<OutputScalarType>((uint32_t)0));
150     const auto true_vector  = svdup_n(static_cast<OutputScalarType>(~(uint32_t)0));
151     auto       ret          = svsel(selection_vector, true_vector, false_vector);
152 
153     return ret;
154 }
155 
156 template <typename ScalarType>
157 void elementwise_arithmetic_op(const ITensor *in1, const ITensor *in2, ITensor *out, ArithmeticOperation op, const Window &window);
158 
159 template <typename ScalarType, typename OutputScalarType = uint8_t>
160 void elementwise_comparison_op(const ITensor *in1, const ITensor *in2, ITensor *out, ComparisonOperation op, const Window &window);
161 } // namespace cpu
162 } // namespace arm_compute
163 #endif /* SRC_CORE_SVE_KERNELS_ELEMENTWISE_LIST_H */
164