1 /* 2 * Copyright (c) 2018-2022 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef ARM_COMPUTE_TEST_CONVERT_FULLY_CONNECTED_WEIGHTS_FIXTURE 25 #define ARM_COMPUTE_TEST_CONVERT_FULLY_CONNECTED_WEIGHTS_FIXTURE 26 27 #include "arm_compute/core/TensorShape.h" 28 #include "arm_compute/core/Types.h" 29 #include "tests/AssetsLibrary.h" 30 #include "tests/Globals.h" 31 #include "tests/IAccessor.h" 32 #include "tests/framework/Asserts.h" 33 #include "tests/framework/Fixture.h" 34 #include "tests/validation/reference/ConvertFullyConnectedWeights.h" 35 36 namespace arm_compute 37 { 38 namespace test 39 { 40 namespace validation 41 { 42 template <typename TensorType, typename AccessorType, typename FunctionType, typename T> 43 class ConvertFullyConnectedWeightsValidationFixture : public framework::Fixture 44 { 45 public: 46 template <typename...> setup(TensorShape input_shape,unsigned int weights_w,DataLayout training_data_layout,DataType data_type)47 void setup(TensorShape input_shape, unsigned int weights_w, DataLayout training_data_layout, DataType data_type) 48 { 49 const unsigned int height = input_shape.x() * input_shape.y() * input_shape.z(); 50 const TensorShape weights_shape(weights_w, height); 51 52 _target = compute_target(input_shape, weights_shape, training_data_layout, data_type); 53 _reference = compute_reference(input_shape, weights_shape, training_data_layout, data_type); 54 } 55 56 protected: 57 template <typename U> fill(U && tensor,int i)58 void fill(U &&tensor, int i) 59 { 60 switch(tensor.data_type()) 61 { 62 case DataType::QASYMM8: 63 { 64 std::uniform_int_distribution<uint32_t> distribution(0, 10); 65 library->fill(tensor, distribution, i); 66 break; 67 } 68 case DataType::F16: 69 { 70 arm_compute::utils::uniform_real_distribution_16bit<half> distribution{ -1.0f, 1.0f }; 71 library->fill(tensor, distribution, i); 72 break; 73 } 74 case DataType::F32: 75 { 76 std::uniform_real_distribution<float> distribution(-1.0f, 1.0f); 77 library->fill(tensor, distribution, i); 78 break; 79 } 80 default: 81 library->fill_tensor_uniform(tensor, i); 82 } 83 } 84 compute_target(const TensorShape & input_shape,const TensorShape & weights_shape,const DataLayout training_data_layout,const DataType data_type)85 TensorType compute_target(const TensorShape &input_shape, const TensorShape &weights_shape, const DataLayout training_data_layout, const DataType data_type) 86 { 87 // Create tensors 88 TensorType src = create_tensor<TensorType>(weights_shape, data_type); 89 TensorType dst = create_tensor<TensorType>(weights_shape, data_type); 90 91 // Create and configure function 92 FunctionType convert_weights; 93 94 convert_weights.configure(&src, &dst, input_shape, training_data_layout); 95 96 ARM_COMPUTE_ASSERT(src.info()->is_resizable()); 97 ARM_COMPUTE_ASSERT(dst.info()->is_resizable()); 98 99 // Allocate tensors 100 src.allocator()->allocate(); 101 dst.allocator()->allocate(); 102 103 ARM_COMPUTE_ASSERT(!src.info()->is_resizable()); 104 ARM_COMPUTE_ASSERT(!dst.info()->is_resizable()); 105 106 // Fill tensors 107 fill(AccessorType(src), 0); 108 109 // Compute function 110 convert_weights.run(); 111 112 return dst; 113 } 114 compute_reference(const TensorShape & input_shape,const TensorShape & weights_shape,const DataLayout training_data_layout,const DataType data_type)115 SimpleTensor<T> compute_reference(const TensorShape &input_shape, const TensorShape &weights_shape, const DataLayout training_data_layout, const DataType data_type) 116 { 117 // Create reference 118 SimpleTensor<T> src{ weights_shape, data_type }; 119 120 // Fill reference 121 fill(src, 0); 122 123 return reference::convert_fully_connected_weights(src, input_shape, training_data_layout); 124 } 125 126 TensorType _target{}; 127 SimpleTensor<T> _reference{}; 128 }; 129 } // namespace validation 130 } // namespace test 131 } // namespace arm_compute 132 #endif /* ARM_COMPUTE_TEST_CONVERT_FULLY_CONNECTED_WEIGHTS_FIXTURE */ 133