xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/src/NonMaxSuppression.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "NonMaxSuppression.hpp"
6 
7 #include <algorithm>
8 
9 namespace od
10 {
11 
GenerateRangeK(unsigned int k)12 static std::vector<unsigned int> GenerateRangeK(unsigned int k)
13 {
14     std::vector<unsigned int> range(k);
15     std::iota(range.begin(), range.end(), 0);
16     return range;
17 }
18 
19 
20 /**
21 * @brief Returns the intersection over union for two bounding boxes
22 *
23 * @param[in]  First detect containing bounding box.
24 * @param[in]  Second detect containing bounding box.
25 * @return     Calculated intersection over union.
26 *
27 */
IntersectionOverUnion(DetectedObject & detect1,DetectedObject & detect2)28 static double IntersectionOverUnion(DetectedObject& detect1, DetectedObject& detect2)
29 {
30     uint32_t area1 = (detect1.GetBoundingBox().GetHeight() * detect1.GetBoundingBox().GetWidth());
31     uint32_t area2 = (detect2.GetBoundingBox().GetHeight() * detect2.GetBoundingBox().GetWidth());
32 
33     float yMinIntersection = std::max(detect1.GetBoundingBox().GetY(), detect2.GetBoundingBox().GetY());
34     float xMinIntersection = std::max(detect1.GetBoundingBox().GetX(), detect2.GetBoundingBox().GetX());
35 
36     float yMaxIntersection = std::min(detect1.GetBoundingBox().GetY() + detect1.GetBoundingBox().GetHeight(),
37                                       detect2.GetBoundingBox().GetY() + detect2.GetBoundingBox().GetHeight());
38     float xMaxIntersection = std::min(detect1.GetBoundingBox().GetX() + detect1.GetBoundingBox().GetWidth(),
39                                       detect2.GetBoundingBox().GetX() + detect2.GetBoundingBox().GetWidth());
40 
41     double areaIntersection = std::max(yMaxIntersection - yMinIntersection, 0.0f) *
42                               std::max(xMaxIntersection - xMinIntersection, 0.0f);
43     double areaUnion = area1 + area2 - areaIntersection;
44 
45     return areaIntersection / areaUnion;
46 }
47 
NonMaxSuppression(DetectedObjects & inputDetections,float iouThresh)48 std::vector<int> NonMaxSuppression(DetectedObjects& inputDetections, float iouThresh)
49 {
50     // Sort indicies of detections by highest score to lowest.
51     std::vector<unsigned int> sortedIndicies = GenerateRangeK(inputDetections.size());
52     std::sort(sortedIndicies.begin(), sortedIndicies.end(),
53         [&inputDetections](int idx1, int idx2)
54         {
55             return inputDetections[idx1].GetScore() > inputDetections[idx2].GetScore();
56         });
57 
58     std::vector<bool> visited(inputDetections.size(), false);
59     std::vector<int> outputIndiciesAfterNMS;
60 
61     for (int i=0; i < inputDetections.size(); ++i)
62     {
63         // Each new unvisited detect should be kept.
64         if (!visited[sortedIndicies[i]])
65         {
66             outputIndiciesAfterNMS.emplace_back(sortedIndicies[i]);
67             visited[sortedIndicies[i]] = true;
68         }
69 
70         // Look for detections to suppress.
71         for (int j=i+1; j<inputDetections.size(); ++j)
72         {
73             // Skip if already kept or suppressed.
74             if (!visited[sortedIndicies[j]])
75             {
76                 // Detects must have the same label to be suppressed.
77                 if (inputDetections[sortedIndicies[j]].GetLabel() == inputDetections[sortedIndicies[i]].GetLabel())
78                 {
79                     auto iou = IntersectionOverUnion(inputDetections[sortedIndicies[i]],
80                                                     inputDetections[sortedIndicies[j]]);
81                     if (iou > iouThresh)
82                     {
83                         visited[sortedIndicies[j]] = true;
84                     }
85                 }
86             }
87         }
88     }
89     return outputIndiciesAfterNMS;
90 }
91 
92 } // namespace od
93