xref: /aosp_15_r20/external/ComputeLibrary/tests/validate_examples/graph_fully_connected.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1 /*
2  * Copyright (c) 2019-2020 Arm Limited.
3  *
4  * SPDX-License-Identifier: MIT
5  *
6  * Permission is hereby granted, free of charge, to any person obtaining a copy
7  * of this software and associated documentation files (the "Software"), to
8  * deal in the Software without restriction, including without limitation the
9  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10  * sell copies of the Software, and to permit persons to whom the Software is
11  * furnished to do so, subject to the following conditions:
12  *
13  * The above copyright notice and this permission notice shall be included in all
14  * copies or substantial portions of the Software.
15  *
16  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22  * SOFTWARE.
23  */
24 #include "arm_compute/graph.h"
25 
26 #include "tests/NEON/Accessor.h"
27 #include "tests/validation/Validation.h"
28 #include "tests/validation/reference/FullyConnectedLayer.h"
29 #include "tests/validation/reference/Permute.h"
30 
31 #include "utils/CommonGraphOptions.h"
32 #include "utils/GraphUtils.h"
33 #include "utils/Utils.h"
34 
35 #include "ValidateExample.h"
36 #include "graph_validate_utils.h"
37 
38 #include <utility>
39 
40 using namespace arm_compute::utils;
41 using namespace arm_compute::graph::frontend;
42 using namespace arm_compute::graph_utils;
43 using namespace arm_compute::graph;
44 using namespace arm_compute;
45 using namespace arm_compute::test;
46 using namespace arm_compute::test::validation;
47 
48 namespace
49 {
50 /** Fully connected command line options used to configure the graph examples
51  *
52  * (Similar to common options)
53  * The options in this object get populated when "parse()" is called on the parser used to construct it.
54  * The expected workflow is:
55  *
56  * CommandLineParser parser;
57  * CommonOptions options( parser );
58  * parser.parse(argc, argv);
59  */
60 class FullyConnectedOptions final : public CommonGraphValidateOptions
61 {
62 public:
FullyConnectedOptions(CommandLineParser & parser)63     explicit FullyConnectedOptions(CommandLineParser &parser) noexcept
64         : CommonGraphValidateOptions(parser),
65           width(parser.add_option<SimpleOption<int>>("width", 3)),
66           batch(parser.add_option<SimpleOption<int>>("batch", 1)),
67           input_scale(parser.add_option<SimpleOption<float>>("input_scale", 1.0f)),
68           input_offset(parser.add_option<SimpleOption<int>>("input_offset", 0)),
69           weights_scale(parser.add_option<SimpleOption<float>>("weights_scale", 1.0f)),
70           weights_offset(parser.add_option<SimpleOption<int>>("weights_offset", 0)),
71           output_scale(parser.add_option<SimpleOption<float>>("output_scale", 1.0f)),
72           output_offset(parser.add_option<SimpleOption<int>>("output_offset", 0)),
73           num_outputs(parser.add_option<SimpleOption<int>>("num_outputs", 1)),
74           input_range_low(parser.add_option<SimpleOption<uint64_t>>("input_range_low")),
75           input_range_high(parser.add_option<SimpleOption<uint64_t>>("input_range_high")),
76           weights_range_low(parser.add_option<SimpleOption<uint64_t>>("weights_range_low")),
77           weights_range_high(parser.add_option<SimpleOption<uint64_t>>("weights_range_high"))
78     {
79         width->set_help("Set Input dimension width");
80         batch->set_help("Set Input dimension batch");
81         input_scale->set_help("Quantization scale from QASYMM8");
82         input_offset->set_help("Quantization offset from QASYMM8");
83         weights_scale->set_help("Quantization scale from QASYMM8");
84         weights_offset->set_help("Quantization offset from QASYMM8");
85         output_scale->set_help("Quantization scale from QASYMM8");
86         output_offset->set_help("Quantization offset from QASYMM8");
87         num_outputs->set_help("Number of outputs.");
88         input_range_low->set_help("Lower bound for input randomization range");
89         input_range_high->set_help("Lower bound for input randomization range");
90         weights_range_low->set_help("Lower bound for input randomization range");
91         weights_range_high->set_help("Lower bound for input randomization range");
92     }
93 
94     /** Fill out the supplied parameters with user supplied parameters
95      *
96      * @param[out] os            Output stream.
97      * @param[in]  common_params Example parameters to output
98      *
99      * @return None.
100      */
consume_parameters(ExampleParams & common_params)101     void consume_parameters(ExampleParams &common_params)
102     {
103         common_params.input.width      = width->value();
104         common_params.input.batch      = batch->value();
105         common_params.input.quant_info = QuantizationInfo(input_scale->value(), input_offset->value());
106         common_params.input.range_low  = input_range_low->value();
107         common_params.input.range_high = input_range_high->value();
108 
109         common_params.weights.quant_info = QuantizationInfo(weights_scale->value(), weights_offset->value());
110         common_params.weights.range_low  = weights_range_low->value();
111         common_params.weights.range_high = weights_range_high->value();
112 
113         common_params.output.quant_info = QuantizationInfo(output_scale->value(), output_offset->value());
114 
115         common_params.data_type                   = data_type->value();
116         common_params.fully_connected.num_outputs = num_outputs->value();
117     }
118 
print_parameters(::std::ostream & os,const ExampleParams & common_params)119     void print_parameters(::std::ostream &os, const ExampleParams &common_params) override
120     {
121         os << "Threads : " << common_params.common_params.threads << std::endl;
122         os << "Target : " << common_params.common_params.target << std::endl;
123         os << "Data type : " << common_params.data_type << std::endl;
124         os << "Input dimensions(X,Y, Channels, Batch) : (" << common_params.input.width << "," << common_params.input.height << "," << common_params.input.fm << "," << common_params.input.batch << ")"
125            << std::endl;
126         os << "Number of outputs : " << common_params.fully_connected.num_outputs << std::endl;
127     }
128 
129     /** Prevent instances of this class from being copied (As this class contains pointers) */
130     FullyConnectedOptions(const FullyConnectedOptions &) = delete;
131     /** Prevent instances of this class from being copied (As this class contains pointers) */
132     FullyConnectedOptions &operator=(const FullyConnectedOptions &) = delete;
133     /** Allow instances of this class to be moved */
134     FullyConnectedOptions(FullyConnectedOptions &&) noexcept(true) = default;
135     /** Allow instances of this class to be moved */
136     FullyConnectedOptions &operator=(FullyConnectedOptions &&) noexcept(true) = default;
137     /** Default destructor */
138     ~FullyConnectedOptions() override = default;
139 
140 private:
141     SimpleOption<int>      *width;              /**< Input width */
142     SimpleOption<int>      *batch;              /**< Input batch */
143     SimpleOption<float>    *input_scale;        /**< Input Quantization scale from QASSYMM8 */
144     SimpleOption<int>      *input_offset;       /**< Input Quantization offset from QASSYMM8 */
145     SimpleOption<float>    *weights_scale;      /**< Weights Quantization scale from QASSYMM8 */
146     SimpleOption<int>      *weights_offset;     /**< Weights Quantization offset from QASSYMM8 */
147     SimpleOption<float>    *output_scale;       /**< Output Quantization scale from QASSYMM8 */
148     SimpleOption<int>      *output_offset;      /**< Output Quantization offset from QASSYMM8 */
149     SimpleOption<int>      *num_outputs;        /**< Number of outputs. */
150     SimpleOption<uint64_t> *input_range_low;    /**< Lower bound for input randomization range */
151     SimpleOption<uint64_t> *input_range_high;   /**< Upper bound for input randomization range */
152     SimpleOption<uint64_t> *weights_range_low;  /**< Lower bound for weights randomization range */
153     SimpleOption<uint64_t> *weights_range_high; /**< Upper bound for weights randomization range */
154 };
155 
156 /** Fully Connected Layer Graph example validation accessor class */
157 template <typename D>
158 class FullyConnectedVerifyAccessor final : public VerifyAccessor<D>
159 {
160     using BaseClassType = VerifyAccessor<D>;
161     using BaseClassType::BaseClassType;
162     using BaseClassType::_params;
163     using TBias = typename std::conditional<std::is_same<typename std::decay<D>::type, uint8_t>::value, int32_t, D>::type;
164 
165     // Inherited methods overriden:
create_tensors(arm_compute::test::SimpleTensor<D> & src,arm_compute::test::SimpleTensor<D> & weights,arm_compute::test::SimpleTensor<TBias> & bias,ITensor & tensor)166     void create_tensors(arm_compute::test::SimpleTensor<D>     &src,
167                         arm_compute::test::SimpleTensor<D>     &weights,
168                         arm_compute::test::SimpleTensor<TBias> &bias,
169                         ITensor                                &tensor) override
170     {
171         // Calculate Tensor shapes for verification
172         const TensorShape      input_shape        = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
173         const TensorDescriptor input_descriptor   = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
174         const TensorDescriptor weights_descriptor = FullyConnectedLayerNode::compute_weights_descriptor(input_descriptor,
175                                                                                                         _params.fully_connected.num_outputs,
176                                                                                                         _params.fully_connected.info,
177                                                                                                         _params.weights.quant_info);
178         const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
179 
180         //Create Input tensors
181         src     = SimpleTensor<D> { input_descriptor.shape, _params.data_type, 1, input_descriptor.quant_info };
182         weights = SimpleTensor<D> { weights_descriptor.shape, _params.data_type, 1, weights_descriptor.quant_info };
183         bias    = SimpleTensor<TBias> { TensorShape(tensor.info()->tensor_shape().x()), _params.data_type, 1, _params.input.quant_info };
184     }
185 
output_shape(ITensor & tensor)186     TensorShape output_shape(ITensor &tensor) override
187     {
188         ARM_COMPUTE_UNUSED(tensor);
189 
190         const TensorShape      input_shape      = TensorShape(_params.input.width, _params.input.height, _params.input.fm, _params.input.batch);
191         const TensorDescriptor input_descriptor = TensorDescriptor(input_shape, _params.data_type, _params.input.quant_info);
192         const TensorDescriptor output_desciptor = FullyConnectedLayerNode::compute_output_descriptor(input_descriptor, _params.fully_connected.num_outputs, _params.output.quant_info);
193 
194         return output_desciptor.shape;
195     }
196 
reference(arm_compute::test::SimpleTensor<D> & src,arm_compute::test::SimpleTensor<D> & weights,arm_compute::test::SimpleTensor<TBias> & bias,const arm_compute::TensorShape & output_shape)197     arm_compute::test::SimpleTensor<D> reference(arm_compute::test::SimpleTensor<D>     &src,
198                                                  arm_compute::test::SimpleTensor<D>     &weights,
199                                                  arm_compute::test::SimpleTensor<TBias> &bias,
200                                                  const arm_compute::TensorShape         &output_shape) override
201     {
202         return reference::fully_connected_layer<D>(src, weights, bias, output_shape, _params.output.quant_info);
203     }
204 
relative_tolerance()205     float relative_tolerance() override
206     {
207         const std::map<arm_compute::graph::Target, const std::map<DataType, float>> relative_tolerance
208         {
209             {
210                 arm_compute::graph::Target::CL,
211                 {   { DataType::F16, 0.2f },
212                     { DataType::F32, 0.05f },
213                     { DataType::QASYMM8, 1.0f }
214                 }
215             },
216             {
217                 arm_compute::graph::Target::NEON,
218                 {   { DataType::F16, 0.2f },
219                     { DataType::F32, 0.01f },
220                     { DataType::QASYMM8, 1.0f }
221                 }
222             }
223         };
224 
225         return relative_tolerance.at(_params.common_params.target).at(_params.data_type);
226     }
227 
absolute_tolerance()228     float absolute_tolerance() override
229     {
230         const std::map<Target, const std::map<DataType, float>> absolute_tolerance
231         {
232             {
233                 Target::CL,
234                 {   { DataType::F16, 0.0f },
235                     { DataType::F32, 0.0001f },
236                     { DataType::QASYMM8, 1.0f }
237                 }
238             },
239             {
240                 Target::NEON,
241                 {   { DataType::F16, 0.3f },
242                     { DataType::F32, 0.1f },
243                     { DataType::QASYMM8, 1.0f }
244                 }
245             }
246         };
247 
248         return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
249     }
250 
tolerance_number()251     float tolerance_number() override
252     {
253         const std::map<Target, const std::map<DataType, float>> absolute_tolerance
254         {
255             {
256                 Target::CL,
257                 {   { DataType::F16, 0.07f },
258                     { DataType::F32, 0.07f },
259                     { DataType::QASYMM8, 0.0f }
260                 }
261             },
262             {
263                 Target::NEON,
264                 {   { DataType::F16, 0.07f },
265                     { DataType::F32, 0.0f },
266                     { DataType::QASYMM8, 0.0f }
267                 }
268             }
269         };
270 
271         return absolute_tolerance.at(_params.common_params.target).at(_params.data_type);
272     }
273 };
274 
275 } // namespace
276 
277 class GraphFullyConnectedValidateExample final : public GraphValidateExample<FullyConnectedLayer, FullyConnectedOptions, FullyConnectedVerifyAccessor>
278 {
279     using GraphValidateExample::graph;
280 
281 public:
GraphFullyConnectedValidateExample()282     GraphFullyConnectedValidateExample()
283         : GraphValidateExample("Fully_connected Graph example")
284     {
285     }
286 
GraphFunctionLayer(ExampleParams & params)287     FullyConnectedLayer GraphFunctionLayer(ExampleParams &params) override
288     {
289         const PixelValue lower = PixelValue(params.input.range_low, params.data_type, params.input.quant_info);
290         const PixelValue upper = PixelValue(params.input.range_high, params.data_type, params.input.quant_info);
291 
292         const PixelValue weights_lower = PixelValue(params.weights.range_low, params.data_type, params.weights.quant_info);
293         const PixelValue weights_upper = PixelValue(params.weights.range_high, params.data_type, params.weights.quant_info);
294 
295         return FullyConnectedLayer(params.fully_connected.num_outputs,
296                                    get_random_accessor(weights_lower, weights_upper, 1),
297                                    get_random_accessor(lower, upper, 2),
298                                    params.fully_connected.info, params.weights.quant_info, params.output.quant_info);
299     }
300 };
301 
302 /** Main program for Graph fully_connected test
303  *
304  * @param[in] argc Number of arguments
305  * @param[in] argv Arguments ( Input dimensions [width, batch]
306  *                             Fully connected  [num_outputs,type]
307  *                             Verification[tolerance_number,absolute_tolerance,relative_tolerance] )
308  *
309  */
main(int argc,char ** argv)310 int main(int argc, char **argv)
311 {
312     return arm_compute::utils::run_example<GraphFullyConnectedValidateExample>(argc, argv);
313 }
314