xref: /aosp_15_r20/external/armnn/samples/ObjectDetection/src/YoloResultDecoder.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "YoloResultDecoder.hpp"
7 
8 #include "NonMaxSuppression.hpp"
9 
10 #include <cassert>
11 #include <stdexcept>
12 
13 namespace od
14 {
15 
Decode(const common::InferenceResults<float> & networkResults,const common::Size & outputFrameSize,const common::Size & resizedFrameSize,const std::vector<std::string> & labels)16 DetectedObjects YoloResultDecoder::Decode(const common::InferenceResults<float>& networkResults,
17                                          const common::Size& outputFrameSize,
18                                          const common::Size& resizedFrameSize,
19                                          const std::vector<std::string>& labels)
20 {
21 
22     // Yolo v3 network outputs 1 tensor
23     if (networkResults.size() != 1)
24     {
25         throw std::runtime_error("Number of outputs from Yolo model doesn't equal 1");
26     }
27     auto element_step = m_boxElements + m_confidenceElements + m_numClasses;
28 
29     float longEdgeInput = std::max(resizedFrameSize.m_Width, resizedFrameSize.m_Height);
30     float longEdgeOutput = std::max(outputFrameSize.m_Width, outputFrameSize.m_Height);
31     const float resizeFactor = longEdgeOutput/longEdgeInput;
32 
33     DetectedObjects detectedObjects;
34     DetectedObjects resultsAfterNMS;
35 
36     for (const common::InferenceResult<float>& result : networkResults)
37     {
38         for (unsigned int i = 0; i < m_numBoxes; ++i)
39         {
40             const float* cur_box = &result[i * element_step];
41             // Objectness score
42             if (cur_box[4] > m_objectThreshold)
43             {
44                 for (unsigned int classIndex = 0; classIndex < m_numClasses; ++classIndex)
45                 {
46                     const float class_prob =  cur_box[4] * cur_box[5 + classIndex];
47 
48                     // class confidence
49 
50                     if (class_prob > m_ClsThreshold)
51                     {
52                         DetectedObject detectedObject;
53 
54                         detectedObject.SetScore(class_prob);
55 
56                         float topLeftX = cur_box[0] * resizeFactor;
57                         float topLeftY = cur_box[1] * resizeFactor;
58                         float botRightX = cur_box[2] * resizeFactor;
59                         float botRightY = cur_box[3] * resizeFactor;
60 
61                         assert(botRightX > topLeftX);
62                         assert(botRightY > topLeftY);
63 
64                         detectedObject.SetBoundingBox({static_cast<int>(topLeftX),
65                                                        static_cast<int>(topLeftY),
66                                                        static_cast<unsigned int>(botRightX-topLeftX),
67                                                        static_cast<unsigned int>(botRightY-topLeftY)});
68                         if(labels.size() > classIndex)
69                         {
70                             detectedObject.SetLabel(labels.at(classIndex));
71                         }
72                         else
73                         {
74                             detectedObject.SetLabel(std::to_string(classIndex));
75                         }
76                         detectedObject.SetId(classIndex);
77                         detectedObjects.emplace_back(detectedObject);
78                     }
79                 }
80             }
81         }
82 
83         std::vector<int> keepIndiciesAfterNMS = od::NonMaxSuppression(detectedObjects, m_NmsThreshold);
84 
85         for (const int ind: keepIndiciesAfterNMS)
86         {
87             resultsAfterNMS.emplace_back(detectedObjects[ind]);
88         }
89     }
90 
91     return resultsAfterNMS;
92 }
93 
YoloResultDecoder(float NMSThreshold,float ClsThreshold,float ObjectThreshold)94 YoloResultDecoder::YoloResultDecoder(float NMSThreshold, float ClsThreshold, float ObjectThreshold)
95         : m_NmsThreshold(NMSThreshold), m_ClsThreshold(ClsThreshold), m_objectThreshold(ObjectThreshold) {}
96 
97 }// namespace od
98 
99 
100 
101