xref: /aosp_15_r20/external/armnn/tests/TfLiteYoloV3Big-Armnn/NMS.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "NMS.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
10*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
11*89c4ff92SAndroid Build Coastguard Worker #include <cstddef>
12*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
13*89c4ff92SAndroid Build Coastguard Worker #include <ostream>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace yolov3 {
16*89c4ff92SAndroid Build Coastguard Worker namespace {
17*89c4ff92SAndroid Build Coastguard Worker /** Number of elements needed to represent a box */
18*89c4ff92SAndroid Build Coastguard Worker constexpr int box_elements = 4;
19*89c4ff92SAndroid Build Coastguard Worker /** Number of elements needed to represent a confidence factor */
20*89c4ff92SAndroid Build Coastguard Worker constexpr int confidence_elements = 1;
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker /** Calculate Intersection Over Union of two boxes
23*89c4ff92SAndroid Build Coastguard Worker  *
24*89c4ff92SAndroid Build Coastguard Worker  * @param[in] box1 First box
25*89c4ff92SAndroid Build Coastguard Worker  * @param[in] box2 Second box
26*89c4ff92SAndroid Build Coastguard Worker  *
27*89c4ff92SAndroid Build Coastguard Worker  * @return The IoU of the two boxes
28*89c4ff92SAndroid Build Coastguard Worker  */
iou(const Box & box1,const Box & box2)29*89c4ff92SAndroid Build Coastguard Worker float iou(const Box& box1, const Box& box2)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker     const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
32*89c4ff92SAndroid Build Coastguard Worker     const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
33*89c4ff92SAndroid Build Coastguard Worker     float overlap;
34*89c4ff92SAndroid Build Coastguard Worker     if (area1 <= 0 || area2 <= 0)
35*89c4ff92SAndroid Build Coastguard Worker     {
36*89c4ff92SAndroid Build Coastguard Worker         overlap = 0.0f;
37*89c4ff92SAndroid Build Coastguard Worker     }
38*89c4ff92SAndroid Build Coastguard Worker     else
39*89c4ff92SAndroid Build Coastguard Worker     {
40*89c4ff92SAndroid Build Coastguard Worker         const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin);
41*89c4ff92SAndroid Build Coastguard Worker         const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin);
42*89c4ff92SAndroid Build Coastguard Worker         const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax);
43*89c4ff92SAndroid Build Coastguard Worker         const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax);
44*89c4ff92SAndroid Build Coastguard Worker         const auto area_intersection =
45*89c4ff92SAndroid Build Coastguard Worker             std::max<float>(y_max_intersection - y_min_intersection, 0.0f) *
46*89c4ff92SAndroid Build Coastguard Worker             std::max<float>(x_max_intersection - x_min_intersection, 0.0f);
47*89c4ff92SAndroid Build Coastguard Worker         overlap = area_intersection / (area1 + area2 - area_intersection);
48*89c4ff92SAndroid Build Coastguard Worker     }
49*89c4ff92SAndroid Build Coastguard Worker     return overlap;
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker 
convert_to_detections(const NMSConfig & config,const std::vector<float> & detected_boxes)52*89c4ff92SAndroid Build Coastguard Worker std::vector<Detection> convert_to_detections(const NMSConfig& config,
53*89c4ff92SAndroid Build Coastguard Worker                                              const std::vector<float>& detected_boxes)
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker     const size_t element_step = static_cast<size_t>(
56*89c4ff92SAndroid Build Coastguard Worker         box_elements + confidence_elements + config.num_classes);
57*89c4ff92SAndroid Build Coastguard Worker     std::vector<Detection> detections;
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int i = 0; i < config.num_boxes; ++i)
60*89c4ff92SAndroid Build Coastguard Worker     {
61*89c4ff92SAndroid Build Coastguard Worker         const float* cur_box = &detected_boxes[i * element_step];
62*89c4ff92SAndroid Build Coastguard Worker         if (cur_box[4] > config.confidence_threshold)
63*89c4ff92SAndroid Build Coastguard Worker         {
64*89c4ff92SAndroid Build Coastguard Worker             Detection det;
65*89c4ff92SAndroid Build Coastguard Worker             det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1],
66*89c4ff92SAndroid Build Coastguard Worker                        cur_box[1] + cur_box[3]};
67*89c4ff92SAndroid Build Coastguard Worker             det.confidence = cur_box[4];
68*89c4ff92SAndroid Build Coastguard Worker             det.classes.resize(static_cast<size_t>(config.num_classes), 0);
69*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int c = 0; c < config.num_classes; ++c)
70*89c4ff92SAndroid Build Coastguard Worker             {
71*89c4ff92SAndroid Build Coastguard Worker                 const float class_prob = det.confidence * cur_box[5 + c];
72*89c4ff92SAndroid Build Coastguard Worker                 if (class_prob > config.confidence_threshold)
73*89c4ff92SAndroid Build Coastguard Worker                 {
74*89c4ff92SAndroid Build Coastguard Worker                     det.classes[c] = class_prob;
75*89c4ff92SAndroid Build Coastguard Worker                 }
76*89c4ff92SAndroid Build Coastguard Worker             }
77*89c4ff92SAndroid Build Coastguard Worker             detections.emplace_back(std::move(det));
78*89c4ff92SAndroid Build Coastguard Worker         }
79*89c4ff92SAndroid Build Coastguard Worker     }
80*89c4ff92SAndroid Build Coastguard Worker     return detections;
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker } // namespace
83*89c4ff92SAndroid Build Coastguard Worker 
compare_detection(const yolov3::Detection & detection,const std::vector<float> & expected)84*89c4ff92SAndroid Build Coastguard Worker bool compare_detection(const yolov3::Detection& detection,
85*89c4ff92SAndroid Build Coastguard Worker                        const std::vector<float>& expected)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker     float tolerance = 0.001f;
88*89c4ff92SAndroid Build Coastguard Worker     return (std::fabs(detection.classes[0] - expected[0]) < tolerance  &&
89*89c4ff92SAndroid Build Coastguard Worker             std::fabs(detection.box.xmin   - expected[1]) < tolerance  &&
90*89c4ff92SAndroid Build Coastguard Worker             std::fabs(detection.box.ymin   - expected[2]) < tolerance  &&
91*89c4ff92SAndroid Build Coastguard Worker             std::fabs(detection.box.xmax   - expected[3]) < tolerance  &&
92*89c4ff92SAndroid Build Coastguard Worker             std::fabs(detection.box.ymax   - expected[4]) < tolerance  &&
93*89c4ff92SAndroid Build Coastguard Worker             std::fabs(detection.confidence - expected[5]) < tolerance  );
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker 
print_detection(std::ostream & os,const std::vector<Detection> & detections)96*89c4ff92SAndroid Build Coastguard Worker void print_detection(std::ostream& os,
97*89c4ff92SAndroid Build Coastguard Worker                      const std::vector<Detection>& detections)
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker     for (const auto& detection : detections)
100*89c4ff92SAndroid Build Coastguard Worker     {
101*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int c = 0; c < detection.classes.size(); ++c)
102*89c4ff92SAndroid Build Coastguard Worker         {
103*89c4ff92SAndroid Build Coastguard Worker             if (detection.classes[c] != 0.0f)
104*89c4ff92SAndroid Build Coastguard Worker             {
105*89c4ff92SAndroid Build Coastguard Worker                 os << c << " " << detection.classes[c] << " " << detection.box.xmin
106*89c4ff92SAndroid Build Coastguard Worker                    << " " << detection.box.ymin << " " << detection.box.xmax << " "
107*89c4ff92SAndroid Build Coastguard Worker                    << detection.box.ymax << std::endl;
108*89c4ff92SAndroid Build Coastguard Worker             }
109*89c4ff92SAndroid Build Coastguard Worker         }
110*89c4ff92SAndroid Build Coastguard Worker     }
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker 
nms(const NMSConfig & config,const std::vector<float> & detected_boxes)113*89c4ff92SAndroid Build Coastguard Worker std::vector<Detection> nms(const NMSConfig& config,
114*89c4ff92SAndroid Build Coastguard Worker                            const std::vector<float>& detected_boxes) {
115*89c4ff92SAndroid Build Coastguard Worker     // Get detections that comply with the expected confidence threshold
116*89c4ff92SAndroid Build Coastguard Worker     std::vector<Detection> detections =
117*89c4ff92SAndroid Build Coastguard Worker         convert_to_detections(config, detected_boxes);
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     const unsigned int num_detections = static_cast<unsigned int>(detections.size());
120*89c4ff92SAndroid Build Coastguard Worker     for (unsigned int c = 0; c < config.num_classes; ++c)
121*89c4ff92SAndroid Build Coastguard Worker     {
122*89c4ff92SAndroid Build Coastguard Worker         // Sort classes
123*89c4ff92SAndroid Build Coastguard Worker         std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections),
124*89c4ff92SAndroid Build Coastguard Worker                   [c](Detection& detection1, Detection& detection2)
125*89c4ff92SAndroid Build Coastguard Worker                     {
126*89c4ff92SAndroid Build Coastguard Worker                         return (detection1.classes[c] - detection2.classes[c]) > 0;
127*89c4ff92SAndroid Build Coastguard Worker                     });
128*89c4ff92SAndroid Build Coastguard Worker         // Clear detections with high IoU
129*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int d = 0; d < num_detections; ++d)
130*89c4ff92SAndroid Build Coastguard Worker         {
131*89c4ff92SAndroid Build Coastguard Worker             // Check if class is already cleared/invalidated
132*89c4ff92SAndroid Build Coastguard Worker             if (detections[d].classes[c] == 0.f)
133*89c4ff92SAndroid Build Coastguard Worker             {
134*89c4ff92SAndroid Build Coastguard Worker                 continue;
135*89c4ff92SAndroid Build Coastguard Worker             }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker             // Filter out boxes on IoU threshold
138*89c4ff92SAndroid Build Coastguard Worker             const Box& box1 = detections[d].box;
139*89c4ff92SAndroid Build Coastguard Worker             for (unsigned int b = d + 1; b < num_detections; ++b)
140*89c4ff92SAndroid Build Coastguard Worker             {
141*89c4ff92SAndroid Build Coastguard Worker                 const Box& box2 = detections[b].box;
142*89c4ff92SAndroid Build Coastguard Worker                 if (iou(box1, box2) > config.iou_threshold)
143*89c4ff92SAndroid Build Coastguard Worker                 {
144*89c4ff92SAndroid Build Coastguard Worker                     detections[b].classes[c] = 0.f;
145*89c4ff92SAndroid Build Coastguard Worker                 }
146*89c4ff92SAndroid Build Coastguard Worker             }
147*89c4ff92SAndroid Build Coastguard Worker         }
148*89c4ff92SAndroid Build Coastguard Worker     }
149*89c4ff92SAndroid Build Coastguard Worker     return detections;
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker } // namespace yolov3
152