xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/ArithmeticOperationsFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE
25 #define ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "tests/AssetsLibrary.h"
30 #include "tests/Globals.h"
31 #include "tests/IAccessor.h"
32 #include "tests/framework/Asserts.h"
33 #include "tests/framework/Fixture.h"
34 #include "tests/validation/Helpers.h"
35 #include "tests/validation/reference/ActivationLayer.h"
36 #include "tests/validation/reference/ArithmeticOperations.h"
37 
38 namespace arm_compute
39 {
40 namespace test
41 {
42 namespace validation
43 {
44 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
45 class ArithmeticOperationGenericFixture : public framework::Fixture
46 {
47 public:
48     template <typename...>
setup(reference::ArithmeticOperation op,const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out,ActivationLayerInfo act_info,bool is_inplace)49     void setup(reference::ArithmeticOperation op, const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
50                QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, ActivationLayerInfo act_info, bool is_inplace)
51     {
52         _op         = op;
53         _act_info   = act_info;
54         _is_inplace = is_inplace;
55         _target     = compute_target(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
56         _reference  = compute_reference(shape0, shape1, data_type, convert_policy, qinfo0, qinfo1, qinfo_out);
57     }
58 
59 protected:
60     template <typename U>
fill(U && tensor,int i)61     void fill(U &&tensor, int i)
62     {
63         library->fill_tensor_uniform(tensor, i);
64     }
65 
compute_target(const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out)66     TensorType compute_target(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
67                               QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
68     {
69         // Create tensors
70         const TensorShape out_shape = TensorShape::broadcast_shape(shape0, shape1);
71         TensorType        ref_src1  = create_tensor<TensorType>(shape0, data_type, 1, qinfo0);
72         TensorType        ref_src2  = create_tensor<TensorType>(shape1, data_type, 1, qinfo1);
73         TensorType        dst       = create_tensor<TensorType>(out_shape, data_type, 1, qinfo_out);
74 
75         // Check whether do in-place computation and whether inputs are broadcast compatible
76         TensorType *actual_dst = &dst;
77         if(_is_inplace)
78         {
79             bool src1_is_inplace = !arm_compute::detail::have_different_dimensions(out_shape, shape0, 0) && (qinfo0 == qinfo_out);
80             bool src2_is_inplace = !arm_compute::detail::have_different_dimensions(out_shape, shape1, 0) && (qinfo1 == qinfo_out);
81             bool do_in_place     = out_shape.total_size() != 0 && (src1_is_inplace || src2_is_inplace);
82             ARM_COMPUTE_ASSERT(do_in_place);
83 
84             if(src1_is_inplace)
85             {
86                 actual_dst = &ref_src1;
87             }
88             else
89             {
90                 actual_dst = &ref_src2;
91             }
92         }
93 
94         // Create and configure function
95         FunctionType arith_op;
96         arith_op.configure(&ref_src1, &ref_src2, actual_dst, convert_policy, _act_info);
97 
98         ARM_COMPUTE_ASSERT(ref_src1.info()->is_resizable());
99         ARM_COMPUTE_ASSERT(ref_src2.info()->is_resizable());
100 
101         // Allocate tensors
102         ref_src1.allocator()->allocate();
103         ref_src2.allocator()->allocate();
104 
105         ARM_COMPUTE_ASSERT(!ref_src1.info()->is_resizable());
106         ARM_COMPUTE_ASSERT(!ref_src2.info()->is_resizable());
107 
108         // If don't do in-place computation, still need to allocate original dst
109         if(!_is_inplace)
110         {
111             ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
112             dst.allocator()->allocate();
113             ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
114         }
115 
116         // Fill tensors
117         fill(AccessorType(ref_src1), 0);
118         fill(AccessorType(ref_src2), 1);
119 
120         // Compute function
121         arith_op.run();
122 
123         return std::move(*actual_dst);
124     }
125 
compute_reference(const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out)126     SimpleTensor<T> compute_reference(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy,
127                                       QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out)
128     {
129         // Create reference
130         SimpleTensor<T> ref_src1{ shape0, data_type, 1, qinfo0 };
131         SimpleTensor<T> ref_src2{ shape1, data_type, 1, qinfo1 };
132         SimpleTensor<T> ref_dst{ TensorShape::broadcast_shape(shape0, shape1), data_type, 1, qinfo_out };
133 
134         // Fill reference
135         fill(ref_src1, 0);
136         fill(ref_src2, 1);
137 
138         auto result = reference::arithmetic_operation<T>(_op, ref_src1, ref_src2, ref_dst, convert_policy);
139         return _act_info.enabled() ? reference::activation_layer(result, _act_info, qinfo_out) : result;
140     }
141 
142     TensorType                     _target{};
143     SimpleTensor<T>                _reference{};
144     reference::ArithmeticOperation _op{ reference::ArithmeticOperation::ADD };
145     ActivationLayerInfo            _act_info{};
146     bool                           _is_inplace{};
147 };
148 
149 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
150 class ArithmeticAdditionBroadcastValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
151 {
152 public:
153     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,bool is_inplace)154     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
155     {
156         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
157                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
158     }
159 };
160 
161 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
162 class ArithmeticAdditionValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
163 {
164 public:
165     template <typename...>
setup(const TensorShape & shape,DataType data_type,ConvertPolicy convert_policy,bool is_inplace)166     void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
167     {
168         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
169                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
170     }
171 };
172 
173 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
174 class ArithmeticAdditionBroadcastValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
175 {
176 public:
177     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,ActivationLayerInfo act_info,bool is_inplace)178     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
179     {
180         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
181                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
182     }
183 };
184 
185 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
186 class ArithmeticAdditionValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
187 {
188 public:
189     template <typename...>
setup(const TensorShape & shape,DataType data_type,ConvertPolicy convert_policy,ActivationLayerInfo act_info,bool is_inplace)190     void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
191     {
192         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
193                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
194     }
195 };
196 
197 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
198 class ArithmeticAdditionValidationQuantizedFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
199 {
200 public:
201     template <typename...>
setup(const TensorShape & shape,DataType data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out,bool is_inplace)202     void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool is_inplace)
203 
204     {
205         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape, shape, data_type, convert_policy,
206                                                                                             qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
207     }
208 };
209 
210 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
211 class ArithmeticAdditionValidationQuantizedBroadcastFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
212 {
213 public:
214     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out,bool is_inplace)215     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out,
216                bool is_inplace)
217     {
218         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::ADD, shape0, shape1, data_type, convert_policy,
219                                                                                             qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
220     }
221 };
222 
223 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
224 class ArithmeticSubtractionBroadcastValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
225 {
226 public:
227     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,bool is_inplace)228     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
229     {
230         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
231                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
232     }
233 };
234 
235 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
236 class ArithmeticSubtractionBroadcastValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
237 {
238 public:
239     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,ActivationLayerInfo act_info,bool is_inplace)240     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info,
241                bool is_inplace)
242     {
243         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
244                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
245     }
246 };
247 
248 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
249 class ArithmeticSubtractionValidationFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
250 {
251 public:
252     template <typename...>
setup(const TensorShape & shape,DataType data_type,ConvertPolicy convert_policy,bool is_inplace)253     void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, bool is_inplace)
254     {
255         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
256                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), ActivationLayerInfo(), is_inplace);
257     }
258 };
259 
260 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
261 class ArithmeticSubtractionValidationFloatFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
262 {
263 public:
264     template <typename...>
setup(const TensorShape & shape,DataType data_type,ConvertPolicy convert_policy,ActivationLayerInfo act_info,bool is_inplace)265     void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, ActivationLayerInfo act_info, bool is_inplace)
266     {
267         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
268                                                                                             QuantizationInfo(), QuantizationInfo(), QuantizationInfo(), act_info, is_inplace);
269     }
270 };
271 
272 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
273 class ArithmeticSubtractionValidationQuantizedFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
274 {
275 public:
276     template <typename...>
setup(const TensorShape & shape,DataType data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out,bool is_inplace)277     void setup(const TensorShape &shape, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out, bool is_inplace)
278 
279     {
280         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape, shape, data_type, convert_policy,
281                                                                                             qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
282     }
283 };
284 
285 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
286 class ArithmeticSubtractionValidationQuantizedBroadcastFixture : public ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>
287 {
288 public:
289     template <typename...>
setup(const TensorShape & shape0,const TensorShape & shape1,DataType data_type,ConvertPolicy convert_policy,QuantizationInfo qinfo0,QuantizationInfo qinfo1,QuantizationInfo qinfo_out,bool is_inplace)290     void setup(const TensorShape &shape0, const TensorShape &shape1, DataType data_type, ConvertPolicy convert_policy, QuantizationInfo qinfo0, QuantizationInfo qinfo1, QuantizationInfo qinfo_out,
291                bool is_inplace)
292     {
293         ArithmeticOperationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(reference::ArithmeticOperation::SUB, shape0, shape1, data_type, convert_policy,
294                                                                                             qinfo0, qinfo1, qinfo_out, ActivationLayerInfo(), is_inplace);
295     }
296 };
297 } // namespace validation
298 } // namespace test
299 } // namespace arm_compute
300 #endif /* ARM_COMPUTE_TEST_ARITHMETIC_OPERATIONS_FIXTURE */
301