1*c217d954SCole Faust /* 2*c217d954SCole Faust * Copyright (c) 2017-2022 Arm Limited. 3*c217d954SCole Faust * 4*c217d954SCole Faust * SPDX-License-Identifier: MIT 5*c217d954SCole Faust * 6*c217d954SCole Faust * Permission is hereby granted, free of charge, to any person obtaining a copy 7*c217d954SCole Faust * of this software and associated documentation files (the "Software"), to 8*c217d954SCole Faust * deal in the Software without restriction, including without limitation the 9*c217d954SCole Faust * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10*c217d954SCole Faust * sell copies of the Software, and to permit persons to whom the Software is 11*c217d954SCole Faust * furnished to do so, subject to the following conditions: 12*c217d954SCole Faust * 13*c217d954SCole Faust * The above copyright notice and this permission notice shall be included in all 14*c217d954SCole Faust * copies or substantial portions of the Software. 15*c217d954SCole Faust * 16*c217d954SCole Faust * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17*c217d954SCole Faust * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18*c217d954SCole Faust * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19*c217d954SCole Faust * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20*c217d954SCole Faust * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21*c217d954SCole Faust * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22*c217d954SCole Faust * SOFTWARE. 23*c217d954SCole Faust */ 24*c217d954SCole Faust #ifndef ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE 25*c217d954SCole Faust #define ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE 26*c217d954SCole Faust 27*c217d954SCole Faust #include "arm_compute/core/TensorShape.h" 28*c217d954SCole Faust #include "arm_compute/core/Types.h" 29*c217d954SCole Faust #include "arm_compute/core/Utils.h" 30*c217d954SCole Faust #include "tests/AssetsLibrary.h" 31*c217d954SCole Faust #include "tests/Globals.h" 32*c217d954SCole Faust #include "tests/IAccessor.h" 33*c217d954SCole Faust #include "tests/RawTensor.h" 34*c217d954SCole Faust #include "tests/framework/Asserts.h" 35*c217d954SCole Faust #include "tests/framework/Fixture.h" 36*c217d954SCole Faust #include "tests/validation/Helpers.h" 37*c217d954SCole Faust #include "tests/validation/reference/ActivationLayer.h" 38*c217d954SCole Faust #include "tests/validation/reference/FullyConnectedLayer.h" 39*c217d954SCole Faust #include "tests/validation/reference/Utils.h" 40*c217d954SCole Faust 41*c217d954SCole Faust #include <random> 42*c217d954SCole Faust 43*c217d954SCole Faust namespace arm_compute 44*c217d954SCole Faust { 45*c217d954SCole Faust namespace test 46*c217d954SCole Faust { 47*c217d954SCole Faust namespace validation 48*c217d954SCole Faust { 49*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T> 50*c217d954SCole Faust class FullyConnectedLayerValidationGenericFixture : public framework::Fixture 51*c217d954SCole Faust { 52*c217d954SCole Faust public: 53*c217d954SCole Faust using TDecay = typename std::decay<T>::type; 54*c217d954SCole Faust using TBias = typename std::conditional < (std::is_same<TDecay, uint8_t>::value || std::is_same<TDecay, int8_t>::value), int32_t, T >::type; 55*c217d954SCole Faust 56*c217d954SCole Faust public: 57*c217d954SCole Faust template <typename...> 58*c217d954SCole Faust void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, 59*c217d954SCole Faust DataType data_type, QuantizationInfo quantization_info, ActivationLayerInfo activation_info, bool mixed_layout = false) 60*c217d954SCole Faust { 61*c217d954SCole Faust ARM_COMPUTE_UNUSED(weights_shape); 62*c217d954SCole Faust ARM_COMPUTE_UNUSED(bias_shape); 63*c217d954SCole Faust 64*c217d954SCole Faust _mixed_layout = mixed_layout; 65*c217d954SCole Faust _data_type = data_type; 66*c217d954SCole Faust _bias_data_type = is_data_type_quantized_asymmetric(data_type) ? DataType::S32 : data_type; 67*c217d954SCole Faust _quantization_info = quantization_info; 68*c217d954SCole Faust _activation_info = activation_info; 69*c217d954SCole Faust 70*c217d954SCole Faust _target = compute_target(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, reshape_weights); 71*c217d954SCole Faust _reference = compute_reference(input_shape, weights_shape, bias_shape, output_shape); 72*c217d954SCole Faust } 73*c217d954SCole Faust 74*c217d954SCole Faust protected: mix_layout(FunctionType & layer,TensorType & src,TensorType & dst)75*c217d954SCole Faust void mix_layout(FunctionType &layer, TensorType &src, TensorType &dst) 76*c217d954SCole Faust { 77*c217d954SCole Faust const DataLayout data_layout = src.info()->data_layout(); 78*c217d954SCole Faust // Test Multi DataLayout graph cases, when the data layout changes after configure 79*c217d954SCole Faust src.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); 80*c217d954SCole Faust dst.info()->set_data_layout(data_layout == DataLayout::NCHW ? DataLayout::NHWC : DataLayout::NCHW); 81*c217d954SCole Faust 82*c217d954SCole Faust // Compute Convolution function 83*c217d954SCole Faust layer.run(); 84*c217d954SCole Faust 85*c217d954SCole Faust // Reinstating original data layout for the test suite to properly check the values 86*c217d954SCole Faust src.info()->set_data_layout(data_layout); 87*c217d954SCole Faust dst.info()->set_data_layout(data_layout); 88*c217d954SCole Faust } 89*c217d954SCole Faust 90*c217d954SCole Faust template <typename U> fill(U && tensor,int i)91*c217d954SCole Faust void fill(U &&tensor, int i) 92*c217d954SCole Faust { 93*c217d954SCole Faust if(_data_type == DataType::QASYMM8) 94*c217d954SCole Faust { 95*c217d954SCole Faust std::uniform_int_distribution<uint32_t> distribution(0, 30); 96*c217d954SCole Faust library->fill(tensor, distribution, i); 97*c217d954SCole Faust } 98*c217d954SCole Faust else if(_data_type == DataType::QASYMM8_SIGNED) 99*c217d954SCole Faust { 100*c217d954SCole Faust std::uniform_int_distribution<int32_t> distribution(-15, 15); 101*c217d954SCole Faust library->fill(tensor, distribution, i); 102*c217d954SCole Faust } 103*c217d954SCole Faust else if(_data_type == DataType::S32) 104*c217d954SCole Faust { 105*c217d954SCole Faust std::uniform_int_distribution<int32_t> distribution(-50, 50); 106*c217d954SCole Faust library->fill(tensor, distribution, i); 107*c217d954SCole Faust } 108*c217d954SCole Faust else if(_data_type == DataType::F16) 109*c217d954SCole Faust { 110*c217d954SCole Faust arm_compute::utils::uniform_real_distribution_16bit<half> distribution(-1.0f, 1.0f); 111*c217d954SCole Faust library->fill(tensor, distribution, i); 112*c217d954SCole Faust } 113*c217d954SCole Faust else if(_data_type == DataType::F32) 114*c217d954SCole Faust { 115*c217d954SCole Faust std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); 116*c217d954SCole Faust library->fill(tensor, distribution, i); 117*c217d954SCole Faust } 118*c217d954SCole Faust else 119*c217d954SCole Faust { 120*c217d954SCole Faust library->fill_tensor_uniform(tensor, i); 121*c217d954SCole Faust } 122*c217d954SCole Faust } 123*c217d954SCole Faust compute_target(const TensorShape & input_shape,const TensorShape & weights_shape,const TensorShape & bias_shape,const TensorShape & output_shape,bool transpose_weights,bool reshape_weights)124*c217d954SCole Faust TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape, bool transpose_weights, 125*c217d954SCole Faust bool reshape_weights) 126*c217d954SCole Faust { 127*c217d954SCole Faust TensorShape reshaped_weights_shape(weights_shape); 128*c217d954SCole Faust 129*c217d954SCole Faust // Test actions depending on the target settings 130*c217d954SCole Faust // 131*c217d954SCole Faust // | reshape | !reshape 132*c217d954SCole Faust // -----------+-----------+--------------------------- 133*c217d954SCole Faust // transpose | | *** 134*c217d954SCole Faust // -----------+-----------+--------------------------- 135*c217d954SCole Faust // !transpose | transpose | transpose 136*c217d954SCole Faust // | | 137*c217d954SCole Faust // 138*c217d954SCole Faust // ***: That combination is invalid. But we can ignore the transpose flag and handle all !reshape the same 139*c217d954SCole Faust if(!reshape_weights || !transpose_weights) 140*c217d954SCole Faust { 141*c217d954SCole Faust const size_t shape_x = reshaped_weights_shape.x(); 142*c217d954SCole Faust reshaped_weights_shape.set(0, reshaped_weights_shape.y()); 143*c217d954SCole Faust reshaped_weights_shape.set(1, shape_x); 144*c217d954SCole Faust } 145*c217d954SCole Faust 146*c217d954SCole Faust // Create tensors 147*c217d954SCole Faust TensorType src = create_tensor<TensorType>(input_shape, _data_type, 1, _quantization_info); 148*c217d954SCole Faust TensorType weights = create_tensor<TensorType>(reshaped_weights_shape, _data_type, 1, _quantization_info); 149*c217d954SCole Faust TensorType bias = create_tensor<TensorType>(bias_shape, _bias_data_type, 1, _quantization_info); 150*c217d954SCole Faust TensorType dst = create_tensor<TensorType>(output_shape, _data_type, 1, _quantization_info); 151*c217d954SCole Faust 152*c217d954SCole Faust // Create Fully Connected layer info 153*c217d954SCole Faust FullyConnectedLayerInfo fc_info; 154*c217d954SCole Faust fc_info.transpose_weights = transpose_weights; 155*c217d954SCole Faust fc_info.are_weights_reshaped = !reshape_weights; 156*c217d954SCole Faust fc_info.activation_info = _activation_info; 157*c217d954SCole Faust 158*c217d954SCole Faust // Create and configure function. 159*c217d954SCole Faust FunctionType fc; 160*c217d954SCole Faust fc.configure(&src, &weights, &bias, &dst, fc_info); 161*c217d954SCole Faust 162*c217d954SCole Faust ARM_COMPUTE_ASSERT(src.info()->is_resizable()); 163*c217d954SCole Faust ARM_COMPUTE_ASSERT(weights.info()->is_resizable()); 164*c217d954SCole Faust ARM_COMPUTE_ASSERT(bias.info()->is_resizable()); 165*c217d954SCole Faust ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); 166*c217d954SCole Faust 167*c217d954SCole Faust add_padding_x({ &src, &weights, &bias, &dst }); 168*c217d954SCole Faust 169*c217d954SCole Faust // Allocate tensors 170*c217d954SCole Faust src.allocator()->allocate(); 171*c217d954SCole Faust weights.allocator()->allocate(); 172*c217d954SCole Faust bias.allocator()->allocate(); 173*c217d954SCole Faust dst.allocator()->allocate(); 174*c217d954SCole Faust 175*c217d954SCole Faust ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); 176*c217d954SCole Faust ARM_COMPUTE_ASSERT(!weights.info()->is_resizable()); 177*c217d954SCole Faust ARM_COMPUTE_ASSERT(!bias.info()->is_resizable()); 178*c217d954SCole Faust ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); 179*c217d954SCole Faust 180*c217d954SCole Faust // Fill tensors 181*c217d954SCole Faust fill(AccessorType(src), 0); 182*c217d954SCole Faust fill(AccessorType(bias), 2); 183*c217d954SCole Faust 184*c217d954SCole Faust if(!reshape_weights || !transpose_weights) 185*c217d954SCole Faust { 186*c217d954SCole Faust TensorShape tmp_shape(weights_shape); 187*c217d954SCole Faust RawTensor tmp(tmp_shape, _data_type, 1); 188*c217d954SCole Faust 189*c217d954SCole Faust // Fill with original shape 190*c217d954SCole Faust fill(tmp, 1); 191*c217d954SCole Faust 192*c217d954SCole Faust // Transpose elementwise 193*c217d954SCole Faust tmp = transpose(tmp); 194*c217d954SCole Faust 195*c217d954SCole Faust AccessorType weights_accessor(weights); 196*c217d954SCole Faust 197*c217d954SCole Faust for(int i = 0; i < tmp.num_elements(); ++i) 198*c217d954SCole Faust { 199*c217d954SCole Faust Coordinates coord = index2coord(tmp.shape(), i); 200*c217d954SCole Faust std::copy_n(static_cast<const RawTensor::value_type *>(tmp(coord)), 201*c217d954SCole Faust tmp.element_size(), 202*c217d954SCole Faust static_cast<RawTensor::value_type *>(weights_accessor(coord))); 203*c217d954SCole Faust } 204*c217d954SCole Faust } 205*c217d954SCole Faust else 206*c217d954SCole Faust { 207*c217d954SCole Faust fill(AccessorType(weights), 1); 208*c217d954SCole Faust } 209*c217d954SCole Faust 210*c217d954SCole Faust if(_mixed_layout) 211*c217d954SCole Faust { 212*c217d954SCole Faust mix_layout(fc, src, dst); 213*c217d954SCole Faust } 214*c217d954SCole Faust else 215*c217d954SCole Faust { 216*c217d954SCole Faust // Compute NEFullyConnectedLayer function 217*c217d954SCole Faust fc.run(); 218*c217d954SCole Faust } 219*c217d954SCole Faust 220*c217d954SCole Faust return dst; 221*c217d954SCole Faust } 222*c217d954SCole Faust compute_reference(const TensorShape & input_shape,const TensorShape & weights_shape,const TensorShape & bias_shape,const TensorShape & output_shape)223*c217d954SCole Faust SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const TensorShape &bias_shape, const TensorShape &output_shape) 224*c217d954SCole Faust { 225*c217d954SCole Faust // Create reference 226*c217d954SCole Faust SimpleTensor<T> src{ input_shape, _data_type, 1, _quantization_info }; 227*c217d954SCole Faust SimpleTensor<T> weights{ weights_shape, _data_type, 1, _quantization_info }; 228*c217d954SCole Faust SimpleTensor<TBias> bias{ bias_shape, _bias_data_type, 1, _quantization_info }; 229*c217d954SCole Faust 230*c217d954SCole Faust // Fill reference 231*c217d954SCole Faust fill(src, 0); 232*c217d954SCole Faust fill(weights, 1); 233*c217d954SCole Faust fill(bias, 2); 234*c217d954SCole Faust 235*c217d954SCole Faust return reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, output_shape, _quantization_info), _activation_info, _quantization_info); 236*c217d954SCole Faust } 237*c217d954SCole Faust 238*c217d954SCole Faust TensorType _target{}; 239*c217d954SCole Faust SimpleTensor<T> _reference{}; 240*c217d954SCole Faust DataType _data_type{}; 241*c217d954SCole Faust DataType _bias_data_type{}; 242*c217d954SCole Faust bool _mixed_layout{ false }; 243*c217d954SCole Faust QuantizationInfo _quantization_info{}; 244*c217d954SCole Faust ActivationLayerInfo _activation_info{}; 245*c217d954SCole Faust }; 246*c217d954SCole Faust 247*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> 248*c217d954SCole Faust class FullyConnectedLayerValidationFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> 249*c217d954SCole Faust { 250*c217d954SCole Faust public: 251*c217d954SCole Faust template <typename...> setup(TensorShape input_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape output_shape,bool transpose_weights,bool reshape_weights,DataType data_type,ActivationLayerInfo activation_info)252*c217d954SCole Faust void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, 253*c217d954SCole Faust ActivationLayerInfo activation_info) 254*c217d954SCole Faust { 255*c217d954SCole Faust FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, 256*c217d954SCole Faust reshape_weights, data_type, 257*c217d954SCole Faust QuantizationInfo(), activation_info, mixed_layout); 258*c217d954SCole Faust } 259*c217d954SCole Faust }; 260*c217d954SCole Faust 261*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T, bool mixed_layout = false> 262*c217d954SCole Faust class FullyConnectedLayerValidationQuantizedFixture : public FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T> 263*c217d954SCole Faust { 264*c217d954SCole Faust public: 265*c217d954SCole Faust template <typename...> setup(TensorShape input_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape output_shape,bool transpose_weights,bool reshape_weights,DataType data_type,QuantizationInfo quantization_info,ActivationLayerInfo activation_info)266*c217d954SCole Faust void setup(TensorShape input_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape output_shape, bool transpose_weights, bool reshape_weights, DataType data_type, 267*c217d954SCole Faust QuantizationInfo quantization_info, ActivationLayerInfo activation_info) 268*c217d954SCole Faust { 269*c217d954SCole Faust FullyConnectedLayerValidationGenericFixture<TensorType, AccessorType, FunctionType, T>::setup(input_shape, weights_shape, bias_shape, output_shape, transpose_weights, 270*c217d954SCole Faust reshape_weights, data_type, 271*c217d954SCole Faust quantization_info, activation_info, mixed_layout); 272*c217d954SCole Faust } 273*c217d954SCole Faust }; 274*c217d954SCole Faust 275*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T> 276*c217d954SCole Faust class FullyConnectedWithDynamicTensorsFixture : public framework::Fixture 277*c217d954SCole Faust { 278*c217d954SCole Faust private: 279*c217d954SCole Faust template <typename U> fill(U && tensor,int i)280*c217d954SCole Faust void fill(U &&tensor, int i) 281*c217d954SCole Faust { 282*c217d954SCole Faust if(_data_type == DataType::F16) 283*c217d954SCole Faust { 284*c217d954SCole Faust arm_compute::utils::uniform_real_distribution_16bit<half> distribution(-1.0f, 1.0f); 285*c217d954SCole Faust library->fill(tensor, distribution, i); 286*c217d954SCole Faust } 287*c217d954SCole Faust else if(_data_type == DataType::F32) 288*c217d954SCole Faust { 289*c217d954SCole Faust std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); 290*c217d954SCole Faust library->fill(tensor, distribution, i); 291*c217d954SCole Faust } 292*c217d954SCole Faust else if(_data_type == DataType::QASYMM8) 293*c217d954SCole Faust { 294*c217d954SCole Faust std::uniform_int_distribution<uint32_t> distribution(0, 30); 295*c217d954SCole Faust library->fill(tensor, distribution, i); 296*c217d954SCole Faust } 297*c217d954SCole Faust else if(_data_type == DataType::S32) 298*c217d954SCole Faust { 299*c217d954SCole Faust std::uniform_int_distribution<int32_t> distribution(-50, 50); 300*c217d954SCole Faust library->fill(tensor, distribution, i); 301*c217d954SCole Faust } 302*c217d954SCole Faust else 303*c217d954SCole Faust { 304*c217d954SCole Faust library->fill_tensor_uniform(tensor, i); 305*c217d954SCole Faust } 306*c217d954SCole Faust } 307*c217d954SCole Faust fill_transposed_weights(TensorType & weights,TensorShape weights_shape,int seed)308*c217d954SCole Faust void fill_transposed_weights(TensorType &weights, TensorShape weights_shape, int seed) 309*c217d954SCole Faust { 310*c217d954SCole Faust RawTensor tmp(weights_shape, _data_type, 1); 311*c217d954SCole Faust 312*c217d954SCole Faust // Fill with original shape 313*c217d954SCole Faust fill(tmp, seed); 314*c217d954SCole Faust 315*c217d954SCole Faust // Transpose elementwise 316*c217d954SCole Faust tmp = transpose(tmp); 317*c217d954SCole Faust 318*c217d954SCole Faust AccessorType weights_accessor(weights); 319*c217d954SCole Faust 320*c217d954SCole Faust for(int i = 0; i < tmp.num_elements(); ++i) 321*c217d954SCole Faust { 322*c217d954SCole Faust Coordinates coord = index2coord(tmp.shape(), i); 323*c217d954SCole Faust std::copy_n(static_cast<const RawTensor::value_type *>(tmp(coord)), 324*c217d954SCole Faust tmp.element_size(), 325*c217d954SCole Faust static_cast<RawTensor::value_type *>(weights_accessor(coord))); 326*c217d954SCole Faust } 327*c217d954SCole Faust } 328*c217d954SCole Faust validate_with_tolerance(TensorType & target,SimpleTensor<T> & ref)329*c217d954SCole Faust void validate_with_tolerance(TensorType &target, SimpleTensor<T> &ref) 330*c217d954SCole Faust { 331*c217d954SCole Faust if(_data_type == DataType::F32) 332*c217d954SCole Faust { 333*c217d954SCole Faust constexpr RelativeTolerance<float> rel_tolerance_f32(0.05f); 334*c217d954SCole Faust constexpr AbsoluteTolerance<float> abs_tolerance_f32(0.0001f); 335*c217d954SCole Faust validate(AccessorType(target), ref, rel_tolerance_f32, 0, abs_tolerance_f32); 336*c217d954SCole Faust } 337*c217d954SCole Faust else if(_data_type == DataType::QASYMM8) 338*c217d954SCole Faust { 339*c217d954SCole Faust constexpr AbsoluteTolerance<uint32_t> tolerance_qasymm8(1); 340*c217d954SCole Faust validate(AccessorType(target), ref, tolerance_qasymm8); 341*c217d954SCole Faust } 342*c217d954SCole Faust else 343*c217d954SCole Faust { 344*c217d954SCole Faust validate(AccessorType(target), ref); 345*c217d954SCole Faust } 346*c217d954SCole Faust } 347*c217d954SCole Faust 348*c217d954SCole Faust public: 349*c217d954SCole Faust using TDecay = typename std::decay<T>::type; 350*c217d954SCole Faust using TBias = typename std::conditional < (std::is_same<TDecay, uint8_t>::value || std::is_same<TDecay, int8_t>::value), int32_t, T >::type; 351*c217d954SCole Faust 352*c217d954SCole Faust template <typename...> setup(TensorShape src_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape dst_shape,DataType data_type,ActivationLayerInfo activation_info,bool constant_weights,bool constant_bias)353*c217d954SCole Faust void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, 354*c217d954SCole Faust DataType data_type, ActivationLayerInfo activation_info, bool constant_weights, bool constant_bias) 355*c217d954SCole Faust { 356*c217d954SCole Faust _data_type = data_type; 357*c217d954SCole Faust 358*c217d954SCole Faust const bool is_quantized = is_data_type_quantized(data_type); 359*c217d954SCole Faust 360*c217d954SCole Faust const DataType bias_data_type = (is_quantized) ? DataType::S32 : data_type; 361*c217d954SCole Faust 362*c217d954SCole Faust const QuantizationInfo src_qinfo = is_quantized ? QuantizationInfo(0.1f, 10) : QuantizationInfo(); 363*c217d954SCole Faust const QuantizationInfo weights_qinfo = is_quantized ? QuantizationInfo(0.3f, 20) : QuantizationInfo(); 364*c217d954SCole Faust const QuantizationInfo dst_qinfo = is_quantized ? QuantizationInfo(0.2f, 5) : QuantizationInfo(); 365*c217d954SCole Faust 366*c217d954SCole Faust // Setup tensor meta-data 367*c217d954SCole Faust const TensorInfo src_info(src_shape, 1, data_type, src_qinfo); 368*c217d954SCole Faust _src.allocator()->init(src_info); 369*c217d954SCole Faust 370*c217d954SCole Faust TensorInfo wei_info(weights_shape, 1, data_type, weights_qinfo); 371*c217d954SCole Faust if(!constant_weights) 372*c217d954SCole Faust { 373*c217d954SCole Faust const TensorShape tr_weights_shape{ weights_shape[1], weights_shape[0] }; 374*c217d954SCole Faust wei_info.set_tensor_shape(tr_weights_shape); 375*c217d954SCole Faust } 376*c217d954SCole Faust wei_info.set_are_values_constant(constant_weights); 377*c217d954SCole Faust _weights.allocator()->init(wei_info); 378*c217d954SCole Faust 379*c217d954SCole Faust TensorInfo bias_info(bias_shape, 1, bias_data_type); 380*c217d954SCole Faust bias_info.set_are_values_constant(constant_bias); 381*c217d954SCole Faust _bias.allocator()->init(bias_info); 382*c217d954SCole Faust 383*c217d954SCole Faust const TensorInfo dst_info(dst_shape, 1, data_type, dst_qinfo); 384*c217d954SCole Faust _dst.allocator()->init(dst_info); 385*c217d954SCole Faust 386*c217d954SCole Faust // Configure FC layer and mark the weights as non constant 387*c217d954SCole Faust FullyConnectedLayerInfo fc_info; 388*c217d954SCole Faust fc_info.activation_info = activation_info; 389*c217d954SCole Faust if(!constant_weights) 390*c217d954SCole Faust { 391*c217d954SCole Faust fc_info.are_weights_reshaped = true; 392*c217d954SCole Faust fc_info.transpose_weights = false; 393*c217d954SCole Faust } 394*c217d954SCole Faust FunctionType fc; 395*c217d954SCole Faust fc.configure(&_src, &_weights, &_bias, &_dst, fc_info); 396*c217d954SCole Faust 397*c217d954SCole Faust // Allocate all the tensors 398*c217d954SCole Faust _src.allocator()->allocate(); 399*c217d954SCole Faust _weights.allocator()->allocate(); 400*c217d954SCole Faust _bias.allocator()->allocate(); 401*c217d954SCole Faust _dst.allocator()->allocate(); 402*c217d954SCole Faust 403*c217d954SCole Faust // Run multiple iterations with different inputs 404*c217d954SCole Faust constexpr int num_iterations = 5; 405*c217d954SCole Faust int randomizer_offset = 0; 406*c217d954SCole Faust 407*c217d954SCole Faust // Create reference tensors 408*c217d954SCole Faust SimpleTensor<T> src{ src_shape, data_type, 1, src_qinfo }; 409*c217d954SCole Faust SimpleTensor<T> weights{ weights_shape, data_type, 1, weights_qinfo }; 410*c217d954SCole Faust SimpleTensor<TBias> bias{ bias_shape, bias_data_type }; 411*c217d954SCole Faust 412*c217d954SCole Faust // Fill weights and/or bias if they remain constant 413*c217d954SCole Faust if(constant_weights) 414*c217d954SCole Faust { 415*c217d954SCole Faust fill(AccessorType(_weights), 1); 416*c217d954SCole Faust fill(weights, 1); 417*c217d954SCole Faust } 418*c217d954SCole Faust if(constant_bias) 419*c217d954SCole Faust { 420*c217d954SCole Faust fill(AccessorType(_bias), 2); 421*c217d954SCole Faust fill(bias, 2); 422*c217d954SCole Faust } 423*c217d954SCole Faust 424*c217d954SCole Faust for(int i = 0; i < num_iterations; ++i) 425*c217d954SCole Faust { 426*c217d954SCole Faust // Run target 427*c217d954SCole Faust { 428*c217d954SCole Faust fill(AccessorType(_src), randomizer_offset); 429*c217d954SCole Faust if(!constant_weights) 430*c217d954SCole Faust { 431*c217d954SCole Faust fill_transposed_weights(_weights, weights_shape, randomizer_offset + 1); 432*c217d954SCole Faust } 433*c217d954SCole Faust if(!constant_bias) 434*c217d954SCole Faust { 435*c217d954SCole Faust fill(AccessorType(_bias), randomizer_offset + 2); 436*c217d954SCole Faust } 437*c217d954SCole Faust 438*c217d954SCole Faust fc.run(); 439*c217d954SCole Faust } 440*c217d954SCole Faust 441*c217d954SCole Faust // Run reference and compare 442*c217d954SCole Faust { 443*c217d954SCole Faust // Fill reference 444*c217d954SCole Faust fill(src, randomizer_offset); 445*c217d954SCole Faust if(!constant_weights) 446*c217d954SCole Faust { 447*c217d954SCole Faust fill(weights, randomizer_offset + 1); 448*c217d954SCole Faust } 449*c217d954SCole Faust if(!constant_bias) 450*c217d954SCole Faust { 451*c217d954SCole Faust fill(bias, randomizer_offset + 2); 452*c217d954SCole Faust } 453*c217d954SCole Faust 454*c217d954SCole Faust auto dst = reference::activation_layer(reference::fully_connected_layer<T>(src, weights, bias, dst_shape, dst_qinfo), activation_info, dst_qinfo); 455*c217d954SCole Faust 456*c217d954SCole Faust // Validate 457*c217d954SCole Faust validate_with_tolerance(_dst, dst); 458*c217d954SCole Faust } 459*c217d954SCole Faust 460*c217d954SCole Faust randomizer_offset += 100; 461*c217d954SCole Faust } 462*c217d954SCole Faust } 463*c217d954SCole Faust 464*c217d954SCole Faust private: 465*c217d954SCole Faust TensorType _src{}, _weights{}, _bias{}, _dst{}; 466*c217d954SCole Faust DataType _data_type{ DataType::UNKNOWN }; 467*c217d954SCole Faust }; 468*c217d954SCole Faust 469*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T> 470*c217d954SCole Faust class FullyConnectedWithDynamicWeightsFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T> 471*c217d954SCole Faust { 472*c217d954SCole Faust public: 473*c217d954SCole Faust template <typename...> setup(TensorShape src_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape dst_shape,DataType data_type,ActivationLayerInfo activation_info)474*c217d954SCole Faust void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, 475*c217d954SCole Faust DataType data_type, ActivationLayerInfo activation_info) 476*c217d954SCole Faust { 477*c217d954SCole Faust FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape, 478*c217d954SCole Faust dst_shape, data_type, activation_info, false, true); 479*c217d954SCole Faust } 480*c217d954SCole Faust }; 481*c217d954SCole Faust 482*c217d954SCole Faust template <typename TensorType, typename AccessorType, typename FunctionType, typename T> 483*c217d954SCole Faust class FullyConnectedWithDynamicBiasFixture : public FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T> 484*c217d954SCole Faust { 485*c217d954SCole Faust public: 486*c217d954SCole Faust template <typename...> setup(TensorShape src_shape,TensorShape weights_shape,TensorShape bias_shape,TensorShape dst_shape,DataType data_type,ActivationLayerInfo activation_info)487*c217d954SCole Faust void setup(TensorShape src_shape, TensorShape weights_shape, TensorShape bias_shape, TensorShape dst_shape, 488*c217d954SCole Faust DataType data_type, ActivationLayerInfo activation_info) 489*c217d954SCole Faust { 490*c217d954SCole Faust FullyConnectedWithDynamicTensorsFixture<TensorType, AccessorType, FunctionType, T>::setup(src_shape, weights_shape, bias_shape, 491*c217d954SCole Faust dst_shape, data_type, activation_info, true, false); 492*c217d954SCole Faust } 493*c217d954SCole Faust }; 494*c217d954SCole Faust } // namespace validation 495*c217d954SCole Faust } // namespace test 496*c217d954SCole Faust } // namespace arm_compute 497*c217d954SCole Faust #endif /* ARM_COMPUTE_TEST_FULLY_CONNECTED_LAYER_FIXTURE */ 498