1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <ResolveType.hpp>
8 
9 #include <armnn/Types.hpp>
10 
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/IBackendInternal.hpp>
13 #include <armnn/backends/WorkloadFactory.hpp>
14 
15 #include <armnnTestUtils/TensorCopyUtils.hpp>
16 #include <backendsCommon/test/WorkloadFactoryHelper.hpp>
17 #include <armnnTestUtils/WorkloadTestUtils.hpp>
18 
19 #include <armnnTestUtils/TensorHelpers.hpp>
20 
21 #include <doctest/doctest.h>
22 
23 namespace
24 {
25 
26 using FloatData = std::vector<float>;
27 using QuantData = std::pair<float, int32_t>;
28 
29 struct TestData
30 {
31     static const armnn::TensorShape s_BoxEncodingsShape;
32     static const armnn::TensorShape s_ScoresShape;
33     static const armnn::TensorShape s_AnchorsShape;
34 
35     static const QuantData s_BoxEncodingsQuantData;
36     static const QuantData s_ScoresQuantData;
37     static const QuantData s_AnchorsQuantData;
38 
39     static const FloatData s_BoxEncodings;
40     static const FloatData s_Scores;
41     static const FloatData s_Anchors;
42 };
43 
44 struct RegularNmsExpectedResults
45 {
46     static const FloatData s_DetectionBoxes;
47     static const FloatData s_DetectionScores;
48     static const FloatData s_DetectionClasses;
49     static const FloatData s_NumDetections;
50 };
51 
52 struct FastNmsExpectedResults
53 {
54     static const FloatData s_DetectionBoxes;
55     static const FloatData s_DetectionScores;
56     static const FloatData s_DetectionClasses;
57     static const FloatData s_NumDetections;
58 };
59 
60 const armnn::TensorShape TestData::s_BoxEncodingsShape = { 1, 6, 4 };
61 const armnn::TensorShape TestData::s_ScoresShape       = { 1, 6, 3 };
62 const armnn::TensorShape TestData::s_AnchorsShape      = { 6, 4 };
63 
64 const QuantData TestData::s_BoxEncodingsQuantData = { 1.00f, 1 };
65 const QuantData TestData::s_ScoresQuantData       = { 0.01f, 0 };
66 const QuantData TestData::s_AnchorsQuantData      = { 0.50f, 0 };
67 
68 const FloatData TestData::s_BoxEncodings =
69 {
70     0.0f,  0.0f, 0.0f, 0.0f,
71     0.0f,  1.0f, 0.0f, 0.0f,
72     0.0f, -1.0f, 0.0f, 0.0f,
73     0.0f,  0.0f, 0.0f, 0.0f,
74     0.0f,  1.0f, 0.0f, 0.0f,
75     0.0f,  0.0f, 0.0f, 0.0f
76 };
77 
78 const FloatData TestData::s_Scores =
79 {
80     0.0f, 0.90f, 0.80f,
81     0.0f, 0.75f, 0.72f,
82     0.0f, 0.60f, 0.50f,
83     0.0f, 0.93f, 0.95f,
84     0.0f, 0.50f, 0.40f,
85     0.0f, 0.30f, 0.20f
86 };
87 
88 const FloatData TestData::s_Anchors =
89 {
90     0.5f,   0.5f, 1.0f, 1.0f,
91     0.5f,   0.5f, 1.0f, 1.0f,
92     0.5f,   0.5f, 1.0f, 1.0f,
93     0.5f,  10.5f, 1.0f, 1.0f,
94     0.5f,  10.5f, 1.0f, 1.0f,
95     0.5f, 100.5f, 1.0f, 1.0f
96 };
97 
98 const FloatData RegularNmsExpectedResults::s_DetectionBoxes =
99 {
100     0.0f, 10.0f, 1.0f, 11.0f,
101     0.0f, 10.0f, 1.0f, 11.0f,
102     0.0f,  0.0f, 0.0f,  0.0f
103 };
104 
105 const FloatData RegularNmsExpectedResults::s_DetectionScores =
106 {
107     0.95f, 0.93f, 0.0f
108 };
109 
110 const FloatData RegularNmsExpectedResults::s_DetectionClasses =
111 {
112     1.0f, 0.0f, 0.0f
113 };
114 
115 const FloatData RegularNmsExpectedResults::s_NumDetections = { 2.0f };
116 
117 const FloatData FastNmsExpectedResults::s_DetectionBoxes =
118 {
119     0.0f,  10.0f, 1.0f,  11.0f,
120     0.0f,   0.0f, 1.0f,   1.0f,
121     0.0f, 100.0f, 1.0f, 101.0f
122 };
123 
124 const FloatData FastNmsExpectedResults::s_DetectionScores =
125 {
126     0.95f, 0.9f, 0.3f
127 };
128 
129 const FloatData FastNmsExpectedResults::s_DetectionClasses =
130 {
131     1.0f, 0.0f, 0.0f
132 };
133 
134 const FloatData FastNmsExpectedResults::s_NumDetections = { 3.0f };
135 
136 } // anonymous namespace
137 
138 template<typename FactoryType,
139          armnn::DataType ArmnnType,
140          typename T = armnn::ResolveType<ArmnnType>>
DetectionPostProcessImpl(const armnn::TensorInfo & boxEncodingsInfo,const armnn::TensorInfo & scoresInfo,const armnn::TensorInfo & anchorsInfo,const std::vector<T> & boxEncodingsData,const std::vector<T> & scoresData,const std::vector<T> & anchorsData,const std::vector<float> & expectedDetectionBoxes,const std::vector<float> & expectedDetectionClasses,const std::vector<float> & expectedDetectionScores,const std::vector<float> & expectedNumDetections,bool useRegularNms)141 void DetectionPostProcessImpl(const armnn::TensorInfo& boxEncodingsInfo,
142                               const armnn::TensorInfo& scoresInfo,
143                               const armnn::TensorInfo& anchorsInfo,
144                               const std::vector<T>& boxEncodingsData,
145                               const std::vector<T>& scoresData,
146                               const std::vector<T>& anchorsData,
147                               const std::vector<float>& expectedDetectionBoxes,
148                               const std::vector<float>& expectedDetectionClasses,
149                               const std::vector<float>& expectedDetectionScores,
150                               const std::vector<float>& expectedNumDetections,
151                               bool useRegularNms)
152 {
153     std::unique_ptr<armnn::IProfiler> profiler = std::make_unique<armnn::IProfiler>();
154     armnn::ProfilerManager::GetInstance().RegisterProfiler(profiler.get());
155 
156     auto memoryManager = WorkloadFactoryHelper<FactoryType>::GetMemoryManager();
157     FactoryType workloadFactory = WorkloadFactoryHelper<FactoryType>::GetFactory(memoryManager);
158     auto tensorHandleFactory = WorkloadFactoryHelper<FactoryType>::GetTensorHandleFactory(memoryManager);
159 
160     armnn::TensorInfo detectionBoxesInfo({ 1, 3, 4 }, armnn::DataType::Float32);
161     armnn::TensorInfo detectionClassesInfo({ 1, 3 }, armnn::DataType::Float32);
162     armnn::TensorInfo detectionScoresInfo({ 1, 3 }, armnn::DataType::Float32);
163     armnn::TensorInfo numDetectionInfo({ 1 }, armnn::DataType::Float32);
164 
165     std::vector<float> actualDetectionBoxesOutput(detectionBoxesInfo.GetNumElements());
166     std::vector<float> actualDetectionClassesOutput(detectionClassesInfo.GetNumElements());
167     std::vector<float> actualDetectionScoresOutput(detectionScoresInfo.GetNumElements());
168     std::vector<float> actualNumDetectionOutput(numDetectionInfo.GetNumElements());
169 
170     auto boxedHandle = tensorHandleFactory.CreateTensorHandle(boxEncodingsInfo);
171     auto scoreshandle = tensorHandleFactory.CreateTensorHandle(scoresInfo);
172     auto anchorsHandle = tensorHandleFactory.CreateTensorHandle(anchorsInfo);
173     auto outputBoxesHandle = tensorHandleFactory.CreateTensorHandle(detectionBoxesInfo);
174     auto classesHandle = tensorHandleFactory.CreateTensorHandle(detectionClassesInfo);
175     auto outputScoresHandle = tensorHandleFactory.CreateTensorHandle(detectionScoresInfo);
176     auto numDetectionHandle = tensorHandleFactory.CreateTensorHandle(numDetectionInfo);
177 
178     armnn::ScopedTensorHandle anchorsTensor(anchorsInfo);
179     AllocateAndCopyDataToITensorHandle(&anchorsTensor, anchorsData.data());
180 
181     armnn::DetectionPostProcessQueueDescriptor data;
182     data.m_Parameters.m_UseRegularNms = useRegularNms;
183     data.m_Parameters.m_MaxDetections = 3;
184     data.m_Parameters.m_MaxClassesPerDetection = 1;
185     data.m_Parameters.m_DetectionsPerClass =1;
186     data.m_Parameters.m_NmsScoreThreshold = 0.0;
187     data.m_Parameters.m_NmsIouThreshold = 0.5;
188     data.m_Parameters.m_NumClasses = 2;
189     data.m_Parameters.m_ScaleY = 10.0;
190     data.m_Parameters.m_ScaleX = 10.0;
191     data.m_Parameters.m_ScaleH = 5.0;
192     data.m_Parameters.m_ScaleW = 5.0;
193     data.m_Anchors = &anchorsTensor;
194 
195     armnn::WorkloadInfo info;
196     AddInputToWorkload(data,  info, boxEncodingsInfo, boxedHandle.get());
197     AddInputToWorkload(data,  info, scoresInfo, scoreshandle.get());
198     AddOutputToWorkload(data, info, detectionBoxesInfo, outputBoxesHandle.get());
199     AddOutputToWorkload(data, info, detectionClassesInfo, classesHandle.get());
200     AddOutputToWorkload(data, info, detectionScoresInfo, outputScoresHandle.get());
201     AddOutputToWorkload(data, info, numDetectionInfo, numDetectionHandle.get());
202 
203     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::DetectionPostProcess,
204                                                                                 data,
205                                                                                 info);
206 
207     boxedHandle->Allocate();
208     scoreshandle->Allocate();
209     outputBoxesHandle->Allocate();
210     classesHandle->Allocate();
211     outputScoresHandle->Allocate();
212     numDetectionHandle->Allocate();
213 
214     CopyDataToITensorHandle(boxedHandle.get(), boxEncodingsData.data());
215     CopyDataToITensorHandle(scoreshandle.get(), scoresData.data());
216 
217     workload->Execute();
218 
219     CopyDataFromITensorHandle(actualDetectionBoxesOutput.data(), outputBoxesHandle.get());
220     CopyDataFromITensorHandle(actualDetectionClassesOutput.data(), classesHandle.get());
221     CopyDataFromITensorHandle(actualDetectionScoresOutput.data(), outputScoresHandle.get());
222     CopyDataFromITensorHandle(actualNumDetectionOutput.data(), numDetectionHandle.get());
223 
224     auto result = CompareTensors(actualDetectionBoxesOutput,
225                                  expectedDetectionBoxes,
226                                  outputBoxesHandle->GetShape(),
227                                  detectionBoxesInfo.GetShape());
228     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
229 
230     result = CompareTensors(actualDetectionClassesOutput,
231                             expectedDetectionClasses,
232                             classesHandle->GetShape(),
233                             detectionClassesInfo.GetShape());
234     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
235 
236     result = CompareTensors(actualDetectionScoresOutput,
237                             expectedDetectionScores,
238                             outputScoresHandle->GetShape(),
239                             detectionScoresInfo.GetShape());
240     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
241 
242     result = CompareTensors(actualNumDetectionOutput,
243                             expectedNumDetections,
244                             numDetectionHandle->GetShape(),
245                             numDetectionInfo.GetShape());
246     CHECK_MESSAGE(result.m_Result, result.m_Message.str());
247 }
248 
249 template<armnn::DataType QuantizedType, typename RawType = armnn::ResolveType<QuantizedType>>
QuantizeData(RawType * quant,const float * dequant,const armnn::TensorInfo & info)250 void QuantizeData(RawType* quant, const float* dequant, const armnn::TensorInfo& info)
251 {
252     for (size_t i = 0; i < info.GetNumElements(); i++)
253     {
254         quant[i] = armnn::Quantize<RawType>(
255             dequant[i], info.GetQuantizationScale(), info.GetQuantizationOffset());
256     }
257 }
258 
259 template<typename FactoryType>
DetectionPostProcessRegularNmsFloatTest()260 void DetectionPostProcessRegularNmsFloatTest()
261 {
262     return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
263         armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
264         armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
265         armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
266         TestData::s_BoxEncodings,
267         TestData::s_Scores,
268         TestData::s_Anchors,
269         RegularNmsExpectedResults::s_DetectionBoxes,
270         RegularNmsExpectedResults::s_DetectionClasses,
271         RegularNmsExpectedResults::s_DetectionScores,
272         RegularNmsExpectedResults::s_NumDetections,
273         true);
274 }
275 
276 template<typename FactoryType,
277          armnn::DataType QuantizedType,
278          typename RawType = armnn::ResolveType<QuantizedType>>
DetectionPostProcessRegularNmsQuantizedTest()279 void DetectionPostProcessRegularNmsQuantizedTest()
280 {
281     armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
282     armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
283     armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
284 
285     boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
286     boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
287 
288     scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
289     scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
290 
291     anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
292     anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
293 
294     std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
295     QuantizeData<QuantizedType>(boxEncodingsData.data(),
296                                 TestData::s_BoxEncodings.data(),
297                                 boxEncodingsInfo);
298 
299     std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
300     QuantizeData<QuantizedType>(scoresData.data(),
301                                 TestData::s_Scores.data(),
302                                 scoresInfo);
303 
304     std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
305     QuantizeData<QuantizedType>(anchorsData.data(),
306                                 TestData::s_Anchors.data(),
307                                 anchorsInfo);
308 
309     return DetectionPostProcessImpl<FactoryType, QuantizedType>(
310         boxEncodingsInfo,
311         scoresInfo,
312         anchorsInfo,
313         boxEncodingsData,
314         scoresData,
315         anchorsData,
316         RegularNmsExpectedResults::s_DetectionBoxes,
317         RegularNmsExpectedResults::s_DetectionClasses,
318         RegularNmsExpectedResults::s_DetectionScores,
319         RegularNmsExpectedResults::s_NumDetections,
320         true);
321 }
322 
323 template<typename FactoryType>
DetectionPostProcessFastNmsFloatTest()324 void DetectionPostProcessFastNmsFloatTest()
325 {
326     return DetectionPostProcessImpl<FactoryType, armnn::DataType::Float32>(
327         armnn::TensorInfo(TestData::s_BoxEncodingsShape, armnn::DataType::Float32),
328         armnn::TensorInfo(TestData::s_ScoresShape, armnn::DataType::Float32),
329         armnn::TensorInfo(TestData::s_AnchorsShape, armnn::DataType::Float32),
330         TestData::s_BoxEncodings,
331         TestData::s_Scores,
332         TestData::s_Anchors,
333         FastNmsExpectedResults::s_DetectionBoxes,
334         FastNmsExpectedResults::s_DetectionClasses,
335         FastNmsExpectedResults::s_DetectionScores,
336         FastNmsExpectedResults::s_NumDetections,
337         false);
338 }
339 
340 template<typename FactoryType,
341          armnn::DataType QuantizedType,
342          typename RawType = armnn::ResolveType<QuantizedType>>
DetectionPostProcessFastNmsQuantizedTest()343 void DetectionPostProcessFastNmsQuantizedTest()
344 {
345     armnn::TensorInfo boxEncodingsInfo(TestData::s_BoxEncodingsShape, QuantizedType);
346     armnn::TensorInfo scoresInfo(TestData::s_ScoresShape, QuantizedType);
347     armnn::TensorInfo anchorsInfo(TestData::s_AnchorsShape, QuantizedType);
348 
349     boxEncodingsInfo.SetQuantizationScale(TestData::s_BoxEncodingsQuantData.first);
350     boxEncodingsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
351 
352     scoresInfo.SetQuantizationScale(TestData::s_ScoresQuantData.first);
353     scoresInfo.SetQuantizationOffset(TestData::s_ScoresQuantData.second);
354 
355     anchorsInfo.SetQuantizationScale(TestData::s_AnchorsQuantData.first);
356     anchorsInfo.SetQuantizationOffset(TestData::s_BoxEncodingsQuantData.second);
357 
358     std::vector<RawType> boxEncodingsData(TestData::s_BoxEncodingsShape.GetNumElements());
359     QuantizeData<QuantizedType>(boxEncodingsData.data(),
360                                 TestData::s_BoxEncodings.data(),
361                                 boxEncodingsInfo);
362 
363     std::vector<RawType> scoresData(TestData::s_ScoresShape.GetNumElements());
364     QuantizeData<QuantizedType>(scoresData.data(),
365                                 TestData::s_Scores.data(),
366                                 scoresInfo);
367 
368     std::vector<RawType> anchorsData(TestData::s_AnchorsShape.GetNumElements());
369     QuantizeData<QuantizedType>(anchorsData.data(),
370                                 TestData::s_Anchors.data(),
371                                 anchorsInfo);
372 
373     return DetectionPostProcessImpl<FactoryType, QuantizedType>(
374         boxEncodingsInfo,
375         scoresInfo,
376         anchorsInfo,
377         boxEncodingsData,
378         scoresData,
379         anchorsData,
380         FastNmsExpectedResults::s_DetectionBoxes,
381         FastNmsExpectedResults::s_DetectionClasses,
382         FastNmsExpectedResults::s_DetectionScores,
383         FastNmsExpectedResults::s_NumDetections,
384         false);
385 }
386