1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include <reference/workloads/DetectionPostProcess.hpp>
7
8 #include <armnn/Descriptors.hpp>
9 #include <armnn/Types.hpp>
10
11 #include <doctest/doctest.h>
12
13 TEST_SUITE("RefDetectionPostProcess")
14 {
15 TEST_CASE("TopKSortTest")
16 {
17 unsigned int k = 3;
18 unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
19 float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
20 armnn::TopKSort(k, indices, values, 8);
21 CHECK(indices[0] == 7);
22 CHECK(indices[1] == 1);
23 CHECK(indices[2] == 2);
24 }
25
26 TEST_CASE("FullTopKSortTest")
27 {
28 unsigned int k = 8;
29 unsigned int indices[8] = { 0, 1, 2, 3, 4, 5, 6, 7 };
30 float values[8] = { 0, 7, 6, 5, 4, 3, 2, 500 };
31 armnn::TopKSort(k, indices, values, 8);
32 CHECK(indices[0] == 7);
33 CHECK(indices[1] == 1);
34 CHECK(indices[2] == 2);
35 CHECK(indices[3] == 3);
36 CHECK(indices[4] == 4);
37 CHECK(indices[5] == 5);
38 CHECK(indices[6] == 6);
39 CHECK(indices[7] == 0);
40 }
41
42 TEST_CASE("IouTest")
43 {
44 float boxI[4] = { 0.0f, 0.0f, 10.0f, 10.0f };
45 float boxJ[4] = { 1.0f, 1.0f, 11.0f, 11.0f };
46 float iou = armnn::IntersectionOverUnion(boxI, boxJ);
47 CHECK(iou == doctest::Approx(0.68).epsilon(0.001f));
48 }
49
50 TEST_CASE("NmsFunction")
51 {
52 std::vector<float> boxCorners({
53 0.0f, 0.0f, 1.0f, 1.0f,
54 0.0f, 0.1f, 1.0f, 1.1f,
55 0.0f, -0.1f, 1.0f, 0.9f,
56 0.0f, 10.0f, 1.0f, 11.0f,
57 0.0f, 10.1f, 1.0f, 11.1f,
58 0.0f, 100.0f, 1.0f, 101.0f
59 });
60
61 std::vector<float> scores({ 0.9f, 0.75f, 0.6f, 0.93f, 0.5f, 0.3f });
62
63 std::vector<unsigned int> result =
64 armnn::NonMaxSuppression(6, boxCorners, scores, 0.0, 3, 0.5);
65
66 CHECK(result.size() == 3);
67 CHECK(result[0] == 3);
68 CHECK(result[1] == 0);
69 CHECK(result[2] == 5);
70 }
71
DetectionPostProcessTestImpl(bool useRegularNms,const std::vector<float> & expectedDetectionBoxes,const std::vector<float> & expectedDetectionClasses,const std::vector<float> & expectedDetectionScores,const std::vector<float> & expectedNumDetections)72 void DetectionPostProcessTestImpl(bool useRegularNms,
73 const std::vector<float>& expectedDetectionBoxes,
74 const std::vector<float>& expectedDetectionClasses,
75 const std::vector<float>& expectedDetectionScores,
76 const std::vector<float>& expectedNumDetections)
77 {
78 armnn::TensorInfo boxEncodingsInfo({ 1, 6, 4 }, armnn::DataType::Float32);
79 armnn::TensorInfo scoresInfo({ 1, 6, 3 }, armnn::DataType::Float32);
80 armnn::TensorInfo anchorsInfo({ 6, 4 }, armnn::DataType::Float32);
81
82 armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
83 armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
84 armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
85 armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
86
87 armnn::DetectionPostProcessDescriptor desc;
88 desc.m_UseRegularNms = useRegularNms;
89 desc.m_MaxDetections = 3;
90 desc.m_MaxClassesPerDetection = 1;
91 desc.m_DetectionsPerClass =1;
92 desc.m_NmsScoreThreshold = 0.0;
93 desc.m_NmsIouThreshold = 0.5;
94 desc.m_NumClasses = 2;
95 desc.m_ScaleY = 10.0;
96 desc.m_ScaleX = 10.0;
97 desc.m_ScaleH = 5.0;
98 desc.m_ScaleW = 5.0;
99
100 std::vector<float> boxEncodings({
101 0.0f, 0.0f, 0.0f, 0.0f,
102 0.0f, 1.0f, 0.0f, 0.0f,
103 0.0f, -1.0f, 0.0f, 0.0f,
104 0.0f, 0.0f, 0.0f, 0.0f,
105 0.0f, 1.0f, 0.0f, 0.0f,
106 0.0f, 0.0f, 0.0f, 0.0f
107 });
108
109 std::vector<float> scores({
110 0.0f, 0.9f, 0.8f,
111 0.0f, 0.75f, 0.72f,
112 0.0f, 0.6f, 0.5f,
113 0.0f, 0.93f, 0.95f,
114 0.0f, 0.5f, 0.4f,
115 0.0f, 0.3f, 0.2f
116 });
117
118 std::vector<float> anchors({
119 0.5f, 0.5f, 1.0f, 1.0f,
120 0.5f, 0.5f, 1.0f, 1.0f,
121 0.5f, 0.5f, 1.0f, 1.0f,
122 0.5f, 10.5f, 1.0f, 1.0f,
123 0.5f, 10.5f, 1.0f, 1.0f,
124 0.5f, 100.5f, 1.0f, 1.0f
125 });
126
127 auto boxEncodingsDecoder = armnn::MakeDecoder<float>(boxEncodingsInfo, boxEncodings.data());
128 auto scoresDecoder = armnn::MakeDecoder<float>(scoresInfo, scores.data());
129 auto anchorsDecoder = armnn::MakeDecoder<float>(anchorsInfo, anchors.data());
130
131 std::vector<float> detectionBoxes(detectionBoxesInfo.GetNumElements());
132 std::vector<float> detectionScores(detectionScoresInfo.GetNumElements());
133 std::vector<float> detectionClasses(detectionClassesInfo.GetNumElements());
134 std::vector<float> numDetections(1);
135
136 armnn::DetectionPostProcess(boxEncodingsInfo,
137 scoresInfo,
138 anchorsInfo,
139 detectionBoxesInfo,
140 detectionClassesInfo,
141 detectionScoresInfo,
142 numDetectionInfo,
143 desc,
144 *boxEncodingsDecoder,
145 *scoresDecoder,
146 *anchorsDecoder,
147 detectionBoxes.data(),
148 detectionClasses.data(),
149 detectionScores.data(),
150 numDetections.data());
151
152 CHECK(std::equal(detectionBoxes.begin(),
153 detectionBoxes.end(),
154 expectedDetectionBoxes.begin(),
155 expectedDetectionBoxes.end()));
156
157 CHECK(std::equal(detectionScores.begin(), detectionScores.end(),
158 expectedDetectionScores.begin(), expectedDetectionScores.end()));
159
160 CHECK(std::equal(detectionClasses.begin(), detectionClasses.end(),
161 expectedDetectionClasses.begin(), expectedDetectionClasses.end()));
162
163 CHECK(std::equal(numDetections.begin(), numDetections.end(),
164 expectedNumDetections.begin(), expectedNumDetections.end()));
165 }
166
167 TEST_CASE("RegularNmsDetectionPostProcess")
168 {
169 std::vector<float> expectedDetectionBoxes({
170 0.0f, 10.0f, 1.0f, 11.0f,
171 0.0f, 10.0f, 1.0f, 11.0f,
172 0.0f, 0.0f, 0.0f, 0.0f
173 });
174
175 std::vector<float> expectedDetectionScores({ 0.95f, 0.93f, 0.0f });
176 std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
177 std::vector<float> expectedNumDetections({ 2.0f });
178
179 DetectionPostProcessTestImpl(true, expectedDetectionBoxes, expectedDetectionClasses,
180 expectedDetectionScores, expectedNumDetections);
181 }
182
183 TEST_CASE("FastNmsDetectionPostProcess")
184 {
185 std::vector<float> expectedDetectionBoxes({
186 0.0f, 10.0f, 1.0f, 11.0f,
187 0.0f, 0.0f, 1.0f, 1.0f,
188 0.0f, 100.0f, 1.0f, 101.0f
189 });
190 std::vector<float> expectedDetectionScores({ 0.95f, 0.9f, 0.3f });
191 std::vector<float> expectedDetectionClasses({ 1.0f, 0.0f, 0.0f });
192 std::vector<float> expectedNumDetections({ 3.0f });
193
194 DetectionPostProcessTestImpl(false, expectedDetectionBoxes, expectedDetectionClasses,
195 expectedDetectionScores, expectedNumDetections);
196 }
197
198 }