xref: /aosp_15_r20/external/armnn/src/backends/reference/test/RefDetectionPostProcessTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <reference/workloads/DetectionPostProcess.hpp>
7 
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/Types.hpp>
10 
11 #include <doctest/doctest.h>
12 
13 TEST_SUITE("RefDetectionPostProcess")
14 {
15 TEST_CASE("TopKSortTest")
16 {
17     unsigned int k = 3;
18     unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
19     float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
20     armnn::TopKSort(k, indices, values, 8);
21     CHECK(indices[0] == 7);
22     CHECK(indices[1] == 1);
23     CHECK(indices[2] == 2);
24 }
25 
26 TEST_CASE("FullTopKSortTest")
27 {
28     unsigned int k = 8;
29     unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
30     float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
31     armnn::TopKSort(k, indices, values, 8);
32     CHECK(indices[0] == 7);
33     CHECK(indices[1] == 1);
34     CHECK(indices[2] == 2);
35     CHECK(indices[3] == 3);
36     CHECK(indices[4] == 4);
37     CHECK(indices[5] == 5);
38     CHECK(indices[6] == 6);
39     CHECK(indices[7] == 0);
40 }
41 
42 TEST_CASE("IouTest")
43 {
44     float boxI[4] = { 0.0f, 0.0f, 10.0f, 10.0f };
45     float boxJ[4] = { 1.0f, 1.0f, 11.0f, 11.0f };
46     float iou = armnn::IntersectionOverUnion(boxI, boxJ);
47     CHECK(iou == doctest::Approx(0.68).epsilon(0.001f));
48 }
49 
50 TEST_CASE("NmsFunction")
51 {
52     std::vector<float> boxCorners({
53         0.0f, 0.0f, 1.0f, 1.0f,
54         0.0f, 0.1f, 1.0f, 1.1f,
55         0.0f, -0.1f, 1.0f, 0.9f,
56         0.0f, 10.0f, 1.0f, 11.0f,
57         0.0f, 10.1f, 1.0f, 11.1f,
58         0.0f, 100.0f, 1.0f, 101.0f
59     });
60 
61     std::vector<float> scores({ 0.9f, 0.75f, 0.6f, 0.93f, 0.5f, 0.3f });
62 
63     std::vector<unsigned int> result =
64         armnn::NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5);
65 
66     CHECK(result.size() == 3);
67     CHECK(result[0] == 3);
68     CHECK(result[1] == 0);
69     CHECK(result[2] == 5);
70 }
71 
DetectionPostProcessTestImpl(bool useRegularNms,const std::vector<float> & expectedDetectionBoxes,const std::vector<float> & expectedDetectionClasses,const std::vector<float> & expectedDetectionScores,const std::vector<float> & expectedNumDetections)72 void DetectionPostProcessTestImpl(bool useRegularNms,
73                                   const std::vector<float>& expectedDetectionBoxes,
74                                   const std::vector<float>& expectedDetectionClasses,
75                                   const std::vector<float>& expectedDetectionScores,
76                                   const std::vector<float>& expectedNumDetections)
77 {
78     armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::Float32);
79     armnn::TensorInfo scoresInfo({ 1, 6, 3 }, armnn::DataType::Float32);
80     armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32);
81 
82     armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
83     armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
84     armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
85     armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
86 
87     armnn::DetectionPostProcessDescriptor desc;
88     desc.m_UseRegularNms = useRegularNms;
89     desc.m_MaxDetections = 3;
90     desc.m_MaxClassesPerDetection = 1;
91     desc.m_DetectionsPerClass =1;
92     desc.m_NmsScoreThreshold = 0.0;
93     desc.m_NmsIouThreshold = 0.5;
94     desc.m_NumClasses = 2;
95     desc.m_ScaleY = 10.0;
96     desc.m_ScaleX = 10.0;
97     desc.m_ScaleH = 5.0;
98     desc.m_ScaleW = 5.0;
99 
100     std::vector<float> boxEncodings({
101         0.0f, 0.0f, 0.0f, 0.0f,
102         0.0f, 1.0f, 0.0f, 0.0f,
103         0.0f, -1.0f, 0.0f, 0.0f,
104         0.0f, 0.0f, 0.0f, 0.0f,
105         0.0f, 1.0f, 0.0f, 0.0f,
106         0.0f, 0.0f, 0.0f, 0.0f
107     });
108 
109     std::vector<float> scores({
110         0.0f, 0.9f, 0.8f,
111         0.0f, 0.75f, 0.72f,
112         0.0f, 0.6f, 0.5f,
113         0.0f, 0.93f, 0.95f,
114         0.0f, 0.5f, 0.4f,
115         0.0f, 0.3f, 0.2f
116     });
117 
118     std::vector<float> anchors({
119         0.5f, 0.5f, 1.0f, 1.0f,
120         0.5f, 0.5f, 1.0f, 1.0f,
121         0.5f, 0.5f, 1.0f, 1.0f,
122         0.5f, 10.5f, 1.0f, 1.0f,
123         0.5f, 10.5f, 1.0f, 1.0f,
124         0.5f, 100.5f, 1.0f, 1.0f
125     });
126 
127     auto boxEncodingsDecoder = armnn::MakeDecoder<float>(boxEncodingsInfo, boxEncodings.data());
128     auto scoresDecoder       = armnn::MakeDecoder<float>(scoresInfo, scores.data());
129     auto anchorsDecoder      = armnn::MakeDecoder<float>(anchorsInfo, anchors.data());
130 
131     std::vector<float> detectionBoxes(detectionBoxesInfo.GetNumElements());
132     std::vector<float> detectionScores(detectionScoresInfo.GetNumElements());
133     std::vector<float> detectionClasses(detectionClassesInfo.GetNumElements());
134     std::vector<float> numDetections(1);
135 
136     armnn::DetectionPostProcess(boxEncodingsInfo,
137                                 scoresInfo,
138                                 anchorsInfo,
139                                 detectionBoxesInfo,
140                                 detectionClassesInfo,
141                                 detectionScoresInfo,
142                                 numDetectionInfo,
143                                 desc,
144                                 *boxEncodingsDecoder,
145                                 *scoresDecoder,
146                                 *anchorsDecoder,
147                                 detectionBoxes.data(),
148                                 detectionClasses.data(),
149                                 detectionScores.data(),
150                                 numDetections.data());
151 
152     CHECK(std::equal(detectionBoxes.begin(),
153                                   detectionBoxes.end(),
154                                   expectedDetectionBoxes.begin(),
155                                   expectedDetectionBoxes.end()));
156 
157     CHECK(std::equal(detectionScores.begin(), detectionScores.end(),
158         expectedDetectionScores.begin(), expectedDetectionScores.end()));
159 
160     CHECK(std::equal(detectionClasses.begin(), detectionClasses.end(),
161         expectedDetectionClasses.begin(), expectedDetectionClasses.end()));
162 
163     CHECK(std::equal(numDetections.begin(), numDetections.end(),
164         expectedNumDetections.begin(), expectedNumDetections.end()));
165 }
166 
167 TEST_CASE("RegularNmsDetectionPostProcess")
168 {
169     std::vector<float> expectedDetectionBoxes({
170         0.0f, 10.0f, 1.0f, 11.0f,
171         0.0f, 10.0f, 1.0f, 11.0f,
172         0.0f, 0.0f, 0.0f, 0.0f
173     });
174 
175     std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
176     std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
177     std::vector<float> expectedNumDetections({ 2.0f });
178 
179     DetectionPostProcessTestImpl(true, expectedDetectionBoxes, expectedDetectionClasses,
180                                  expectedDetectionScores, expectedNumDetections);
181 }
182 
183 TEST_CASE("FastNmsDetectionPostProcess")
184 {
185     std::vector<float> expectedDetectionBoxes({
186         0.0f, 10.0f, 1.0f, 11.0f,
187         0.0f, 0.0f, 1.0f, 1.0f,
188         0.0f, 100.0f, 1.0f, 101.0f
189     });
190     std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
191     std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
192     std::vector<float> expectedNumDetections({ 3.0f });
193 
194     DetectionPostProcessTestImpl(false, expectedDetectionBoxes, expectedDetectionClasses,
195                                  expectedDetectionScores, expectedNumDetections);
196 }
197 
198 }