xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/ElementwiseUnaryFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2018-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_ELEMENTWISE_UNARY_FIXTURE
25 #define ARM_COMPUTE_TEST_ELEMENTWISE_UNARY_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "tests/AssetsLibrary.h"
30 #include "tests/Globals.h"
31 #include "tests/IAccessor.h"
32 #include "tests/framework/Asserts.h"
33 #include "tests/framework/Fixture.h"
34 #include "tests/validation/reference/ElementwiseUnary.h"
35 
36 namespace arm_compute
37 {
38 namespace test
39 {
40 namespace validation
41 {
42 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
43 class ElementWiseUnaryValidationFixture : public framework::Fixture
44 {
45 public:
46     template <typename...>
47     void setup(TensorShape input_shape, DataType input_data_type, bool in_place, ElementWiseUnary op, bool use_dynamic_shape = false)
48     {
49         _op                = op;
50         _target            = compute_target(input_shape, input_data_type, in_place);
51         _reference         = compute_reference(input_shape, input_data_type);
52         _use_dynamic_shape = use_dynamic_shape;
53     }
54 
55 protected:
56     template <typename U>
fill(U && tensor,int i,DataType data_type)57     void fill(U &&tensor, int i, DataType data_type)
58     {
59         using FloatType             = typename std::conditional < std::is_same<T, half>::value || std::is_floating_point<T>::value, T, float >::type;
60         using FloatDistributionType = typename std::conditional<std::is_same<T, half>::value, arm_compute::utils::uniform_real_distribution_16bit<T>, std::uniform_real_distribution<FloatType>>::type;
61 
62         switch(_op)
63         {
64             case ElementWiseUnary::EXP:
65             {
66                 FloatDistributionType distribution{ FloatType(-1.0f), FloatType(1.0f) };
67                 library->fill(tensor, distribution, i);
68                 break;
69             }
70             case ElementWiseUnary::RSQRT:
71             {
72                 FloatDistributionType distribution{ FloatType(1.0f), FloatType(2.0f) };
73                 library->fill(tensor, distribution, i);
74                 break;
75             }
76             case ElementWiseUnary::ABS:
77             case ElementWiseUnary::NEG:
78             {
79                 switch(data_type)
80                 {
81                     case DataType::F16:
82                     {
83                         arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -2.0f, 2.0f };
84                         library->fill(tensor, distribution, i);
85                         break;
86                     }
87                     case DataType::F32:
88                     {
89                         FloatDistributionType distribution{ FloatType(-2.0f), FloatType(2.0f) };
90                         library->fill(tensor, distribution, i);
91                         break;
92                     }
93                     case DataType::S32:
94                     {
95                         std::uniform_int_distribution<int32_t> distribution(-100, 100);
96                         library->fill(tensor, distribution, i);
97                         break;
98                     }
99                     default:
100                         ARM_COMPUTE_ERROR("DataType for Elementwise Negation Not implemented");
101                 }
102                 break;
103             }
104             case ElementWiseUnary::LOG:
105             {
106                 FloatDistributionType distribution{ FloatType(0.0000001f), FloatType(100.0f) };
107                 library->fill(tensor, distribution, i);
108                 break;
109             }
110             case ElementWiseUnary::SIN:
111             {
112                 FloatDistributionType distribution{ FloatType(-100.00f), FloatType(100.00f) };
113                 library->fill(tensor, distribution, i);
114                 break;
115             }
116             case ElementWiseUnary::ROUND:
117             {
118                 FloatDistributionType distribution{ FloatType(100.0f), FloatType(-100.0f) };
119                 library->fill(tensor, distribution, i);
120                 break;
121             }
122             default:
123                 ARM_COMPUTE_ERROR("Not implemented");
124         }
125     }
126 
compute_target(const TensorShape & shape,DataType data_type,bool in_place)127     TensorType compute_target(const TensorShape &shape, DataType data_type, bool in_place)
128     {
129         // Create tensors
130         TensorType src = create_tensor<TensorType>(shape, data_type);
131         TensorType dst = create_tensor<TensorType>(shape, data_type);
132 
133         TensorType *actual_dst = in_place ? &src : &dst;
134 
135         // if _use_dynamic_shape is true, this fixture will test scenario for dynamic shapes.
136         // - At configure time, all input tensors are marked as dynamic using set_tensor_dynamic()
137         // - After configure, tensors are marked as static for run using set_tensor_static()
138         // - The tensors with static shape are given to run()
139         if(_use_dynamic_shape)
140         {
141             set_tensor_dynamic(src);
142         }
143 
144         // Create and configure function
145         FunctionType elwiseunary_layer;
146         elwiseunary_layer.configure(&src, actual_dst);
147 
148         if(_use_dynamic_shape)
149         {
150             set_tensor_static(src);
151         }
152 
153         ARM_COMPUTE_ASSERT(src.info()->is_resizable());
154         src.allocator()->allocate();
155         ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
156         if(!in_place)
157         {
158             ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
159             dst.allocator()->allocate();
160             ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
161         }
162 
163         // Fill tensors
164         fill(AccessorType(src), 0, data_type);
165 
166         // Compute function
167         elwiseunary_layer.run();
168 
169         if(in_place)
170         {
171             return src;
172         }
173         else
174         {
175             return dst;
176         }
177     }
178 
compute_reference(const TensorShape & shape,DataType data_type)179     SimpleTensor<T> compute_reference(const TensorShape &shape, DataType data_type)
180     {
181         // Create reference
182         SimpleTensor<T> src{ shape, data_type };
183 
184         // Fill reference
185         fill(src, 0, data_type);
186 
187         return reference::elementwise_unary<T>(src, _op);
188     }
189 
190     TensorType       _target{};
191     SimpleTensor<T>  _reference{};
192     ElementWiseUnary _op{};
193     bool             _use_dynamic_shape{ false };
194 };
195 
196 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
197 class RsqrtValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
198 {
199 public:
200     template <typename...>
setup(const TensorShape & shape,DataType data_type)201     void setup(const TensorShape &shape, DataType data_type)
202     {
203         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT);
204     }
205 };
206 
207 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
208 class RsqrtDynamicShapeValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
209 {
210 public:
211     template <typename...>
setup(const TensorShape & shape,DataType data_type)212     void setup(const TensorShape &shape, DataType data_type)
213     {
214         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::RSQRT, true);
215     }
216 };
217 
218 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
219 class ExpValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
220 {
221 public:
222     template <typename...>
setup(const TensorShape & shape,DataType data_type)223     void setup(const TensorShape &shape, DataType data_type)
224     {
225         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::EXP);
226     }
227 };
228 
229 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
230 class NegValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
231 {
232 public:
233     template <typename...>
setup(const TensorShape & shape,DataType data_type)234     void setup(const TensorShape &shape, DataType data_type)
235     {
236         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::NEG);
237     }
238 };
239 
240 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
241 class NegValidationInPlaceFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
242 {
243 public:
244     template <typename...>
setup(const TensorShape & shape,DataType data_type,bool in_place)245     void setup(const TensorShape &shape, DataType data_type, bool in_place)
246     {
247         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, in_place, ElementWiseUnary::NEG);
248     }
249 };
250 
251 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
252 class LogValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
253 {
254 public:
255     template <typename...>
setup(const TensorShape & shape,DataType data_type)256     void setup(const TensorShape &shape, DataType data_type)
257     {
258         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::LOG);
259     }
260 };
261 
262 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
263 class AbsValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
264 {
265 public:
266     template <typename...>
setup(const TensorShape & shape,DataType data_type)267     void setup(const TensorShape &shape, DataType data_type)
268     {
269         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ABS);
270     }
271 };
272 
273 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
274 class SinValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
275 {
276 public:
277     template <typename...>
setup(const TensorShape & shape,DataType data_type)278     void setup(const TensorShape &shape, DataType data_type)
279     {
280         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::SIN);
281     }
282 };
283 
284 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
285 class RoundValidationFixture : public ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>
286 {
287 public:
288     template <typename...>
setup(const TensorShape & shape,DataType data_type)289     void setup(const TensorShape &shape, DataType data_type)
290     {
291         ElementWiseUnaryValidationFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, data_type, false, ElementWiseUnary::ROUND);
292     }
293 };
294 } // namespace validation
295 } // namespace test
296 } // namespace arm_compute
297 #endif /* ARM_COMPUTE_TEST_ELEMENTWISE_UNARY_FIXTURE */
298