xref: /aosp_15_r20/external/ComputeLibrary/tests/validation/fixtures/FullyConnectedLayerFixture.h (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2017-2022 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE
25*c217d954SCole Faust #define ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE
26*c217d954SCole Faust 
27*c217d954SCole Faust #include "arm_compute/core/TensorShape.h"
28*c217d954SCole Faust #include "arm_compute/core/Types.h"
29*c217d954SCole Faust #include "arm_compute/core/Utils.h"
30*c217d954SCole Faust #include "tests/AssetsLibrary.h"
31*c217d954SCole Faust #include "tests/Globals.h"
32*c217d954SCole Faust #include "tests/IAccessor.h"
33*c217d954SCole Faust #include "tests/RawTensor.h"
34*c217d954SCole Faust #include "tests/framework/Asserts.h"
35*c217d954SCole Faust #include "tests/framework/Fixture.h"
36*c217d954SCole Faust #include "tests/validation/Helpers.h"
37*c217d954SCole Faust #include "tests/validation/reference/ActivationLayer.h"
38*c217d954SCole Faust #include "tests/validation/reference/FullyConnectedLayer.h"
39*c217d954SCole Faust #include "tests/validation/reference/Utils.h"
40*c217d954SCole Faust 
41*c217d954SCole Faust #include <random>
42*c217d954SCole Faust 
43*c217d954SCole Faust namespace arm_compute
44*c217d954SCole Faust {
45*c217d954SCole Faust namespace test
46*c217d954SCole Faust {
47*c217d954SCole Faust namespace validation
48*c217d954SCole Faust {
49*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
50*c217d954SCole Faust class FullyConnectedLayerValidationGenericFixture : public framework::Fixture
51*c217d954SCole Faust {
52*c217d954SCole Faust public:
53*c217d954SCole Faust     using TDecay = typename std::decay<T>::type;
54*c217d954SCole Faust     using TBias  = typename std::conditional < (std::is_same<TDecay, uint8_t>::value || std::is_same<TDecay, int8_t>::value), int32_t, T >::type;
55*c217d954SCole Faust 
56*c217d954SCole Faust public:
57*c217d954SCole Faust     template <typename...>
58*c217d954SCole Faust     void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights,
59*c217d954SCole Faust                DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo activation_info, bool mixed_layout = false)
60*c217d954SCole Faust     {
61*c217d954SCole Faust         ARM_COMPUTE_UNUSED(weights_shape);
62*c217d954SCole Faust         ARM_COMPUTE_UNUSED(bias_shape);
63*c217d954SCole Faust 
64*c217d954SCole Faust         _mixed_layout      = mixed_layout;
65*c217d954SCole Faust         _data_type         = data_type;
66*c217d954SCole Faust         _bias_data_type    = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type;
67*c217d954SCole Faust         _quantization_info = quantization_info;
68*c217d954SCole Faust         _activation_info   = activation_info;
69*c217d954SCole Faust 
70*c217d954SCole Faust         _target    = compute_target(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights);
71*c217d954SCole Faust         _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape);
72*c217d954SCole Faust     }
73*c217d954SCole Faust 
74*c217d954SCole Faust protected:
mix_layout(FunctionType & layer,TensorType & src,TensorType & dst)75*c217d954SCole Faust     void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst)
76*c217d954SCole Faust     {
77*c217d954SCole Faust         const DataLayout data_layout = src.info()->data_layout();
78*c217d954SCole Faust         // Test Multi DataLayout graph cases, when the data layout changes after configure
79*c217d954SCole Faust         src.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW);
80*c217d954SCole Faust         dst.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW);
81*c217d954SCole Faust 
82*c217d954SCole Faust         // Compute Convolution function
83*c217d954SCole Faust         layer.run();
84*c217d954SCole Faust 
85*c217d954SCole Faust         // Reinstating original data layout for the test suite to properly check the values
86*c217d954SCole Faust         src.info()->set_data_layout(data_layout);
87*c217d954SCole Faust         dst.info()->set_data_layout(data_layout);
88*c217d954SCole Faust     }
89*c217d954SCole Faust 
90*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)91*c217d954SCole Faust     void fill(U &&tensor, int i)
92*c217d954SCole Faust     {
93*c217d954SCole Faust         if(_data_type == DataType::QASYMM8)
94*c217d954SCole Faust         {
95*c217d954SCole Faust             std::uniform_int_distribution<uint32_t> distribution(0, 30);
96*c217d954SCole Faust             library->fill(tensor, distribution, i);
97*c217d954SCole Faust         }
98*c217d954SCole Faust         else if(_data_type == DataType::QASYMM8_SIGNED)
99*c217d954SCole Faust         {
100*c217d954SCole Faust             std::uniform_int_distribution<int32_t> distribution(-15, 15);
101*c217d954SCole Faust             library->fill(tensor, distribution, i);
102*c217d954SCole Faust         }
103*c217d954SCole Faust         else if(_data_type == DataType::S32)
104*c217d954SCole Faust         {
105*c217d954SCole Faust             std::uniform_int_distribution<int32_t> distribution(-50, 50);
106*c217d954SCole Faust             library->fill(tensor, distribution, i);
107*c217d954SCole Faust         }
108*c217d954SCole Faust         else if(_data_type == DataType::F16)
109*c217d954SCole Faust         {
110*c217d954SCole Faust             arm_compute::utils::uniform_real_distribution_16bit<half> distribution(-1.0f, 1.0f);
111*c217d954SCole Faust             library->fill(tensor, distribution, i);
112*c217d954SCole Faust         }
113*c217d954SCole Faust         else if(_data_type == DataType::F32)
114*c217d954SCole Faust         {
115*c217d954SCole Faust             std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
116*c217d954SCole Faust             library->fill(tensor, distribution, i);
117*c217d954SCole Faust         }
118*c217d954SCole Faust         else
119*c217d954SCole Faust         {
120*c217d954SCole Faust             library->fill_tensor_uniform(tensor, i);
121*c217d954SCole Faust         }
122*c217d954SCole Faust     }
123*c217d954SCole Faust 
compute_target(const TensorShape & input_shape,const TensorShape & weights_shape,const TensorShape & bias_shape,const TensorShape & output_shape,bool transpose_weights,bool reshape_weights)124*c217d954SCole Faust     TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, bool transpose_weights,
125*c217d954SCole Faust                               bool reshape_weights)
126*c217d954SCole Faust     {
127*c217d954SCole Faust         TensorShape reshaped_weights_shape(weights_shape);
128*c217d954SCole Faust 
129*c217d954SCole Faust         // Test actions depending on the target settings
130*c217d954SCole Faust         //
131*c217d954SCole Faust         //            | reshape   | !reshape
132*c217d954SCole Faust         // -----------+-----------+---------------------------
133*c217d954SCole Faust         //  transpose |           | ***
134*c217d954SCole Faust         // -----------+-----------+---------------------------
135*c217d954SCole Faust         // !transpose | transpose | transpose
136*c217d954SCole Faust         //            |           |
137*c217d954SCole Faust         //
138*c217d954SCole Faust         // ***: That combination is invalid. But we can ignore the transpose flag and handle all !reshape the same
139*c217d954SCole Faust         if(!reshape_weights || !transpose_weights)
140*c217d954SCole Faust         {
141*c217d954SCole Faust             const size_t shape_x = reshaped_weights_shape.x();
142*c217d954SCole Faust             reshaped_weights_shape.set(0, reshaped_weights_shape.y());
143*c217d954SCole Faust             reshaped_weights_shape.set(1, shape_x);
144*c217d954SCole Faust         }
145*c217d954SCole Faust 
146*c217d954SCole Faust         // Create tensors
147*c217d954SCole Faust         TensorType src     = create_tensor<TensorType>(input_shape, _data_type, 1, _quantization_info);
148*c217d954SCole Faust         TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, _data_type, 1, _quantization_info);
149*c217d954SCole Faust         TensorType bias    = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _quantization_info);
150*c217d954SCole Faust         TensorType dst     = create_tensor<TensorType>(output_shape, _data_type, 1, _quantization_info);
151*c217d954SCole Faust 
152*c217d954SCole Faust         // Create Fully Connected layer info
153*c217d954SCole Faust         FullyConnectedLayerInfo fc_info;
154*c217d954SCole Faust         fc_info.transpose_weights    = transpose_weights;
155*c217d954SCole Faust         fc_info.are_weights_reshaped = !reshape_weights;
156*c217d954SCole Faust         fc_info.activation_info      = _activation_info;
157*c217d954SCole Faust 
158*c217d954SCole Faust         // Create and configure function.
159*c217d954SCole Faust         FunctionType fc;
160*c217d954SCole Faust         fc.configure(&src, &weights, &bias, &dst, fc_info);
161*c217d954SCole Faust 
162*c217d954SCole Faust         ARM_COMPUTE_ASSERT(src.info()->is_resizable());
163*c217d954SCole Faust         ARM_COMPUTE_ASSERT(weights.info()->is_resizable());
164*c217d954SCole Faust         ARM_COMPUTE_ASSERT(bias.info()->is_resizable());
165*c217d954SCole Faust         ARM_COMPUTE_ASSERT(dst.info()->is_resizable());
166*c217d954SCole Faust 
167*c217d954SCole Faust         add_padding_x({ &src, &weights, &bias, &dst });
168*c217d954SCole Faust 
169*c217d954SCole Faust         // Allocate tensors
170*c217d954SCole Faust         src.allocator()->allocate();
171*c217d954SCole Faust         weights.allocator()->allocate();
172*c217d954SCole Faust         bias.allocator()->allocate();
173*c217d954SCole Faust         dst.allocator()->allocate();
174*c217d954SCole Faust 
175*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!src.info()->is_resizable());
176*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!weights.info()->is_resizable());
177*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!bias.info()->is_resizable());
178*c217d954SCole Faust         ARM_COMPUTE_ASSERT(!dst.info()->is_resizable());
179*c217d954SCole Faust 
180*c217d954SCole Faust         // Fill tensors
181*c217d954SCole Faust         fill(AccessorType(src), 0);
182*c217d954SCole Faust         fill(AccessorType(bias), 2);
183*c217d954SCole Faust 
184*c217d954SCole Faust         if(!reshape_weights || !transpose_weights)
185*c217d954SCole Faust         {
186*c217d954SCole Faust             TensorShape tmp_shape(weights_shape);
187*c217d954SCole Faust             RawTensor   tmp(tmp_shape, _data_type, 1);
188*c217d954SCole Faust 
189*c217d954SCole Faust             // Fill with original shape
190*c217d954SCole Faust             fill(tmp, 1);
191*c217d954SCole Faust 
192*c217d954SCole Faust             // Transpose elementwise
193*c217d954SCole Faust             tmp = transpose(tmp);
194*c217d954SCole Faust 
195*c217d954SCole Faust             AccessorType weights_accessor(weights);
196*c217d954SCole Faust 
197*c217d954SCole Faust             for(int i = 0; i < tmp.num_elements(); ++i)
198*c217d954SCole Faust             {
199*c217d954SCole Faust                 Coordinates coord = index2coord(tmp.shape(), i);
200*c217d954SCole Faust                 std::copy_n(static_cast<const RawTensor::value_type *>(tmp(coord)),
201*c217d954SCole Faust                             tmp.element_size(),
202*c217d954SCole Faust                             static_cast<RawTensor::value_type *>(weights_accessor(coord)));
203*c217d954SCole Faust             }
204*c217d954SCole Faust         }
205*c217d954SCole Faust         else
206*c217d954SCole Faust         {
207*c217d954SCole Faust             fill(AccessorType(weights), 1);
208*c217d954SCole Faust         }
209*c217d954SCole Faust 
210*c217d954SCole Faust         if(_mixed_layout)
211*c217d954SCole Faust         {
212*c217d954SCole Faust             mix_layout(fc, src, dst);
213*c217d954SCole Faust         }
214*c217d954SCole Faust         else
215*c217d954SCole Faust         {
216*c217d954SCole Faust             // Compute NEFullyConnectedLayer function
217*c217d954SCole Faust             fc.run();
218*c217d954SCole Faust         }
219*c217d954SCole Faust 
220*c217d954SCole Faust         return dst;
221*c217d954SCole Faust     }
222*c217d954SCole Faust 
compute_reference(const TensorShape & input_shape,const TensorShape & weights_shape,const TensorShape & bias_shape,const TensorShape & output_shape)223*c217d954SCole Faust     SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape)
224*c217d954SCole Faust     {
225*c217d954SCole Faust         // Create reference
226*c217d954SCole Faust         SimpleTensor<T>     src{ input_shape, _data_type, 1, _quantization_info };
227*c217d954SCole Faust         SimpleTensor<T>     weights{ weights_shape, _data_type, 1, _quantization_info };
228*c217d954SCole Faust         SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _quantization_info };
229*c217d954SCole Faust 
230*c217d954SCole Faust         // Fill reference
231*c217d954SCole Faust         fill(src, 0);
232*c217d954SCole Faust         fill(weights, 1);
233*c217d954SCole Faust         fill(bias, 2);
234*c217d954SCole Faust 
235*c217d954SCole Faust         return reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, output_shape, _quantization_info), _activation_info, _quantization_info);
236*c217d954SCole Faust     }
237*c217d954SCole Faust 
238*c217d954SCole Faust     TensorType          _target{};
239*c217d954SCole Faust     SimpleTensor<T>     _reference{};
240*c217d954SCole Faust     DataType            _data_type{};
241*c217d954SCole Faust     DataType            _bias_data_type{};
242*c217d954SCole Faust     bool                _mixed_layout{ false };
243*c217d954SCole Faust     QuantizationInfo    _quantization_info{};
244*c217d954SCole Faust     ActivationLayerInfo _activation_info{};
245*c217d954SCole Faust };
246*c217d954SCole Faust 
247*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
248*c217d954SCole Faust class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
249*c217d954SCole Faust {
250*c217d954SCole Faust public:
251*c217d954SCole Faust     template <typename...>
setup(TensorShape input_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape output_shape,bool transpose_weights,bool reshape_weights,DataType data_type,ActivationLayerInfo activation_info)252*c217d954SCole Faust     void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type,
253*c217d954SCole Faust                ActivationLayerInfo activation_info)
254*c217d954SCole Faust     {
255*c217d954SCole Faust         FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
256*c217d954SCole Faust                                                                                                       reshape_weights, data_type,
257*c217d954SCole Faust                                                                                                       QuantizationInfo(), activation_info, mixed_layout);
258*c217d954SCole Faust     }
259*c217d954SCole Faust };
260*c217d954SCole Faust 
261*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false>
262*c217d954SCole Faust class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>
263*c217d954SCole Faust {
264*c217d954SCole Faust public:
265*c217d954SCole Faust     template <typename...>
setup(TensorShape input_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape output_shape,bool transpose_weights,bool reshape_weights,DataType data_type,QuantizationInfo quantization_info,ActivationLayerInfo activation_info)266*c217d954SCole Faust     void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type,
267*c217d954SCole Faust                QuantizationInfo quantization_info, ActivationLayerInfo activation_info)
268*c217d954SCole Faust     {
269*c217d954SCole Faust         FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights,
270*c217d954SCole Faust                                                                                                       reshape_weights, data_type,
271*c217d954SCole Faust                                                                                                       quantization_info, activation_info, mixed_layout);
272*c217d954SCole Faust     }
273*c217d954SCole Faust };
274*c217d954SCole Faust 
275*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
276*c217d954SCole Faust class FullyConnectedWithDynamicTensorsFixture : public framework::Fixture
277*c217d954SCole Faust {
278*c217d954SCole Faust private:
279*c217d954SCole Faust     template <typename U>
fill(U && tensor,int i)280*c217d954SCole Faust     void fill(U &&tensor, int i)
281*c217d954SCole Faust     {
282*c217d954SCole Faust         if(_data_type == DataType::F16)
283*c217d954SCole Faust         {
284*c217d954SCole Faust             arm_compute::utils::uniform_real_distribution_16bit<half> distribution(-1.0f, 1.0f);
285*c217d954SCole Faust             library->fill(tensor, distribution, i);
286*c217d954SCole Faust         }
287*c217d954SCole Faust         else if(_data_type == DataType::F32)
288*c217d954SCole Faust         {
289*c217d954SCole Faust             std::uniform_real_distribution<float> distribution(-1.0f, 1.0f);
290*c217d954SCole Faust             library->fill(tensor, distribution, i);
291*c217d954SCole Faust         }
292*c217d954SCole Faust         else if(_data_type == DataType::QASYMM8)
293*c217d954SCole Faust         {
294*c217d954SCole Faust             std::uniform_int_distribution<uint32_t> distribution(0, 30);
295*c217d954SCole Faust             library->fill(tensor, distribution, i);
296*c217d954SCole Faust         }
297*c217d954SCole Faust         else if(_data_type == DataType::S32)
298*c217d954SCole Faust         {
299*c217d954SCole Faust             std::uniform_int_distribution<int32_t> distribution(-50, 50);
300*c217d954SCole Faust             library->fill(tensor, distribution, i);
301*c217d954SCole Faust         }
302*c217d954SCole Faust         else
303*c217d954SCole Faust         {
304*c217d954SCole Faust             library->fill_tensor_uniform(tensor, i);
305*c217d954SCole Faust         }
306*c217d954SCole Faust     }
307*c217d954SCole Faust 
fill_transposed_weights(TensorType & weights,TensorShape weights_shape,int seed)308*c217d954SCole Faust     void fill_transposed_weights(TensorType &weights, TensorShape weights_shape, int seed)
309*c217d954SCole Faust     {
310*c217d954SCole Faust         RawTensor tmp(weights_shape, _data_type, 1);
311*c217d954SCole Faust 
312*c217d954SCole Faust         // Fill with original shape
313*c217d954SCole Faust         fill(tmp, seed);
314*c217d954SCole Faust 
315*c217d954SCole Faust         // Transpose elementwise
316*c217d954SCole Faust         tmp = transpose(tmp);
317*c217d954SCole Faust 
318*c217d954SCole Faust         AccessorType weights_accessor(weights);
319*c217d954SCole Faust 
320*c217d954SCole Faust         for(int i = 0; i < tmp.num_elements(); ++i)
321*c217d954SCole Faust         {
322*c217d954SCole Faust             Coordinates coord = index2coord(tmp.shape(), i);
323*c217d954SCole Faust             std::copy_n(static_cast<const RawTensor::value_type *>(tmp(coord)),
324*c217d954SCole Faust                         tmp.element_size(),
325*c217d954SCole Faust                         static_cast<RawTensor::value_type *>(weights_accessor(coord)));
326*c217d954SCole Faust         }
327*c217d954SCole Faust     }
328*c217d954SCole Faust 
validate_with_tolerance(TensorType & target,SimpleTensor<T> & ref)329*c217d954SCole Faust     void validate_with_tolerance(TensorType &target, SimpleTensor<T> &ref)
330*c217d954SCole Faust     {
331*c217d954SCole Faust         if(_data_type == DataType::F32)
332*c217d954SCole Faust         {
333*c217d954SCole Faust             constexpr RelativeTolerance<float> rel_tolerance_f32(0.05f);
334*c217d954SCole Faust             constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.0001f);
335*c217d954SCole Faust             validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32);
336*c217d954SCole Faust         }
337*c217d954SCole Faust         else if(_data_type == DataType::QASYMM8)
338*c217d954SCole Faust         {
339*c217d954SCole Faust             constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8(1);
340*c217d954SCole Faust             validate(AccessorType(target), ref, tolerance_qasymm8);
341*c217d954SCole Faust         }
342*c217d954SCole Faust         else
343*c217d954SCole Faust         {
344*c217d954SCole Faust             validate(AccessorType(target), ref);
345*c217d954SCole Faust         }
346*c217d954SCole Faust     }
347*c217d954SCole Faust 
348*c217d954SCole Faust public:
349*c217d954SCole Faust     using TDecay = typename std::decay<T>::type;
350*c217d954SCole Faust     using TBias  = typename std::conditional < (std::is_same<TDecay, uint8_t>::value || std::is_same<TDecay, int8_t>::value), int32_t, T >::type;
351*c217d954SCole Faust 
352*c217d954SCole Faust     template <typename...>
setup(TensorShape src_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape dst_shape,DataType data_type,ActivationLayerInfo activation_info,bool constant_weights,bool constant_bias)353*c217d954SCole Faust     void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
354*c217d954SCole Faust                DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias)
355*c217d954SCole Faust     {
356*c217d954SCole Faust         _data_type = data_type;
357*c217d954SCole Faust 
358*c217d954SCole Faust         const bool is_quantized = is_data_type_quantized(data_type);
359*c217d954SCole Faust 
360*c217d954SCole Faust         const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type;
361*c217d954SCole Faust 
362*c217d954SCole Faust         const QuantizationInfo src_qinfo     = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo();
363*c217d954SCole Faust         const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo();
364*c217d954SCole Faust         const QuantizationInfo dst_qinfo     = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo();
365*c217d954SCole Faust 
366*c217d954SCole Faust         // Setup tensor meta-data
367*c217d954SCole Faust         const TensorInfo src_info(src_shape, 1, data_type, src_qinfo);
368*c217d954SCole Faust         _src.allocator()->init(src_info);
369*c217d954SCole Faust 
370*c217d954SCole Faust         TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo);
371*c217d954SCole Faust         if(!constant_weights)
372*c217d954SCole Faust         {
373*c217d954SCole Faust             const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] };
374*c217d954SCole Faust             wei_info.set_tensor_shape(tr_weights_shape);
375*c217d954SCole Faust         }
376*c217d954SCole Faust         wei_info.set_are_values_constant(constant_weights);
377*c217d954SCole Faust         _weights.allocator()->init(wei_info);
378*c217d954SCole Faust 
379*c217d954SCole Faust         TensorInfo bias_info(bias_shape, 1, bias_data_type);
380*c217d954SCole Faust         bias_info.set_are_values_constant(constant_bias);
381*c217d954SCole Faust         _bias.allocator()->init(bias_info);
382*c217d954SCole Faust 
383*c217d954SCole Faust         const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo);
384*c217d954SCole Faust         _dst.allocator()->init(dst_info);
385*c217d954SCole Faust 
386*c217d954SCole Faust         // Configure FC layer and mark the weights as non constant
387*c217d954SCole Faust         FullyConnectedLayerInfo fc_info;
388*c217d954SCole Faust         fc_info.activation_info = activation_info;
389*c217d954SCole Faust         if(!constant_weights)
390*c217d954SCole Faust         {
391*c217d954SCole Faust             fc_info.are_weights_reshaped = true;
392*c217d954SCole Faust             fc_info.transpose_weights    = false;
393*c217d954SCole Faust         }
394*c217d954SCole Faust         FunctionType fc;
395*c217d954SCole Faust         fc.configure(&_src, &_weights, &_bias, &_dst, fc_info);
396*c217d954SCole Faust 
397*c217d954SCole Faust         // Allocate all the tensors
398*c217d954SCole Faust         _src.allocator()->allocate();
399*c217d954SCole Faust         _weights.allocator()->allocate();
400*c217d954SCole Faust         _bias.allocator()->allocate();
401*c217d954SCole Faust         _dst.allocator()->allocate();
402*c217d954SCole Faust 
403*c217d954SCole Faust         // Run multiple iterations with different inputs
404*c217d954SCole Faust         constexpr int num_iterations    = 5;
405*c217d954SCole Faust         int           randomizer_offset = 0;
406*c217d954SCole Faust 
407*c217d954SCole Faust         // Create reference tensors
408*c217d954SCole Faust         SimpleTensor<T>     src{ src_shape, data_type, 1, src_qinfo };
409*c217d954SCole Faust         SimpleTensor<T>     weights{ weights_shape, data_type, 1, weights_qinfo };
410*c217d954SCole Faust         SimpleTensor<TBias> bias{ bias_shape, bias_data_type };
411*c217d954SCole Faust 
412*c217d954SCole Faust         // Fill weights and/or bias if they remain constant
413*c217d954SCole Faust         if(constant_weights)
414*c217d954SCole Faust         {
415*c217d954SCole Faust             fill(AccessorType(_weights), 1);
416*c217d954SCole Faust             fill(weights, 1);
417*c217d954SCole Faust         }
418*c217d954SCole Faust         if(constant_bias)
419*c217d954SCole Faust         {
420*c217d954SCole Faust             fill(AccessorType(_bias), 2);
421*c217d954SCole Faust             fill(bias, 2);
422*c217d954SCole Faust         }
423*c217d954SCole Faust 
424*c217d954SCole Faust         for(int i = 0; i < num_iterations; ++i)
425*c217d954SCole Faust         {
426*c217d954SCole Faust             // Run target
427*c217d954SCole Faust             {
428*c217d954SCole Faust                 fill(AccessorType(_src), randomizer_offset);
429*c217d954SCole Faust                 if(!constant_weights)
430*c217d954SCole Faust                 {
431*c217d954SCole Faust                     fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1);
432*c217d954SCole Faust                 }
433*c217d954SCole Faust                 if(!constant_bias)
434*c217d954SCole Faust                 {
435*c217d954SCole Faust                     fill(AccessorType(_bias), randomizer_offset + 2);
436*c217d954SCole Faust                 }
437*c217d954SCole Faust 
438*c217d954SCole Faust                 fc.run();
439*c217d954SCole Faust             }
440*c217d954SCole Faust 
441*c217d954SCole Faust             // Run reference and compare
442*c217d954SCole Faust             {
443*c217d954SCole Faust                 // Fill reference
444*c217d954SCole Faust                 fill(src, randomizer_offset);
445*c217d954SCole Faust                 if(!constant_weights)
446*c217d954SCole Faust                 {
447*c217d954SCole Faust                     fill(weights, randomizer_offset + 1);
448*c217d954SCole Faust                 }
449*c217d954SCole Faust                 if(!constant_bias)
450*c217d954SCole Faust                 {
451*c217d954SCole Faust                     fill(bias, randomizer_offset + 2);
452*c217d954SCole Faust                 }
453*c217d954SCole Faust 
454*c217d954SCole Faust                 auto dst = reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, dst_shape, dst_qinfo), activation_info, dst_qinfo);
455*c217d954SCole Faust 
456*c217d954SCole Faust                 // Validate
457*c217d954SCole Faust                 validate_with_tolerance(_dst, dst);
458*c217d954SCole Faust             }
459*c217d954SCole Faust 
460*c217d954SCole Faust             randomizer_offset += 100;
461*c217d954SCole Faust         }
462*c217d954SCole Faust     }
463*c217d954SCole Faust 
464*c217d954SCole Faust private:
465*c217d954SCole Faust     TensorType _src{}, _weights{}, _bias{}, _dst{};
466*c217d954SCole Faust     DataType   _data_type{ DataType::UNKNOWN };
467*c217d954SCole Faust };
468*c217d954SCole Faust 
469*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
470*c217d954SCole Faust class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
471*c217d954SCole Faust {
472*c217d954SCole Faust public:
473*c217d954SCole Faust     template <typename...>
setup(TensorShape src_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape dst_shape,DataType data_type,ActivationLayerInfo activation_info)474*c217d954SCole Faust     void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
475*c217d954SCole Faust                DataType data_type, ActivationLayerInfo activation_info)
476*c217d954SCole Faust     {
477*c217d954SCole Faust         FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
478*c217d954SCole Faust                                                                                                   dst_shape, data_type, activation_info, false, true);
479*c217d954SCole Faust     }
480*c217d954SCole Faust };
481*c217d954SCole Faust 
482*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T>
483*c217d954SCole Faust class FullyConnectedWithDynamicBiasFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>
484*c217d954SCole Faust {
485*c217d954SCole Faust public:
486*c217d954SCole Faust     template <typename...>
setup(TensorShape src_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape dst_shape,DataType data_type,ActivationLayerInfo activation_info)487*c217d954SCole Faust     void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape,
488*c217d954SCole Faust                DataType data_type, ActivationLayerInfo activation_info)
489*c217d954SCole Faust     {
490*c217d954SCole Faust         FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape,
491*c217d954SCole Faust                                                                                                   dst_shape, data_type, activation_info, true, false);
492*c217d954SCole Faust     }
493*c217d954SCole Faust };
494*c217d954SCole Faust } // namespace validation
495*c217d954SCole Faust } // namespace test
496*c217d954SCole Faust } // namespace arm_compute
497*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE */
498