1 //
2 // Copyright © 2017-2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "DetectionPostProcessLayer.hpp"
7
8 #include "LayerCloneBase.hpp"
9
10 #include <armnn/TypesUtils.hpp>
11 #include <armnn/backends/TensorHandle.hpp>
12 #include <armnn/backends/WorkloadData.hpp>
13 #include <armnn/backends/WorkloadFactory.hpp>
14
15 namespace armnn
16 {
17
DetectionPostProcessLayer(const DetectionPostProcessDescriptor & param,const char * name)18 DetectionPostProcessLayer::DetectionPostProcessLayer(const DetectionPostProcessDescriptor& param, const char* name)
19 : LayerWithParameters(2, 4, LayerType::DetectionPostProcess, param, name)
20 {
21 }
22
CreateWorkload(const armnn::IWorkloadFactory & factory) const23 std::unique_ptr<IWorkload> DetectionPostProcessLayer::CreateWorkload(const armnn::IWorkloadFactory& factory) const
24 {
25 DetectionPostProcessQueueDescriptor descriptor;
26 descriptor.m_Anchors = m_Anchors.get();
27 SetAdditionalInfo(descriptor);
28
29 return factory.CreateWorkload(LayerType::DetectionPostProcess, descriptor, PrepInfoAndDesc(descriptor));
30 }
31
Clone(Graph & graph) const32 DetectionPostProcessLayer* DetectionPostProcessLayer::Clone(Graph& graph) const
33 {
34 auto layer = CloneBase<DetectionPostProcessLayer>(graph, m_Param, GetName());
35 layer->m_Anchors = m_Anchors ? m_Anchors : nullptr;
36 return std::move(layer);
37 }
38
ValidateTensorShapesFromInputs()39 void DetectionPostProcessLayer::ValidateTensorShapesFromInputs()
40 {
41 VerifyLayerConnections(2, CHECK_LOCATION());
42
43 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
44
45 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
46
47 // on this level constant data should not be released.
48 ARMNN_ASSERT_MSG(m_Anchors != nullptr, "DetectionPostProcessLayer: Anchors data should not be null.");
49
50 ARMNN_ASSERT_MSG(GetNumOutputSlots() == 4, "DetectionPostProcessLayer: The layer should return 4 outputs.");
51
52 std::vector<TensorShape> inferredShapes = InferOutputShapes(
53 { GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
54 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape() });
55
56 ARMNN_ASSERT(inferredShapes.size() == 4);
57 ARMNN_ASSERT(inferredShapes[0].GetDimensionality() == Dimensionality::Specified);
58 ARMNN_ASSERT(inferredShapes[1].GetDimensionality() == Dimensionality::Specified);
59 ARMNN_ASSERT(inferredShapes[2].GetDimensionality() == Dimensionality::Specified);
60 ARMNN_ASSERT(inferredShapes[3].GetDimensionality() == Dimensionality::Specified);
61
62 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "DetectionPostProcessLayer");
63
64 ValidateAndCopyShape(GetOutputSlot(1).GetTensorInfo().GetShape(),
65 inferredShapes[1],
66 m_ShapeInferenceMethod,
67 "DetectionPostProcessLayer", 1);
68
69 ValidateAndCopyShape(GetOutputSlot(2).GetTensorInfo().GetShape(),
70 inferredShapes[2],
71 m_ShapeInferenceMethod,
72 "DetectionPostProcessLayer", 2);
73
74 ValidateAndCopyShape(GetOutputSlot(3).GetTensorInfo().GetShape(),
75 inferredShapes[3],
76 m_ShapeInferenceMethod,
77 "DetectionPostProcessLayer", 3);
78 }
79
InferOutputShapes(const std::vector<TensorShape> &) const80 std::vector<TensorShape> DetectionPostProcessLayer::InferOutputShapes(const std::vector<TensorShape>&) const
81 {
82 unsigned int detectedBoxes = m_Param.m_MaxDetections * m_Param.m_MaxClassesPerDetection;
83
84 std::vector<TensorShape> results;
85 results.push_back({ 1, detectedBoxes, 4 });
86 results.push_back({ 1, detectedBoxes });
87 results.push_back({ 1, detectedBoxes });
88 results.push_back({ 1 });
89 return results;
90 }
91
GetConstantTensorsByRef() const92 Layer::ImmutableConstantTensors DetectionPostProcessLayer::GetConstantTensorsByRef() const
93 {
94 // For API stability DO NOT ALTER order and add new members to the end of vector
95 return { m_Anchors };
96 }
97
ExecuteStrategy(IStrategy & strategy) const98 void DetectionPostProcessLayer::ExecuteStrategy(IStrategy& strategy) const
99 {
100 ManagedConstTensorHandle managedAnchors(m_Anchors);
101 std::vector<armnn::ConstTensor> constTensors { {managedAnchors.GetTensorInfo(), managedAnchors.Map()} };
102 strategy.ExecuteStrategy(this, GetParameters(), constTensors, GetName());
103 }
104
105 } // namespace armnn
106