xref: /aosp_15_r20/external/armnn/src/armnn/layers/DetectionPostProcessLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "DetectionPostProcessLayer.hpp"
7 
8 #include "LayerCloneBase.hpp"
9 
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadData.hpp>
13 #include <armnn/backends/WorkloadFactory.hpp>
14 
15 namespace armnn
16 {
17 
DetectionPostProcessLayer(const DetectionPostProcessDescriptor & param,const char * name)18 DetectionPostProcessLayer::DetectionPostProcessLayer(const DetectionPostProcessDescriptor& param, const char* name)
19     : LayerWithParameters(2, 4, LayerType::DetectionPostProcess, param, name)
20 {
21 }
22 
CreateWorkload(const armnn::IWorkloadFactory & factory) const23 std::unique_ptr<IWorkload> DetectionPostProcessLayer::CreateWorkload(const armnn::IWorkloadFactory& factory) const
24 {
25     DetectionPostProcessQueueDescriptor descriptor;
26     descriptor.m_Anchors = m_Anchors.get();
27     SetAdditionalInfo(descriptor);
28 
29     return factory.CreateWorkload(LayerType::DetectionPostProcess, descriptor, PrepInfoAndDesc(descriptor));
30 }
31 
Clone(Graph & graph) const32 DetectionPostProcessLayer* DetectionPostProcessLayer::Clone(Graph& graph) const
33 {
34     auto layer = CloneBase<DetectionPostProcessLayer>(graph, m_Param, GetName());
35     layer->m_Anchors = m_Anchors ? m_Anchors : nullptr;
36     return std::move(layer);
37 }
38 
ValidateTensorShapesFromInputs()39 void DetectionPostProcessLayer::ValidateTensorShapesFromInputs()
40 {
41     VerifyLayerConnections(2, CHECK_LOCATION());
42 
43     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
44 
45     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
46 
47     // on this level constant data should not be released.
48     ARMNN_ASSERT_MSG(m_Anchors != nullptr, "DetectionPostProcessLayer: Anchors data should not be null.");
49 
50     ARMNN_ASSERT_MSG(GetNumOutputSlots() == 4, "DetectionPostProcessLayer: The layer should return 4 outputs.");
51 
52     std::vector<TensorShape> inferredShapes = InferOutputShapes(
53             { GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
54               GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
55 
56     ARMNN_ASSERT(inferredShapes.size() == 4);
57     ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified);
58     ARMNN_ASSERT(inferredShapes[1].GetDimensionality() == Dimensionality::Specified);
59     ARMNN_ASSERT(inferredShapes[2].GetDimensionality() == Dimensionality::Specified);
60     ARMNN_ASSERT(inferredShapes[3].GetDimensionality() == Dimensionality::Specified);
61 
62     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "DetectionPostProcessLayer");
63 
64     ValidateAndCopyShape(GetOutputSlot(1).GetTensorInfo().GetShape(),
65                          inferredShapes[1],
66                          m_ShapeInferenceMethod,
67                          "DetectionPostProcessLayer", 1);
68 
69     ValidateAndCopyShape(GetOutputSlot(2).GetTensorInfo().GetShape(),
70                          inferredShapes[2],
71                          m_ShapeInferenceMethod,
72                          "DetectionPostProcessLayer", 2);
73 
74     ValidateAndCopyShape(GetOutputSlot(3).GetTensorInfo().GetShape(),
75                          inferredShapes[3],
76                          m_ShapeInferenceMethod,
77                          "DetectionPostProcessLayer", 3);
78 }
79 
InferOutputShapes(const std::vector<TensorShape> &) const80 std::vector<TensorShape> DetectionPostProcessLayer::InferOutputShapes(const std::vector<TensorShape>&) const
81 {
82     unsigned int detectedBoxes = m_Param.m_MaxDetections * m_Param.m_MaxClassesPerDetection;
83 
84     std::vector<TensorShape> results;
85     results.push_back({ 1, detectedBoxes, 4 });
86     results.push_back({ 1, detectedBoxes });
87     results.push_back({ 1, detectedBoxes });
88     results.push_back({ 1 });
89     return results;
90 }
91 
GetConstantTensorsByRef() const92 Layer::ImmutableConstantTensors DetectionPostProcessLayer::GetConstantTensorsByRef() const
93 {
94     // For API stability DO NOT ALTER order and add new members to the end of vector
95     return { m_Anchors };
96 }
97 
ExecuteStrategy(IStrategy & strategy) const98 void DetectionPostProcessLayer::ExecuteStrategy(IStrategy& strategy) const
99 {
100     ManagedConstTensorHandle managedAnchors(m_Anchors);
101     std::vector<armnn::ConstTensor> constTensors { {managedAnchors.GetTensorInfo(), managedAnchors.Map()} };
102     strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
103 }
104 
105 } // namespace armnn
106