xref: /aosp_15_r20/external/armnn/tests/TfLiteYoloV3Big-Armnn/NMS.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 
7 #include "NMS.hpp"
8 
9 #include <cmath>
10 #include <algorithm>
11 #include <cstddef>
12 #include <numeric>
13 #include <ostream>
14 
15 namespace yolov3 {
16 namespace {
17 /** Number of elements needed to represent a box */
18 constexpr int box_elements = 4;
19 /** Number of elements needed to represent a confidence factor */
20 constexpr int confidence_elements = 1;
21 
22 /** Calculate Intersection Over Union of two boxes
23  *
24  * @param[in] box1 First box
25  * @param[in] box2 Second box
26  *
27  * @return The IoU of the two boxes
28  */
iou(const Box & box1,const Box & box2)29 float iou(const Box& box1, const Box& box2)
30 {
31     const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
32     const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
33     float overlap;
34     if (area1 <= 0 || area2 <= 0)
35     {
36         overlap = 0.0f;
37     }
38     else
39     {
40         const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin);
41         const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin);
42         const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax);
43         const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax);
44         const auto area_intersection =
45             std::max<float>(y_max_intersection - y_min_intersection, 0.0f) *
46             std::max<float>(x_max_intersection - x_min_intersection, 0.0f);
47         overlap = area_intersection / (area1 + area2 - area_intersection);
48     }
49     return overlap;
50 }
51 
convert_to_detections(const NMSConfig & config,const std::vector<float> & detected_boxes)52 std::vector<Detection> convert_to_detections(const NMSConfig& config,
53                                              const std::vector<float>& detected_boxes)
54 {
55     const size_t element_step = static_cast<size_t>(
56         box_elements + confidence_elements + config.num_classes);
57     std::vector<Detection> detections;
58 
59     for (unsigned int i = 0; i < config.num_boxes; ++i)
60     {
61         const float* cur_box = &detected_boxes[i * element_step];
62         if (cur_box[4] > config.confidence_threshold)
63         {
64             Detection det;
65             det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1],
66                        cur_box[1] + cur_box[3]};
67             det.confidence = cur_box[4];
68             det.classes.resize(static_cast<size_t>(config.num_classes), 0);
69             for (unsigned int c = 0; c < config.num_classes; ++c)
70             {
71                 const float class_prob = det.confidence * cur_box[5 + c];
72                 if (class_prob > config.confidence_threshold)
73                 {
74                     det.classes[c] = class_prob;
75                 }
76             }
77             detections.emplace_back(std::move(det));
78         }
79     }
80     return detections;
81 }
82 } // namespace
83 
compare_detection(const yolov3::Detection & detection,const std::vector<float> & expected)84 bool compare_detection(const yolov3::Detection& detection,
85                        const std::vector<float>& expected)
86 {
87     float tolerance = 0.001f;
88     return (std::fabs(detection.classes[0] - expected[0]) < tolerance  &&
89             std::fabs(detection.box.xmin   - expected[1]) < tolerance  &&
90             std::fabs(detection.box.ymin   - expected[2]) < tolerance  &&
91             std::fabs(detection.box.xmax   - expected[3]) < tolerance  &&
92             std::fabs(detection.box.ymax   - expected[4]) < tolerance  &&
93             std::fabs(detection.confidence - expected[5]) < tolerance  );
94 }
95 
print_detection(std::ostream & os,const std::vector<Detection> & detections)96 void print_detection(std::ostream& os,
97                      const std::vector<Detection>& detections)
98 {
99     for (const auto& detection : detections)
100     {
101         for (unsigned int c = 0; c < detection.classes.size(); ++c)
102         {
103             if (detection.classes[c] != 0.0f)
104             {
105                 os << c << " " << detection.classes[c] << " " << detection.box.xmin
106                    << " " << detection.box.ymin << " " << detection.box.xmax << " "
107                    << detection.box.ymax << std::endl;
108             }
109         }
110     }
111 }
112 
nms(const NMSConfig & config,const std::vector<float> & detected_boxes)113 std::vector<Detection> nms(const NMSConfig& config,
114                            const std::vector<float>& detected_boxes) {
115     // Get detections that comply with the expected confidence threshold
116     std::vector<Detection> detections =
117         convert_to_detections(config, detected_boxes);
118 
119     const unsigned int num_detections = static_cast<unsigned int>(detections.size());
120     for (unsigned int c = 0; c < config.num_classes; ++c)
121     {
122         // Sort classes
123         std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections),
124                   [c](Detection& detection1, Detection& detection2)
125                     {
126                         return (detection1.classes[c] - detection2.classes[c]) > 0;
127                     });
128         // Clear detections with high IoU
129         for (unsigned int d = 0; d < num_detections; ++d)
130         {
131             // Check if class is already cleared/invalidated
132             if (detections[d].classes[c] == 0.f)
133             {
134                 continue;
135             }
136 
137             // Filter out boxes on IoU threshold
138             const Box& box1 = detections[d].box;
139             for (unsigned int b = d + 1; b < num_detections; ++b)
140             {
141                 const Box& box2 = detections[b].box;
142                 if (iou(box1, box2) > config.iou_threshold)
143                 {
144                     detections[b].classes[c] = 0.f;
145                 }
146             }
147         }
148     }
149     return detections;
150 }
151 } // namespace yolov3
152