1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include <catch.hpp> 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include "NonMaxSuppression.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Non_Max_Suppression_1") 11*89c4ff92SAndroid Build Coastguard Worker { 12*89c4ff92SAndroid Build Coastguard Worker // Box with iou exactly 0.5. 13*89c4ff92SAndroid Build Coastguard Worker od::DetectedObject detectedObject1; 14*89c4ff92SAndroid Build Coastguard Worker detectedObject1.SetLabel("2"); 15*89c4ff92SAndroid Build Coastguard Worker detectedObject1.SetScore(171); 16*89c4ff92SAndroid Build Coastguard Worker detectedObject1.SetBoundingBox({0, 0, 150, 150}); 17*89c4ff92SAndroid Build Coastguard Worker 18*89c4ff92SAndroid Build Coastguard Worker // Strongest detection. 19*89c4ff92SAndroid Build Coastguard Worker od::DetectedObject detectedObject2; 20*89c4ff92SAndroid Build Coastguard Worker detectedObject2.SetLabel("2"); 21*89c4ff92SAndroid Build Coastguard Worker detectedObject2.SetScore(230); 22*89c4ff92SAndroid Build Coastguard Worker detectedObject2.SetBoundingBox({0, 75, 150, 75}); 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker // Weaker detection with same coordinates of strongest. 25*89c4ff92SAndroid Build Coastguard Worker od::DetectedObject detectedObject3; 26*89c4ff92SAndroid Build Coastguard Worker detectedObject3.SetLabel("2"); 27*89c4ff92SAndroid Build Coastguard Worker detectedObject3.SetScore(20); 28*89c4ff92SAndroid Build Coastguard Worker detectedObject3.SetBoundingBox({0, 75, 150, 75}); 29*89c4ff92SAndroid Build Coastguard Worker 30*89c4ff92SAndroid Build Coastguard Worker // Detection not overlapping strongest. 31*89c4ff92SAndroid Build Coastguard Worker od::DetectedObject detectedObject4; 32*89c4ff92SAndroid Build Coastguard Worker detectedObject4.SetLabel("2"); 33*89c4ff92SAndroid Build Coastguard Worker detectedObject4.SetScore(222); 34*89c4ff92SAndroid Build Coastguard Worker detectedObject4.SetBoundingBox({0, 0, 50, 50}); 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker // Small detection inside strongest. 37*89c4ff92SAndroid Build Coastguard Worker od::DetectedObject detectedObject5; 38*89c4ff92SAndroid Build Coastguard Worker detectedObject5.SetLabel("2"); 39*89c4ff92SAndroid Build Coastguard Worker detectedObject5.SetScore(201); 40*89c4ff92SAndroid Build Coastguard Worker detectedObject5.SetBoundingBox({100, 100, 20, 20}); 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker // Box with iou exactly 0.5 but different label. 43*89c4ff92SAndroid Build Coastguard Worker od::DetectedObject detectedObject6; 44*89c4ff92SAndroid Build Coastguard Worker detectedObject6.SetLabel("1"); 45*89c4ff92SAndroid Build Coastguard Worker detectedObject6.SetScore(75); 46*89c4ff92SAndroid Build Coastguard Worker detectedObject6.SetBoundingBox({0, 0, 150, 150}); 47*89c4ff92SAndroid Build Coastguard Worker 48*89c4ff92SAndroid Build Coastguard Worker od::DetectedObjects expectedResults {detectedObject1, 49*89c4ff92SAndroid Build Coastguard Worker detectedObject2, 50*89c4ff92SAndroid Build Coastguard Worker detectedObject3, 51*89c4ff92SAndroid Build Coastguard Worker detectedObject4, 52*89c4ff92SAndroid Build Coastguard Worker detectedObject5, 53*89c4ff92SAndroid Build Coastguard Worker detectedObject6}; 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker auto sorted = od::NonMaxSuppression(expectedResults, 0.49); 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker // 1st and 3rd detection should be suppressed. 58*89c4ff92SAndroid Build Coastguard Worker REQUIRE(sorted.size() == 4); 59*89c4ff92SAndroid Build Coastguard Worker 60*89c4ff92SAndroid Build Coastguard Worker // Final detects should be ordered strongest to weakest. 61*89c4ff92SAndroid Build Coastguard Worker REQUIRE(sorted[0] == 1); 62*89c4ff92SAndroid Build Coastguard Worker REQUIRE(sorted[1] == 3); 63*89c4ff92SAndroid Build Coastguard Worker REQUIRE(sorted[2] == 4); 64*89c4ff92SAndroid Build Coastguard Worker REQUIRE(sorted[3] == 5); 65*89c4ff92SAndroid Build Coastguard Worker } 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Non_Max_Suppression_2") 68*89c4ff92SAndroid Build Coastguard Worker { 69*89c4ff92SAndroid Build Coastguard Worker // Real box examples. 70*89c4ff92SAndroid Build Coastguard Worker od::DetectedObject detectedObject1; 71*89c4ff92SAndroid Build Coastguard Worker detectedObject1.SetLabel("2"); 72*89c4ff92SAndroid Build Coastguard Worker detectedObject1.SetScore(220); 73*89c4ff92SAndroid Build Coastguard Worker detectedObject1.SetBoundingBox({430, 158, 68, 68}); 74*89c4ff92SAndroid Build Coastguard Worker 75*89c4ff92SAndroid Build Coastguard Worker od::DetectedObject detectedObject2; 76*89c4ff92SAndroid Build Coastguard Worker detectedObject2.SetLabel("2"); 77*89c4ff92SAndroid Build Coastguard Worker detectedObject2.SetScore(171); 78*89c4ff92SAndroid Build Coastguard Worker detectedObject2.SetBoundingBox({438, 158, 68, 68}); 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker od::DetectedObjects expectedResults {detectedObject1, 81*89c4ff92SAndroid Build Coastguard Worker detectedObject2}; 82*89c4ff92SAndroid Build Coastguard Worker 83*89c4ff92SAndroid Build Coastguard Worker auto sorted = od::NonMaxSuppression(expectedResults, 0.5); 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker // 2nd detect should be suppressed. 86*89c4ff92SAndroid Build Coastguard Worker REQUIRE(sorted.size() == 1); 87*89c4ff92SAndroid Build Coastguard Worker 88*89c4ff92SAndroid Build Coastguard Worker // First detect should be strongest and kept. 89*89c4ff92SAndroid Build Coastguard Worker REQUIRE(sorted[0] == 0); 90*89c4ff92SAndroid Build Coastguard Worker } 91