xref: /aosp_15_r20/external/armnn/src/armnn/layers/MergeLayer.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "MergeLayer.hpp"
6 
7 #include "LayerCloneBase.hpp"
8 
9 #include <armnn/backends/WorkloadData.hpp>
10 #include <armnn/backends/WorkloadFactory.hpp>
11 
12 namespace armnn
13 {
14 
MergeLayer(const char * name)15 MergeLayer::MergeLayer(const char* name)
16     : Layer(2, 1, LayerType::Merge, name)
17 {}
18 
CreateWorkload(const IWorkloadFactory & factory) const19 std::unique_ptr<IWorkload> MergeLayer::CreateWorkload(const IWorkloadFactory& factory) const
20 {
21     IgnoreUnused(factory);
22     return nullptr;
23 }
24 
Clone(Graph & graph) const25 MergeLayer* MergeLayer::Clone(Graph& graph) const
26 {
27     return CloneBase<MergeLayer>(graph, GetName());
28 }
29 
ValidateTensorShapesFromInputs()30 void MergeLayer::ValidateTensorShapesFromInputs()
31 {
32     VerifyLayerConnections(2, CHECK_LOCATION());
33 
34     const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
35 
36     VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
37 
38     std::vector<TensorShape> inferredShapes = InferOutputShapes({
39         GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
40         GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
41     });
42 
43     ARMNN_ASSERT(inferredShapes.size() == 1);
44 
45     ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "MergeLayer");
46 }
47 
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const48 std::vector<TensorShape> MergeLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
49 {
50     ARMNN_ASSERT(inputShapes.size() == 2);
51 
52     ConditionalThrowIfNotEqual<LayerValidationException>(
53         "MergeLayer: TensorShapes set on inputs do not match",
54         inputShapes[0],
55         inputShapes[1]
56     );
57 
58     return {inputShapes[0]};
59 }
60 
ExecuteStrategy(IStrategy & strategy) const61 void MergeLayer::ExecuteStrategy(IStrategy& strategy) const
62 {
63     strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
64 }
65 
66 } // namespace armnn
67