xref: /aosp_15_r20/external/armnn/tests/TfLiteYoloV3Big-Armnn/NMS.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <ostream>
9*89c4ff92SAndroid Build Coastguard Worker #include <vector>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker namespace yolov3 {
12*89c4ff92SAndroid Build Coastguard Worker /** Non Maxima Suprresion configuration meta-data */
13*89c4ff92SAndroid Build Coastguard Worker struct NMSConfig {
14*89c4ff92SAndroid Build Coastguard Worker     unsigned int num_classes{0};      /**< Number of classes in the detected boxes */
15*89c4ff92SAndroid Build Coastguard Worker     unsigned int num_boxes{0};        /**< Number of detected boxes */
16*89c4ff92SAndroid Build Coastguard Worker     float confidence_threshold{0.8f}; /**< Inclusion confidence threshold for a box */
17*89c4ff92SAndroid Build Coastguard Worker     float iou_threshold{0.8f};        /**< Inclusion threshold for Intersection-Over-Union */
18*89c4ff92SAndroid Build Coastguard Worker };
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker /** Box representation structure */
21*89c4ff92SAndroid Build Coastguard Worker struct Box {
22*89c4ff92SAndroid Build Coastguard Worker     float xmin;  /**< X-pos position of the low left coordinate */
23*89c4ff92SAndroid Build Coastguard Worker     float xmax;  /**< X-pos position of the top right coordinate */
24*89c4ff92SAndroid Build Coastguard Worker     float ymin;  /**< Y-pos position of the low left coordinate */
25*89c4ff92SAndroid Build Coastguard Worker     float ymax;  /**< Y-pos position of the top right coordinate */
26*89c4ff92SAndroid Build Coastguard Worker };
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker /** Detection structure */
29*89c4ff92SAndroid Build Coastguard Worker struct Detection {
30*89c4ff92SAndroid Build Coastguard Worker     Box box;                    /**< Detection box */
31*89c4ff92SAndroid Build Coastguard Worker     float confidence;           /**< Confidence of detection */
32*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> classes; /**< Probability of classes */
33*89c4ff92SAndroid Build Coastguard Worker };
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker /** Print identified yolo detections
36*89c4ff92SAndroid Build Coastguard Worker  *
37*89c4ff92SAndroid Build Coastguard Worker  * @param[in, out] os          Output stream to print to
38*89c4ff92SAndroid Build Coastguard Worker  * @param[in]      detections  Detections to print
39*89c4ff92SAndroid Build Coastguard Worker  */
40*89c4ff92SAndroid Build Coastguard Worker void print_detection(std::ostream& os,
41*89c4ff92SAndroid Build Coastguard Worker                      const std::vector<Detection>& detections);
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker /** Compare a detection object with a vector of float values
44*89c4ff92SAndroid Build Coastguard Worker  *
45*89c4ff92SAndroid Build Coastguard Worker  * @param detection [in] Detection object
46*89c4ff92SAndroid Build Coastguard Worker  * @param expected  [in] Vector of expected float values
47*89c4ff92SAndroid Build Coastguard Worker  * @return Boolean to represent if they match or not
48*89c4ff92SAndroid Build Coastguard Worker  */
49*89c4ff92SAndroid Build Coastguard Worker bool compare_detection(const yolov3::Detection& detection,
50*89c4ff92SAndroid Build Coastguard Worker                        const std::vector<float>& expected);
51*89c4ff92SAndroid Build Coastguard Worker 
52*89c4ff92SAndroid Build Coastguard Worker /** Perform Non-Maxima Supression on a list of given detections
53*89c4ff92SAndroid Build Coastguard Worker  *
54*89c4ff92SAndroid Build Coastguard Worker  * @param[in] config         Configuration metadata for NMS
55*89c4ff92SAndroid Build Coastguard Worker  * @param[in] detected_boxes Detected boxes
56*89c4ff92SAndroid Build Coastguard Worker  *
57*89c4ff92SAndroid Build Coastguard Worker  * @return A vector with the final detections
58*89c4ff92SAndroid Build Coastguard Worker  */
59*89c4ff92SAndroid Build Coastguard Worker std::vector<Detection> nms(const NMSConfig& config,
60*89c4ff92SAndroid Build Coastguard Worker                            const std::vector<float>& detected_boxes);
61*89c4ff92SAndroid Build Coastguard Worker } // namespace yolov3
62