1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker
7*89c4ff92SAndroid Build Coastguard Worker #include "NMS.hpp"
8*89c4ff92SAndroid Build Coastguard Worker
9*89c4ff92SAndroid Build Coastguard Worker #include <cmath>
10*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
11*89c4ff92SAndroid Build Coastguard Worker #include <cstddef>
12*89c4ff92SAndroid Build Coastguard Worker #include <numeric>
13*89c4ff92SAndroid Build Coastguard Worker #include <ostream>
14*89c4ff92SAndroid Build Coastguard Worker
15*89c4ff92SAndroid Build Coastguard Worker namespace yolov3 {
16*89c4ff92SAndroid Build Coastguard Worker namespace {
17*89c4ff92SAndroid Build Coastguard Worker /** Number of elements needed to represent a box */
18*89c4ff92SAndroid Build Coastguard Worker constexpr int box_elements = 4;
19*89c4ff92SAndroid Build Coastguard Worker /** Number of elements needed to represent a confidence factor */
20*89c4ff92SAndroid Build Coastguard Worker constexpr int confidence_elements = 1;
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker /** Calculate Intersection Over Union of two boxes
23*89c4ff92SAndroid Build Coastguard Worker *
24*89c4ff92SAndroid Build Coastguard Worker * @param[in] box1 First box
25*89c4ff92SAndroid Build Coastguard Worker * @param[in] box2 Second box
26*89c4ff92SAndroid Build Coastguard Worker *
27*89c4ff92SAndroid Build Coastguard Worker * @return The IoU of the two boxes
28*89c4ff92SAndroid Build Coastguard Worker */
iou(const Box & box1,const Box & box2)29*89c4ff92SAndroid Build Coastguard Worker float iou(const Box& box1, const Box& box2)
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker const float area1 = (box1.xmax - box1.xmin) * (box1.ymax - box1.ymin);
32*89c4ff92SAndroid Build Coastguard Worker const float area2 = (box2.xmax - box2.xmin) * (box2.ymax - box2.ymin);
33*89c4ff92SAndroid Build Coastguard Worker float overlap;
34*89c4ff92SAndroid Build Coastguard Worker if (area1 <= 0 || area2 <= 0)
35*89c4ff92SAndroid Build Coastguard Worker {
36*89c4ff92SAndroid Build Coastguard Worker overlap = 0.0f;
37*89c4ff92SAndroid Build Coastguard Worker }
38*89c4ff92SAndroid Build Coastguard Worker else
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker const auto y_min_intersection = std::max<float>(box1.ymin, box2.ymin);
41*89c4ff92SAndroid Build Coastguard Worker const auto x_min_intersection = std::max<float>(box1.xmin, box2.xmin);
42*89c4ff92SAndroid Build Coastguard Worker const auto y_max_intersection = std::min<float>(box1.ymax, box2.ymax);
43*89c4ff92SAndroid Build Coastguard Worker const auto x_max_intersection = std::min<float>(box1.xmax, box2.xmax);
44*89c4ff92SAndroid Build Coastguard Worker const auto area_intersection =
45*89c4ff92SAndroid Build Coastguard Worker std::max<float>(y_max_intersection - y_min_intersection, 0.0f) *
46*89c4ff92SAndroid Build Coastguard Worker std::max<float>(x_max_intersection - x_min_intersection, 0.0f);
47*89c4ff92SAndroid Build Coastguard Worker overlap = area_intersection / (area1 + area2 - area_intersection);
48*89c4ff92SAndroid Build Coastguard Worker }
49*89c4ff92SAndroid Build Coastguard Worker return overlap;
50*89c4ff92SAndroid Build Coastguard Worker }
51*89c4ff92SAndroid Build Coastguard Worker
convert_to_detections(const NMSConfig & config,const std::vector<float> & detected_boxes)52*89c4ff92SAndroid Build Coastguard Worker std::vector<Detection> convert_to_detections(const NMSConfig& config,
53*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& detected_boxes)
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker const size_t element_step = static_cast<size_t>(
56*89c4ff92SAndroid Build Coastguard Worker box_elements + confidence_elements + config.num_classes);
57*89c4ff92SAndroid Build Coastguard Worker std::vector<Detection> detections;
58*89c4ff92SAndroid Build Coastguard Worker
59*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < config.num_boxes; ++i)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker const float* cur_box = &detected_boxes[i * element_step];
62*89c4ff92SAndroid Build Coastguard Worker if (cur_box[4] > config.confidence_threshold)
63*89c4ff92SAndroid Build Coastguard Worker {
64*89c4ff92SAndroid Build Coastguard Worker Detection det;
65*89c4ff92SAndroid Build Coastguard Worker det.box = {cur_box[0], cur_box[0] + cur_box[2], cur_box[1],
66*89c4ff92SAndroid Build Coastguard Worker cur_box[1] + cur_box[3]};
67*89c4ff92SAndroid Build Coastguard Worker det.confidence = cur_box[4];
68*89c4ff92SAndroid Build Coastguard Worker det.classes.resize(static_cast<size_t>(config.num_classes), 0);
69*89c4ff92SAndroid Build Coastguard Worker for (unsigned int c = 0; c < config.num_classes; ++c)
70*89c4ff92SAndroid Build Coastguard Worker {
71*89c4ff92SAndroid Build Coastguard Worker const float class_prob = det.confidence * cur_box[5 + c];
72*89c4ff92SAndroid Build Coastguard Worker if (class_prob > config.confidence_threshold)
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker det.classes[c] = class_prob;
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker detections.emplace_back(std::move(det));
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker }
80*89c4ff92SAndroid Build Coastguard Worker return detections;
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker } // namespace
83*89c4ff92SAndroid Build Coastguard Worker
compare_detection(const yolov3::Detection & detection,const std::vector<float> & expected)84*89c4ff92SAndroid Build Coastguard Worker bool compare_detection(const yolov3::Detection& detection,
85*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& expected)
86*89c4ff92SAndroid Build Coastguard Worker {
87*89c4ff92SAndroid Build Coastguard Worker float tolerance = 0.001f;
88*89c4ff92SAndroid Build Coastguard Worker return (std::fabs(detection.classes[0] - expected[0]) < tolerance &&
89*89c4ff92SAndroid Build Coastguard Worker std::fabs(detection.box.xmin - expected[1]) < tolerance &&
90*89c4ff92SAndroid Build Coastguard Worker std::fabs(detection.box.ymin - expected[2]) < tolerance &&
91*89c4ff92SAndroid Build Coastguard Worker std::fabs(detection.box.xmax - expected[3]) < tolerance &&
92*89c4ff92SAndroid Build Coastguard Worker std::fabs(detection.box.ymax - expected[4]) < tolerance &&
93*89c4ff92SAndroid Build Coastguard Worker std::fabs(detection.confidence - expected[5]) < tolerance );
94*89c4ff92SAndroid Build Coastguard Worker }
95*89c4ff92SAndroid Build Coastguard Worker
print_detection(std::ostream & os,const std::vector<Detection> & detections)96*89c4ff92SAndroid Build Coastguard Worker void print_detection(std::ostream& os,
97*89c4ff92SAndroid Build Coastguard Worker const std::vector<Detection>& detections)
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker for (const auto& detection : detections)
100*89c4ff92SAndroid Build Coastguard Worker {
101*89c4ff92SAndroid Build Coastguard Worker for (unsigned int c = 0; c < detection.classes.size(); ++c)
102*89c4ff92SAndroid Build Coastguard Worker {
103*89c4ff92SAndroid Build Coastguard Worker if (detection.classes[c] != 0.0f)
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker os << c << " " << detection.classes[c] << " " << detection.box.xmin
106*89c4ff92SAndroid Build Coastguard Worker << " " << detection.box.ymin << " " << detection.box.xmax << " "
107*89c4ff92SAndroid Build Coastguard Worker << detection.box.ymax << std::endl;
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker }
110*89c4ff92SAndroid Build Coastguard Worker }
111*89c4ff92SAndroid Build Coastguard Worker }
112*89c4ff92SAndroid Build Coastguard Worker
nms(const NMSConfig & config,const std::vector<float> & detected_boxes)113*89c4ff92SAndroid Build Coastguard Worker std::vector<Detection> nms(const NMSConfig& config,
114*89c4ff92SAndroid Build Coastguard Worker const std::vector<float>& detected_boxes) {
115*89c4ff92SAndroid Build Coastguard Worker // Get detections that comply with the expected confidence threshold
116*89c4ff92SAndroid Build Coastguard Worker std::vector<Detection> detections =
117*89c4ff92SAndroid Build Coastguard Worker convert_to_detections(config, detected_boxes);
118*89c4ff92SAndroid Build Coastguard Worker
119*89c4ff92SAndroid Build Coastguard Worker const unsigned int num_detections = static_cast<unsigned int>(detections.size());
120*89c4ff92SAndroid Build Coastguard Worker for (unsigned int c = 0; c < config.num_classes; ++c)
121*89c4ff92SAndroid Build Coastguard Worker {
122*89c4ff92SAndroid Build Coastguard Worker // Sort classes
123*89c4ff92SAndroid Build Coastguard Worker std::sort(detections.begin(), detections.begin() + static_cast<std::ptrdiff_t>(num_detections),
124*89c4ff92SAndroid Build Coastguard Worker [c](Detection& detection1, Detection& detection2)
125*89c4ff92SAndroid Build Coastguard Worker {
126*89c4ff92SAndroid Build Coastguard Worker return (detection1.classes[c] - detection2.classes[c]) > 0;
127*89c4ff92SAndroid Build Coastguard Worker });
128*89c4ff92SAndroid Build Coastguard Worker // Clear detections with high IoU
129*89c4ff92SAndroid Build Coastguard Worker for (unsigned int d = 0; d < num_detections; ++d)
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker // Check if class is already cleared/invalidated
132*89c4ff92SAndroid Build Coastguard Worker if (detections[d].classes[c] == 0.f)
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker continue;
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker
137*89c4ff92SAndroid Build Coastguard Worker // Filter out boxes on IoU threshold
138*89c4ff92SAndroid Build Coastguard Worker const Box& box1 = detections[d].box;
139*89c4ff92SAndroid Build Coastguard Worker for (unsigned int b = d + 1; b < num_detections; ++b)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker const Box& box2 = detections[b].box;
142*89c4ff92SAndroid Build Coastguard Worker if (iou(box1, box2) > config.iou_threshold)
143*89c4ff92SAndroid Build Coastguard Worker {
144*89c4ff92SAndroid Build Coastguard Worker detections[b].classes[c] = 0.f;
145*89c4ff92SAndroid Build Coastguard Worker }
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker }
149*89c4ff92SAndroid Build Coastguard Worker return detections;
150*89c4ff92SAndroid Build Coastguard Worker }
151*89c4ff92SAndroid Build Coastguard Worker } // namespace yolov3
152