1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "MergeLayer.hpp"
6
7 #include "LayerCloneBase.hpp"
8
9 #include <armnn/backends/WorkloadData.hpp>
10 #include <armnn/backends/WorkloadFactory.hpp>
11
12 namespace armnn
13 {
14
MergeLayer(const char * name)15 MergeLayer::MergeLayer(const char* name)
16 : Layer(2, 1, LayerType::Merge, name)
17 {}
18
CreateWorkload(const IWorkloadFactory & factory) const19 std::unique_ptr<IWorkload> MergeLayer::CreateWorkload(const IWorkloadFactory& factory) const
20 {
21 IgnoreUnused(factory);
22 return nullptr;
23 }
24
Clone(Graph & graph) const25 MergeLayer* MergeLayer::Clone(Graph& graph) const
26 {
27 return CloneBase<MergeLayer>(graph, GetName());
28 }
29
ValidateTensorShapesFromInputs()30 void MergeLayer::ValidateTensorShapesFromInputs()
31 {
32 VerifyLayerConnections(2, CHECK_LOCATION());
33
34 const TensorShape& outputShape = GetOutputSlot(0).GetTensorInfo().GetShape();
35
36 VerifyShapeInferenceType(outputShape, m_ShapeInferenceMethod);
37
38 std::vector<TensorShape> inferredShapes = InferOutputShapes({
39 GetInputSlot(0).GetConnection()->GetTensorInfo().GetShape(),
40 GetInputSlot(1).GetConnection()->GetTensorInfo().GetShape(),
41 });
42
43 ARMNN_ASSERT(inferredShapes.size() == 1);
44
45 ValidateAndCopyShape(outputShape, inferredShapes[0], m_ShapeInferenceMethod, "MergeLayer");
46 }
47
InferOutputShapes(const std::vector<TensorShape> & inputShapes) const48 std::vector<TensorShape> MergeLayer::InferOutputShapes(const std::vector<TensorShape>& inputShapes) const
49 {
50 ARMNN_ASSERT(inputShapes.size() == 2);
51
52 ConditionalThrowIfNotEqual<LayerValidationException>(
53 "MergeLayer: TensorShapes set on inputs do not match",
54 inputShapes[0],
55 inputShapes[1]
56 );
57
58 return {inputShapes[0]};
59 }
60
ExecuteStrategy(IStrategy & strategy) const61 void MergeLayer::ExecuteStrategy(IStrategy& strategy) const
62 {
63 strategy.ExecuteStrategy(this, GetParameters(), {}, GetName());
64 }
65
66 } // namespace armnn
67