1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <ResolveType.hpp>
8
9 #include <armnn/Types.hpp>
10
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/IBackendInternal.hpp>
13 #include <armnn/backends/WorkloadFactory.hpp>
14
15 #include <armnnTestUtils/TensorCopyUtils.hpp>
16 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
17 #include <armnnTestUtils/WorkloadTestUtils.hpp>
18
19 #include <armnnTestUtils/TensorHelpers.hpp>
20
21 #include <doctest/doctest.h>
22
23 namespace
24 {
25
26 using FloatData = std::vector<float>;
27 using QuantData = std::pair<float, int32_t>;
28
29 struct TestData
30 {
31 static const armnn::TensorShape s_BoxEncodingsShape;
32 static const armnn::TensorShape s_ScoresShape;
33 static const armnn::TensorShape s_AnchorsShape;
34
35 static const QuantData s_BoxEncodingsQuantData;
36 static const QuantData s_ScoresQuantData;
37 static const QuantData s_AnchorsQuantData;
38
39 static const FloatData s_BoxEncodings;
40 static const FloatData s_Scores;
41 static const FloatData s_Anchors;
42 };
43
44 struct RegularNmsExpectedResults
45 {
46 static const FloatData s_DetectionBoxes;
47 static const FloatData s_DetectionScores;
48 static const FloatData s_DetectionClasses;
49 static const FloatData s_NumDetections;
50 };
51
52 struct FastNmsExpectedResults
53 {
54 static const FloatData s_DetectionBoxes;
55 static const FloatData s_DetectionScores;
56 static const FloatData s_DetectionClasses;
57 static const FloatData s_NumDetections;
58 };
59
60 const armnn::TensorShape TestData::s_BoxEncodingsShape = { 1, 6, 4 };
61 const armnn::TensorShape TestData::s_ScoresShape = { 1, 6, 3 };
62 const armnn::TensorShape TestData::s_AnchorsShape = { 6, 4 };
63
64 const QuantData TestData::s_BoxEncodingsQuantData = { 1.00f, 1 };
65 const QuantData TestData::s_ScoresQuantData = { 0.01f, 0 };
66 const QuantData TestData::s_AnchorsQuantData = { 0.50f, 0 };
67
68 const FloatData TestData::s_BoxEncodings =
69 {
70 0.0f, 0.0f, 0.0f, 0.0f,
71 0.0f, 1.0f, 0.0f, 0.0f,
72 0.0f, -1.0f, 0.0f, 0.0f,
73 0.0f, 0.0f, 0.0f, 0.0f,
74 0.0f, 1.0f, 0.0f, 0.0f,
75 0.0f, 0.0f, 0.0f, 0.0f
76 };
77
78 const FloatData TestData::s_Scores =
79 {
80 0.0f, 0.90f, 0.80f,
81 0.0f, 0.75f, 0.72f,
82 0.0f, 0.60f, 0.50f,
83 0.0f, 0.93f, 0.95f,
84 0.0f, 0.50f, 0.40f,
85 0.0f, 0.30f, 0.20f
86 };
87
88 const FloatData TestData::s_Anchors =
89 {
90 0.5f, 0.5f, 1.0f, 1.0f,
91 0.5f, 0.5f, 1.0f, 1.0f,
92 0.5f, 0.5f, 1.0f, 1.0f,
93 0.5f, 10.5f, 1.0f, 1.0f,
94 0.5f, 10.5f, 1.0f, 1.0f,
95 0.5f, 100.5f, 1.0f, 1.0f
96 };
97
98 const FloatData RegularNmsExpectedResults::s_DetectionBoxes =
99 {
100 0.0f, 10.0f, 1.0f, 11.0f,
101 0.0f, 10.0f, 1.0f, 11.0f,
102 0.0f, 0.0f, 0.0f, 0.0f
103 };
104
105 const FloatData RegularNmsExpectedResults::s_DetectionScores =
106 {
107 0.95f, 0.93f, 0.0f
108 };
109
110 const FloatData RegularNmsExpectedResults::s_DetectionClasses =
111 {
112 1.0f, 0.0f, 0.0f
113 };
114
115 const FloatData RegularNmsExpectedResults::s_NumDetections = { 2.0f };
116
117 const FloatData FastNmsExpectedResults::s_DetectionBoxes =
118 {
119 0.0f, 10.0f, 1.0f, 11.0f,
120 0.0f, 0.0f, 1.0f, 1.0f,
121 0.0f, 100.0f, 1.0f, 101.0f
122 };
123
124 const FloatData FastNmsExpectedResults::s_DetectionScores =
125 {
126 0.95f, 0.9f, 0.3f
127 };
128
129 const FloatData FastNmsExpectedResults::s_DetectionClasses =
130 {
131 1.0f, 0.0f, 0.0f
132 };
133
134 const FloatData FastNmsExpectedResults::s_NumDetections = { 3.0f };
135
136 } // anonymous namespace
137
138 template<typename FactoryType,
139 armnn::DataType ArmnnType,
140 typename T = armnn::ResolveType<ArmnnType>>
DetectionPostProcessImpl(const armnn::TensorInfo & boxEncodingsInfo,const armnn::TensorInfo & scoresInfo,const armnn::TensorInfo & anchorsInfo,const std::vector<T> & boxEncodingsData,const std::vector<T> & scoresData,const std::vector<T> & anchorsData,const std::vector<float> & expectedDetectionBoxes,const std::vector<float> & expectedDetectionClasses,const std::vector<float> & expectedDetectionScores,const std::vector<float> & expectedNumDetections,bool useRegularNms)141 void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
142 const armnn::TensorInfo& scoresInfo,
143 const armnn::TensorInfo& anchorsInfo,
144 const std::vector<T>& boxEncodingsData,
145 const std::vector<T>& scoresData,
146 const std::vector<T>& anchorsData,
147 const std::vector<float>& expectedDetectionBoxes,
148 const std::vector<float>& expectedDetectionClasses,
149 const std::vector<float>& expectedDetectionScores,
150 const std::vector<float>& expectedNumDetections,
151 bool useRegularNms)
152 {
153 std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
154 armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
155
156 auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
157 FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
158 auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
159
160 armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
161 armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
162 armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
163 armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
164
165 std::vector<float> actualDetectionBoxesOutput(detectionBoxesInfo.GetNumElements());
166 std::vector<float> actualDetectionClassesOutput(detectionClassesInfo.GetNumElements());
167 std::vector<float> actualDetectionScoresOutput(detectionScoresInfo.GetNumElements());
168 std::vector<float> actualNumDetectionOutput(numDetectionInfo.GetNumElements());
169
170 auto boxedHandle = tensorHandleFactory.CreateTensorHandle(boxEncodingsInfo);
171 auto scoreshandle = tensorHandleFactory.CreateTensorHandle(scoresInfo);
172 auto anchorsHandle = tensorHandleFactory.CreateTensorHandle(anchorsInfo);
173 auto outputBoxesHandle = tensorHandleFactory.CreateTensorHandle(detectionBoxesInfo);
174 auto classesHandle = tensorHandleFactory.CreateTensorHandle(detectionClassesInfo);
175 auto outputScoresHandle = tensorHandleFactory.CreateTensorHandle(detectionScoresInfo);
176 auto numDetectionHandle = tensorHandleFactory.CreateTensorHandle(numDetectionInfo);
177
178 armnn::ScopedTensorHandle anchorsTensor(anchorsInfo);
179 AllocateAndCopyDataToITensorHandle(&anchorsTensor, anchorsData.data());
180
181 armnn::DetectionPostProcessQueueDescriptor data;
182 data.m_Parameters.m_UseRegularNms = useRegularNms;
183 data.m_Parameters.m_MaxDetections = 3;
184 data.m_Parameters.m_MaxClassesPerDetection = 1;
185 data.m_Parameters.m_DetectionsPerClass =1;
186 data.m_Parameters.m_NmsScoreThreshold = 0.0;
187 data.m_Parameters.m_NmsIouThreshold = 0.5;
188 data.m_Parameters.m_NumClasses = 2;
189 data.m_Parameters.m_ScaleY = 10.0;
190 data.m_Parameters.m_ScaleX = 10.0;
191 data.m_Parameters.m_ScaleH = 5.0;
192 data.m_Parameters.m_ScaleW = 5.0;
193 data.m_Anchors = &anchorsTensor;
194
195 armnn::WorkloadInfo info;
196 AddInputToWorkload(data, info, boxEncodingsInfo, boxedHandle.get());
197 AddInputToWorkload(data, info, scoresInfo, scoreshandle.get());
198 AddOutputToWorkload(data, info, detectionBoxesInfo, outputBoxesHandle.get());
199 AddOutputToWorkload(data, info, detectionClassesInfo, classesHandle.get());
200 AddOutputToWorkload(data, info, detectionScoresInfo, outputScoresHandle.get());
201 AddOutputToWorkload(data, info, numDetectionInfo, numDetectionHandle.get());
202
203 std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::DetectionPostProcess,
204 data,
205 info);
206
207 boxedHandle->Allocate();
208 scoreshandle->Allocate();
209 outputBoxesHandle->Allocate();
210 classesHandle->Allocate();
211 outputScoresHandle->Allocate();
212 numDetectionHandle->Allocate();
213
214 CopyDataToITensorHandle(boxedHandle.get(), boxEncodingsData.data());
215 CopyDataToITensorHandle(scoreshandle.get(), scoresData.data());
216
217 workload->Execute();
218
219 CopyDataFromITensorHandle(actualDetectionBoxesOutput.data(), outputBoxesHandle.get());
220 CopyDataFromITensorHandle(actualDetectionClassesOutput.data(), classesHandle.get());
221 CopyDataFromITensorHandle(actualDetectionScoresOutput.data(), outputScoresHandle.get());
222 CopyDataFromITensorHandle(actualNumDetectionOutput.data(), numDetectionHandle.get());
223
224 auto result = CompareTensors(actualDetectionBoxesOutput,
225 expectedDetectionBoxes,
226 outputBoxesHandle->GetShape(),
227 detectionBoxesInfo.GetShape());
228 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
229
230 result = CompareTensors(actualDetectionClassesOutput,
231 expectedDetectionClasses,
232 classesHandle->GetShape(),
233 detectionClassesInfo.GetShape());
234 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
235
236 result = CompareTensors(actualDetectionScoresOutput,
237 expectedDetectionScores,
238 outputScoresHandle->GetShape(),
239 detectionScoresInfo.GetShape());
240 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
241
242 result = CompareTensors(actualNumDetectionOutput,
243 expectedNumDetections,
244 numDetectionHandle->GetShape(),
245 numDetectionInfo.GetShape());
246 CHECK_MESSAGE(result.m_Result, result.m_Message.str());
247 }
248
249 template<armnn::DataType QuantizedType, typename RawType = armnn::ResolveType<QuantizedType>>
QuantizeData(RawType * quant,const float * dequant,const armnn::TensorInfo & info)250 void QuantizeData(RawType* quant, const float* dequant, const armnn::TensorInfo& info)
251 {
252 for (size_t i = 0; i < info.GetNumElements(); i++)
253 {
254 quant[i] = armnn::Quantize<RawType>(
255 dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
256 }
257 }
258
259 template<typename FactoryType>
DetectionPostProcessRegularNmsFloatTest()260 void DetectionPostProcessRegularNmsFloatTest()
261 {
262 return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
263 armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
264 armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
265 armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
266 TestData::s_BoxEncodings,
267 TestData::s_Scores,
268 TestData::s_Anchors,
269 RegularNmsExpectedResults::s_DetectionBoxes,
270 RegularNmsExpectedResults::s_DetectionClasses,
271 RegularNmsExpectedResults::s_DetectionScores,
272 RegularNmsExpectedResults::s_NumDetections,
273 true);
274 }
275
276 template<typename FactoryType,
277 armnn::DataType QuantizedType,
278 typename RawType = armnn::ResolveType<QuantizedType>>
DetectionPostProcessRegularNmsQuantizedTest()279 void DetectionPostProcessRegularNmsQuantizedTest()
280 {
281 armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
282 armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
283 armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
284
285 boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
286 boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
287
288 scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
289 scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
290
291 anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
292 anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
293
294 std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
295 QuantizeData<QuantizedType>(boxEncodingsData.data(),
296 TestData::s_BoxEncodings.data(),
297 boxEncodingsInfo);
298
299 std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
300 QuantizeData<QuantizedType>(scoresData.data(),
301 TestData::s_Scores.data(),
302 scoresInfo);
303
304 std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
305 QuantizeData<QuantizedType>(anchorsData.data(),
306 TestData::s_Anchors.data(),
307 anchorsInfo);
308
309 return DetectionPostProcessImpl<FactoryType, QuantizedType>(
310 boxEncodingsInfo,
311 scoresInfo,
312 anchorsInfo,
313 boxEncodingsData,
314 scoresData,
315 anchorsData,
316 RegularNmsExpectedResults::s_DetectionBoxes,
317 RegularNmsExpectedResults::s_DetectionClasses,
318 RegularNmsExpectedResults::s_DetectionScores,
319 RegularNmsExpectedResults::s_NumDetections,
320 true);
321 }
322
323 template<typename FactoryType>
DetectionPostProcessFastNmsFloatTest()324 void DetectionPostProcessFastNmsFloatTest()
325 {
326 return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
327 armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
328 armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
329 armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
330 TestData::s_BoxEncodings,
331 TestData::s_Scores,
332 TestData::s_Anchors,
333 FastNmsExpectedResults::s_DetectionBoxes,
334 FastNmsExpectedResults::s_DetectionClasses,
335 FastNmsExpectedResults::s_DetectionScores,
336 FastNmsExpectedResults::s_NumDetections,
337 false);
338 }
339
340 template<typename FactoryType,
341 armnn::DataType QuantizedType,
342 typename RawType = armnn::ResolveType<QuantizedType>>
DetectionPostProcessFastNmsQuantizedTest()343 void DetectionPostProcessFastNmsQuantizedTest()
344 {
345 armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
346 armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
347 armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
348
349 boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
350 boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
351
352 scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
353 scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
354
355 anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
356 anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
357
358 std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
359 QuantizeData<QuantizedType>(boxEncodingsData.data(),
360 TestData::s_BoxEncodings.data(),
361 boxEncodingsInfo);
362
363 std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
364 QuantizeData<QuantizedType>(scoresData.data(),
365 TestData::s_Scores.data(),
366 scoresInfo);
367
368 std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
369 QuantizeData<QuantizedType>(anchorsData.data(),
370 TestData::s_Anchors.data(),
371 anchorsInfo);
372
373 return DetectionPostProcessImpl<FactoryType, QuantizedType>(
374 boxEncodingsInfo,
375 scoresInfo,
376 anchorsInfo,
377 boxEncodingsData,
378 scoresData,
379 anchorsData,
380 FastNmsExpectedResults::s_DetectionBoxes,
381 FastNmsExpectedResults::s_DetectionClasses,
382 FastNmsExpectedResults::s_DetectionScores,
383 FastNmsExpectedResults::s_NumDetections,
384 false);
385 }
386