xref: /aosp_15_r20/external/armnn/tests/TfLiteYoloV3Big-Armnn/NMS.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <ostream>
9 #include <vector>
10 
11 namespace yolov3 {
12 /** Non Maxima Suprresion configuration meta-data */
13 struct NMSConfig {
14     unsigned int num_classes{0};      /**< Number of classes in the detected boxes */
15     unsigned int num_boxes{0};        /**< Number of detected boxes */
16     float confidence_threshold{0.8f}; /**< Inclusion confidence threshold for a box */
17     float iou_threshold{0.8f};        /**< Inclusion threshold for Intersection-Over-Union */
18 };
19 
20 /** Box representation structure */
21 struct Box {
22     float xmin;  /**< X-pos position of the low left coordinate */
23     float xmax;  /**< X-pos position of the top right coordinate */
24     float ymin;  /**< Y-pos position of the low left coordinate */
25     float ymax;  /**< Y-pos position of the top right coordinate */
26 };
27 
28 /** Detection structure */
29 struct Detection {
30     Box box;                    /**< Detection box */
31     float confidence;           /**< Confidence of detection */
32     std::vector<float> classes; /**< Probability of classes */
33 };
34 
35 /** Print identified yolo detections
36  *
37  * @param[in, out] os          Output stream to print to
38  * @param[in]      detections  Detections to print
39  */
40 void print_detection(std::ostream& os,
41                      const std::vector<Detection>& detections);
42 
43 /** Compare a detection object with a vector of float values
44  *
45  * @param detection [in] Detection object
46  * @param expected  [in] Vector of expected float values
47  * @return Boolean to represent if they match or not
48  */
49 bool compare_detection(const yolov3::Detection& detection,
50                        const std::vector<float>& expected);
51 
52 /** Perform Non-Maxima Supression on a list of given detections
53  *
54  * @param[in] config         Configuration metadata for NMS
55  * @param[in] detected_boxes Detected boxes
56  *
57  * @return A vector with the final detections
58  */
59 std::vector<Detection> nms(const NMSConfig& config,
60                            const std::vector<float>& detected_boxes);
61 } // namespace yolov3
62