xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/PoolingLayerFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2017-2021 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #ifndef ARM_COMPUTE_TEST_POOLING_LAYER_FIXTURE
25 #define ARM_COMPUTE_TEST_POOLING_LAYER_FIXTURE
26 
27 #include "arm_compute/core/TensorShape.h"
28 #include "arm_compute/core/Types.h"
29 #include "arm_compute/core/utils/misc/ShapeCalculator.h"
30 #include "arm_compute/runtime/Tensor.h"
31 #include "tests/AssetsLibrary.h"
32 #include "tests/Globals.h"
33 #include "tests/IAccessor.h"
34 #include "tests/framework/Asserts.h"
35 #include "tests/framework/Fixture.h"
36 #include "tests/validation/reference/PoolingLayer.h"
37 #include <random>
38 namespace arm_compute
39 {
40 namespace test
41 {
42 namespace validation
43 {
44 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
45 class PoolingLayerValidationGenericFixture : public framework::Fixture
46 {
47 public:
48     template <typename...>
49     void setup(TensorShape shape, PoolingLayerInfo pool_info, DataType data_type, DataLayout data_layout, bool indices = false,
50                QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo(), bool mixed_layout = false)
51     {
52         _mixed_layout = mixed_layout;
53         _pool_info    = pool_info;
54         _target       = compute_target(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices);
55         _reference    = compute_reference(shape, pool_info, data_type, data_layout, input_qinfo, output_qinfo, indices);
56     }
57 
58 protected:
mix_layout(FunctionType & layer,TensorType & src,TensorType & dst)59     void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst)
60     {
61         const DataLayout data_layout = src.info()->data_layout();
62         // Test Multi DataLayout graph cases, when the data layout changes after configure
63         src.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW);
64         dst.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW);
65 
66         // Compute Convolution function
67         layer.run();
68 
69         // Reinstating original data layout for the test suite to properly check the values
70         src.info()->set_data_layout(data_layout);
71         dst.info()->set_data_layout(data_layout);
72     }
73 
74     template <typename U>
fill(U && tensor)75     void fill(U &&tensor)
76     {
77         if(tensor.data_type() == DataType::F32)
78         {
79             std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
80             library->fill(tensor, distribution, 0);
81         }
82         else if(tensor.data_type() == DataType::F16)
83         {
84             arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f };
85             library->fill(tensor, distribution, 0);
86         }
87         else // data type is quantized_asymmetric
88         {
89             library->fill_tensor_uniform(tensor, 0);
90         }
91     }
92 
compute_target(TensorShape shape,PoolingLayerInfo info,DataType data_type,DataLayout data_layout,QuantizationInfo input_qinfo,QuantizationInfo output_qinfo,bool indices)93     TensorType compute_target(TensorShape shape, PoolingLayerInfo info,
94                               DataType data_type, DataLayout data_layout,
95                               QuantizationInfo input_qinfo, QuantizationInfo output_qinfo,
96                               bool indices)
97     {
98         // Change shape in case of NHWC.
99         if(data_layout == DataLayout::NHWC)
100         {
101             permute(shape, PermutationVector(2U, 0U, 1U));
102         }
103         // Create tensors
104         TensorType        src       = create_tensor<TensorType>(shape, data_type, 1, input_qinfo, data_layout);
105         const TensorShape dst_shape = misc::shape_calculator::compute_pool_shape(*(src.info()), info);
106         TensorType        dst       = create_tensor<TensorType>(dst_shape, data_type, 1, output_qinfo, data_layout);
107         _target_indices             = create_tensor<TensorType>(dst_shape, DataType::U32, 1, output_qinfo, data_layout);
108 
109         // Create and configure function
110         FunctionType pool_layer;
111         pool_layer.configure(&src, &dst, info, (indices) ? &_target_indices : nullptr);
112 
113         ARM_COMPUTE_ASSERT(src.info()->is_resizable());
114         ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
115         ARM_COMPUTE_ASSERT(_target_indices.info()->is_resizable());
116 
117         add_padding_x({ &src, &dst, &_target_indices }, data_layout);
118 
119         // Allocate tensors
120         src.allocator()->allocate();
121         dst.allocator()->allocate();
122         _target_indices.allocator()->allocate();
123 
124         ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
125         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
126         ARM_COMPUTE_ASSERT(!_target_indices.info()->is_resizable());
127 
128         // Fill tensors
129         fill(AccessorType(src));
130 
131         if(_mixed_layout)
132         {
133             mix_layout(pool_layer, src, dst);
134         }
135         else
136         {
137             // Compute function
138             pool_layer.run();
139         }
140         return dst;
141     }
142 
compute_reference(TensorShape shape,PoolingLayerInfo info,DataType data_type,DataLayout data_layout,QuantizationInfo input_qinfo,QuantizationInfo output_qinfo,bool indices)143     SimpleTensor<T> compute_reference(TensorShape shape, PoolingLayerInfo info, DataType data_type, DataLayout data_layout,
144                                       QuantizationInfo input_qinfo, QuantizationInfo output_qinfo, bool indices)
145     {
146         // Create reference
147         SimpleTensor<T> src(shape, data_type, 1, input_qinfo);
148         // Fill reference
149         fill(src);
150         return reference::pooling_layer<T>(src, info, output_qinfo, indices ? &_ref_indices : nullptr, data_layout);
151     }
152 
153     TensorType             _target{};
154     SimpleTensor<T>        _reference{};
155     PoolingLayerInfo       _pool_info{};
156     bool                   _mixed_layout{ false };
157     TensorType             _target_indices{};
158     SimpleTensor<uint32_t> _ref_indices{};
159 };
160 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
161 class PoolingLayerIndicesValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
162 {
163 public:
164     template <typename...>
setup(TensorShape shape,PoolingType pool_type,Size2D pool_size,PadStrideInfo pad_stride_info,bool exclude_padding,DataType data_type,DataLayout data_layout)165     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout)
166     {
167         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding),
168                                                                                                data_type, data_layout, true);
169     }
170 };
171 
172 template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
173 class PoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
174 {
175 public:
176     template <typename...>
setup(TensorShape shape,PoolingType pool_type,Size2D pool_size,PadStrideInfo pad_stride_info,bool exclude_padding,DataType data_type,DataLayout data_layout)177     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout)
178     {
179         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding),
180                                                                                                data_type, data_layout, false, mixed_layout);
181     }
182 };
183 
184 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
185 class PoolingLayerValidationMixedPrecisionFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
186 {
187 public:
188     template <typename...>
189     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout, bool fp_mixed_precision = false)
190     {
191         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding, fp_mixed_precision),
192                                                                                                data_type, data_layout);
193     }
194 };
195 
196 template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
197 class PoolingLayerValidationQuantizedFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
198 {
199 public:
200     template <typename...>
201     void setup(TensorShape shape, PoolingType pool_type, Size2D pool_size, PadStrideInfo pad_stride_info, bool exclude_padding, DataType data_type, DataLayout data_layout = DataLayout::NCHW,
202                QuantizationInfo input_qinfo = QuantizationInfo(), QuantizationInfo output_qinfo = QuantizationInfo())
203     {
204         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, pool_size, data_layout, pad_stride_info, exclude_padding),
205                                                                                                data_type, data_layout, false, input_qinfo, output_qinfo, mixed_layout);
206     }
207 };
208 
209 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
210 class SpecialPoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
211 {
212 public:
213     template <typename...>
setup(TensorShape src_shape,PoolingLayerInfo pool_info,DataType data_type)214     void setup(TensorShape src_shape, PoolingLayerInfo pool_info, DataType data_type)
215     {
216         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, pool_info, data_type, pool_info.data_layout);
217     }
218 };
219 
220 template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
221 class GlobalPoolingLayerValidationFixture : public PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
222 {
223 public:
224     template <typename...>
225     void setup(TensorShape shape, PoolingType pool_type, DataType data_type, DataLayout data_layout = DataLayout::NCHW)
226     {
227         PoolingLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(shape, PoolingLayerInfo(pool_type, data_layout), data_type, data_layout);
228     }
229 };
230 
231 } // namespace validation
232 } // namespace test
233 } // namespace arm_compute
234 #endif /* ARM_COMPUTE_TEST_POOLING_LAYER_FIXTURE */
235