1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker
6*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker
8*89c4ff92SAndroid Build Coastguard Worker #include <BackendSettings.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <Graph.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
12*89c4ff92SAndroid Build Coastguard Worker
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/StrategyBase.hpp>
17*89c4ff92SAndroid Build Coastguard Worker
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
21*89c4ff92SAndroid Build Coastguard Worker
22*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/LayerSupportBase.hpp>
23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
24*89c4ff92SAndroid Build Coastguard Worker
25*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
26*89c4ff92SAndroid Build Coastguard Worker
27*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
28*89c4ff92SAndroid Build Coastguard Worker
29*89c4ff92SAndroid Build Coastguard Worker namespace
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker
CreateLSTMLayerHelper(Graph & graph,bool CifgEnabled)32*89c4ff92SAndroid Build Coastguard Worker void CreateLSTMLayerHelper(Graph &graph, bool CifgEnabled)
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker LstmDescriptor layerDesc;
35*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ActivationFunc = 4;
36*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ClippingThresCell = 0.2f;
37*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ClippingThresProj = 0.4f;
38*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_CifgEnabled = CifgEnabled;
39*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_PeepholeEnabled = false;
40*89c4ff92SAndroid Build Coastguard Worker layerDesc.m_ProjectionEnabled = false;
41*89c4ff92SAndroid Build Coastguard Worker
42*89c4ff92SAndroid Build Coastguard Worker LstmLayer* const layer = graph.AddLayer<LstmLayer>(layerDesc, "layer");
43*89c4ff92SAndroid Build Coastguard Worker unsigned int batchSize = 3;
44*89c4ff92SAndroid Build Coastguard Worker unsigned int inputSize = 2;
45*89c4ff92SAndroid Build Coastguard Worker unsigned int numUnits = 4;
46*89c4ff92SAndroid Build Coastguard Worker unsigned int outputSize = 4;
47*89c4ff92SAndroid Build Coastguard Worker
48*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToForgetWeights = std::make_unique<ScopedTensorHandle>
49*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, inputSize }, DataType::Float32));
50*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToCellWeights = std::make_unique<ScopedTensorHandle>
51*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, inputSize }, DataType::Float32));
52*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToOutputWeights = std::make_unique<ScopedTensorHandle>
53*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, inputSize }, DataType::Float32));
54*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToForgetWeights = std::make_unique<ScopedTensorHandle>
55*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, outputSize }, DataType::Float32));
56*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToCellWeights = std::make_unique<ScopedTensorHandle>
57*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, outputSize }, DataType::Float32));
58*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToOutputWeights = std::make_unique<ScopedTensorHandle>
59*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, outputSize }, DataType::Float32));
60*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_ForgetGateBias = std::make_unique<ScopedTensorHandle>
61*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
62*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_CellBias = std::make_unique<ScopedTensorHandle>
63*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
64*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_OutputGateBias = std::make_unique<ScopedTensorHandle>
65*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
66*89c4ff92SAndroid Build Coastguard Worker
67*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToForgetWeights->Allocate();
68*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToCellWeights->Allocate();
69*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_InputToOutputWeights->Allocate();
70*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToForgetWeights->Allocate();
71*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToCellWeights->Allocate();
72*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_RecurrentToOutputWeights->Allocate();
73*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_ForgetGateBias->Allocate();
74*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_CellBias->Allocate();
75*89c4ff92SAndroid Build Coastguard Worker layer->m_BasicParameters.m_OutputGateBias->Allocate();
76*89c4ff92SAndroid Build Coastguard Worker
77*89c4ff92SAndroid Build Coastguard Worker if (!layerDesc.m_CifgEnabled)
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker layer->m_CifgParameters.m_InputToInputWeights = std::make_unique<ScopedTensorHandle>
80*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, inputSize }, DataType::Float32));
81*89c4ff92SAndroid Build Coastguard Worker layer->m_CifgParameters.m_RecurrentToInputWeights = std::make_unique<ScopedTensorHandle>
82*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits, outputSize }, DataType::Float32));
83*89c4ff92SAndroid Build Coastguard Worker layer->m_CifgParameters.m_InputGateBias = std::make_unique<ScopedTensorHandle>
84*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
85*89c4ff92SAndroid Build Coastguard Worker layer->m_CifgParameters.m_InputToInputWeights->Allocate();
86*89c4ff92SAndroid Build Coastguard Worker layer->m_CifgParameters.m_RecurrentToInputWeights->Allocate();
87*89c4ff92SAndroid Build Coastguard Worker layer->m_CifgParameters.m_InputGateBias->Allocate();
88*89c4ff92SAndroid Build Coastguard Worker }
89*89c4ff92SAndroid Build Coastguard Worker
90*89c4ff92SAndroid Build Coastguard Worker if (layerDesc.m_ProjectionEnabled)
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker layer->m_ProjectionParameters.m_ProjectionWeights = std::make_unique<ScopedTensorHandle>
93*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ outputSize, numUnits }, DataType::Float32));
94*89c4ff92SAndroid Build Coastguard Worker layer->m_ProjectionParameters.m_ProjectionBias = std::make_unique<ScopedTensorHandle>
95*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ outputSize }, DataType::Float32));
96*89c4ff92SAndroid Build Coastguard Worker layer->m_ProjectionParameters.m_ProjectionWeights->Allocate();
97*89c4ff92SAndroid Build Coastguard Worker layer->m_ProjectionParameters.m_ProjectionBias->Allocate();
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker
100*89c4ff92SAndroid Build Coastguard Worker if (layerDesc.m_PeepholeEnabled)
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker if (!layerDesc.m_CifgEnabled)
103*89c4ff92SAndroid Build Coastguard Worker {
104*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToInputWeights = std::make_unique<ScopedTensorHandle>
105*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
106*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToInputWeights->Allocate();
107*89c4ff92SAndroid Build Coastguard Worker }
108*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToForgetWeights = std::make_unique<ScopedTensorHandle>
109*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
110*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToOutputWeights = std::make_unique<ScopedTensorHandle>
111*89c4ff92SAndroid Build Coastguard Worker (TensorInfo({ numUnits }, DataType::Float32));
112*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToForgetWeights->Allocate();
113*89c4ff92SAndroid Build Coastguard Worker layer->m_PeepholeParameters.m_CellToOutputWeights->Allocate();
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker
116*89c4ff92SAndroid Build Coastguard Worker // create input and output layers
117*89c4ff92SAndroid Build Coastguard Worker Layer* const input = graph.AddLayer<InputLayer>(0, "input");
118*89c4ff92SAndroid Build Coastguard Worker Layer* const outputStateIn = graph.AddLayer<InputLayer>(1, "outputStateIn");
119*89c4ff92SAndroid Build Coastguard Worker Layer* const cellStateIn = graph.AddLayer<InputLayer>(2, "cellStateIn");
120*89c4ff92SAndroid Build Coastguard Worker Layer* const scratchBuffer = graph.AddLayer<OutputLayer>(0, "scratchBuffer");
121*89c4ff92SAndroid Build Coastguard Worker Layer* const outputStateOut = graph.AddLayer<OutputLayer>(1, "outputStateOut");
122*89c4ff92SAndroid Build Coastguard Worker Layer* const cellStateOut = graph.AddLayer<OutputLayer>(2, "cellStateOut");
123*89c4ff92SAndroid Build Coastguard Worker Layer* const output = graph.AddLayer<OutputLayer>(3, "output");
124*89c4ff92SAndroid Build Coastguard Worker
125*89c4ff92SAndroid Build Coastguard Worker // connect up
126*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo lstmTensorInfo1({ batchSize, inputSize }, DataType::Float32);
127*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo lstmTensorInfo2({ batchSize, numUnits}, DataType::Float32);
128*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo lstmTensorInfo3({ batchSize, outputSize }, DataType::Float32);
129*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * (layerDesc.m_CifgEnabled ? 3 : 4) },
130*89c4ff92SAndroid Build Coastguard Worker DataType::Float32);
131*89c4ff92SAndroid Build Coastguard Worker
132*89c4ff92SAndroid Build Coastguard Worker Connect(input, layer, lstmTensorInfo1, 0, 0);
133*89c4ff92SAndroid Build Coastguard Worker Connect(cellStateIn, layer, lstmTensorInfo2, 0, 1);
134*89c4ff92SAndroid Build Coastguard Worker Connect(outputStateIn, layer, lstmTensorInfo3, 0, 2);
135*89c4ff92SAndroid Build Coastguard Worker Connect(layer, scratchBuffer, lstmTensorInfoScratchBuff, 0, 0);
136*89c4ff92SAndroid Build Coastguard Worker Connect(layer, outputStateOut, lstmTensorInfo3, 1, 0);
137*89c4ff92SAndroid Build Coastguard Worker Connect(layer, cellStateOut, lstmTensorInfo2, 2, 0);
138*89c4ff92SAndroid Build Coastguard Worker Connect(layer, output, lstmTensorInfo3, 3, 0);
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker
141*89c4ff92SAndroid Build Coastguard Worker
142*89c4ff92SAndroid Build Coastguard Worker class MockLayerSupport : public LayerSupportBase
143*89c4ff92SAndroid Build Coastguard Worker {
144*89c4ff92SAndroid Build Coastguard Worker public:
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> &,const Optional<QuantizedLstmInputParamsInfo> &,Optional<std::string &> reasonIfUnsupported) const145*89c4ff92SAndroid Build Coastguard Worker bool IsLayerSupported(const LayerType& type,
146*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorInfo>& infos,
147*89c4ff92SAndroid Build Coastguard Worker const BaseDescriptor& descriptor,
148*89c4ff92SAndroid Build Coastguard Worker const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/,
149*89c4ff92SAndroid Build Coastguard Worker const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/,
150*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const override
151*89c4ff92SAndroid Build Coastguard Worker {
152*89c4ff92SAndroid Build Coastguard Worker switch (type)
153*89c4ff92SAndroid Build Coastguard Worker {
154*89c4ff92SAndroid Build Coastguard Worker case LayerType::Input:
155*89c4ff92SAndroid Build Coastguard Worker return IsInputSupported(infos[0], reasonIfUnsupported);
156*89c4ff92SAndroid Build Coastguard Worker case LayerType::Output:
157*89c4ff92SAndroid Build Coastguard Worker return IsOutputSupported(infos[0], reasonIfUnsupported);
158*89c4ff92SAndroid Build Coastguard Worker case LayerType::Activation:
159*89c4ff92SAndroid Build Coastguard Worker return IsActivationSupported(infos[0],
160*89c4ff92SAndroid Build Coastguard Worker infos[1],
161*89c4ff92SAndroid Build Coastguard Worker *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
162*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported);
163*89c4ff92SAndroid Build Coastguard Worker default:
164*89c4ff92SAndroid Build Coastguard Worker return false;
165*89c4ff92SAndroid Build Coastguard Worker }
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker
IsInputSupported(const TensorInfo &,Optional<std::string &>) const168*89c4ff92SAndroid Build Coastguard Worker bool IsInputSupported(const TensorInfo& /*input*/,
169*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
170*89c4ff92SAndroid Build Coastguard Worker {
171*89c4ff92SAndroid Build Coastguard Worker return true;
172*89c4ff92SAndroid Build Coastguard Worker }
173*89c4ff92SAndroid Build Coastguard Worker
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const174*89c4ff92SAndroid Build Coastguard Worker bool IsOutputSupported(const TensorInfo& /*input*/,
175*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
176*89c4ff92SAndroid Build Coastguard Worker {
177*89c4ff92SAndroid Build Coastguard Worker return true;
178*89c4ff92SAndroid Build Coastguard Worker }
179*89c4ff92SAndroid Build Coastguard Worker
IsActivationSupported(const TensorInfo &,const TensorInfo &,const ActivationDescriptor &,Optional<std::string &>) const180*89c4ff92SAndroid Build Coastguard Worker bool IsActivationSupported(const TensorInfo& /*input0*/,
181*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& /*output*/,
182*89c4ff92SAndroid Build Coastguard Worker const ActivationDescriptor& /*descriptor*/,
183*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker return true;
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker };
188*89c4ff92SAndroid Build Coastguard Worker
189*89c4ff92SAndroid Build Coastguard Worker template <typename NamePolicy>
190*89c4ff92SAndroid Build Coastguard Worker class CustomAllocatorBackend : public IBackendInternal
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker public:
CustomAllocatorBackend()193*89c4ff92SAndroid Build Coastguard Worker CustomAllocatorBackend() :
194*89c4ff92SAndroid Build Coastguard Worker m_BackendCapabilities(NamePolicy::GetIdStatic(), {{"NullCapability", false}}),
195*89c4ff92SAndroid Build Coastguard Worker m_CustomAllocator(false) {};
CustomAllocatorBackend(const BackendCapabilities & capabilities)196*89c4ff92SAndroid Build Coastguard Worker CustomAllocatorBackend(const BackendCapabilities& capabilities) :
197*89c4ff92SAndroid Build Coastguard Worker m_BackendCapabilities(capabilities),
198*89c4ff92SAndroid Build Coastguard Worker m_CustomAllocator(false) {};
199*89c4ff92SAndroid Build Coastguard Worker ~CustomAllocatorBackend() = default;
200*89c4ff92SAndroid Build Coastguard Worker
GetIdStatic()201*89c4ff92SAndroid Build Coastguard Worker static const BackendId& GetIdStatic()
202*89c4ff92SAndroid Build Coastguard Worker {
203*89c4ff92SAndroid Build Coastguard Worker return NamePolicy::GetIdStatic();
204*89c4ff92SAndroid Build Coastguard Worker }
GetId() const205*89c4ff92SAndroid Build Coastguard Worker const BackendId& GetId() const override
206*89c4ff92SAndroid Build Coastguard Worker {
207*89c4ff92SAndroid Build Coastguard Worker return GetIdStatic();
208*89c4ff92SAndroid Build Coastguard Worker }
209*89c4ff92SAndroid Build Coastguard Worker
CreateMemoryManager() const210*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override
211*89c4ff92SAndroid Build Coastguard Worker {
212*89c4ff92SAndroid Build Coastguard Worker return nullptr;
213*89c4ff92SAndroid Build Coastguard Worker };
214*89c4ff92SAndroid Build Coastguard Worker
215*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr &) const216*89c4ff92SAndroid Build Coastguard Worker CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr&) const override
217*89c4ff92SAndroid Build Coastguard Worker {
218*89c4ff92SAndroid Build Coastguard Worker return nullptr;
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker
CreateBackendContext(const IRuntime::CreationOptions &) const221*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
222*89c4ff92SAndroid Build Coastguard Worker {
223*89c4ff92SAndroid Build Coastguard Worker return nullptr;
224*89c4ff92SAndroid Build Coastguard Worker }
225*89c4ff92SAndroid Build Coastguard Worker
GetLayerSupport() const226*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
227*89c4ff92SAndroid Build Coastguard Worker {
228*89c4ff92SAndroid Build Coastguard Worker return std::make_shared<MockLayerSupport>();
229*89c4ff92SAndroid Build Coastguard Worker }
230*89c4ff92SAndroid Build Coastguard Worker
OptimizeSubgraphView(const SubgraphView &) const231*89c4ff92SAndroid Build Coastguard Worker OptimizationViews OptimizeSubgraphView(const SubgraphView&) const override
232*89c4ff92SAndroid Build Coastguard Worker {
233*89c4ff92SAndroid Build Coastguard Worker return {};
234*89c4ff92SAndroid Build Coastguard Worker };
235*89c4ff92SAndroid Build Coastguard Worker
GetCapabilities() const236*89c4ff92SAndroid Build Coastguard Worker BackendCapabilities GetCapabilities() const override
237*89c4ff92SAndroid Build Coastguard Worker {
238*89c4ff92SAndroid Build Coastguard Worker return m_BackendCapabilities;
239*89c4ff92SAndroid Build Coastguard Worker };
240*89c4ff92SAndroid Build Coastguard Worker
UseCustomMemoryAllocator(std::shared_ptr<ICustomAllocator> allocator,armnn::Optional<std::string &> errMsg)241*89c4ff92SAndroid Build Coastguard Worker virtual bool UseCustomMemoryAllocator(std::shared_ptr<ICustomAllocator> allocator,
242*89c4ff92SAndroid Build Coastguard Worker armnn::Optional<std::string&> errMsg) override
243*89c4ff92SAndroid Build Coastguard Worker {
244*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(errMsg, allocator);
245*89c4ff92SAndroid Build Coastguard Worker m_CustomAllocator = true;
246*89c4ff92SAndroid Build Coastguard Worker return m_CustomAllocator;
247*89c4ff92SAndroid Build Coastguard Worker }
248*89c4ff92SAndroid Build Coastguard Worker
249*89c4ff92SAndroid Build Coastguard Worker BackendCapabilities m_BackendCapabilities;
250*89c4ff92SAndroid Build Coastguard Worker bool m_CustomAllocator;
251*89c4ff92SAndroid Build Coastguard Worker };
252*89c4ff92SAndroid Build Coastguard Worker
253*89c4ff92SAndroid Build Coastguard Worker template <typename NamePolicy>
254*89c4ff92SAndroid Build Coastguard Worker class NoProtectedModeMockBackend : public IBackendInternal
255*89c4ff92SAndroid Build Coastguard Worker {
256*89c4ff92SAndroid Build Coastguard Worker public:
NoProtectedModeMockBackend()257*89c4ff92SAndroid Build Coastguard Worker NoProtectedModeMockBackend() : m_BackendCapabilities(NamePolicy::GetIdStatic(), {{"NullCapability", false}}) {};
NoProtectedModeMockBackend(const BackendCapabilities & capabilities)258*89c4ff92SAndroid Build Coastguard Worker NoProtectedModeMockBackend(const BackendCapabilities& capabilities) : m_BackendCapabilities(capabilities) {};
259*89c4ff92SAndroid Build Coastguard Worker ~NoProtectedModeMockBackend() = default;
260*89c4ff92SAndroid Build Coastguard Worker
GetIdStatic()261*89c4ff92SAndroid Build Coastguard Worker static const BackendId& GetIdStatic()
262*89c4ff92SAndroid Build Coastguard Worker {
263*89c4ff92SAndroid Build Coastguard Worker return NamePolicy::GetIdStatic();
264*89c4ff92SAndroid Build Coastguard Worker }
GetId() const265*89c4ff92SAndroid Build Coastguard Worker const BackendId& GetId() const override
266*89c4ff92SAndroid Build Coastguard Worker {
267*89c4ff92SAndroid Build Coastguard Worker return GetIdStatic();
268*89c4ff92SAndroid Build Coastguard Worker }
269*89c4ff92SAndroid Build Coastguard Worker
CreateMemoryManager() const270*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override
271*89c4ff92SAndroid Build Coastguard Worker {
272*89c4ff92SAndroid Build Coastguard Worker return nullptr;
273*89c4ff92SAndroid Build Coastguard Worker };
274*89c4ff92SAndroid Build Coastguard Worker
275*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr &) const276*89c4ff92SAndroid Build Coastguard Worker CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr&) const override
277*89c4ff92SAndroid Build Coastguard Worker {
278*89c4ff92SAndroid Build Coastguard Worker return nullptr;
279*89c4ff92SAndroid Build Coastguard Worker }
280*89c4ff92SAndroid Build Coastguard Worker
CreateBackendContext(const IRuntime::CreationOptions &) const281*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
282*89c4ff92SAndroid Build Coastguard Worker {
283*89c4ff92SAndroid Build Coastguard Worker return nullptr;
284*89c4ff92SAndroid Build Coastguard Worker }
285*89c4ff92SAndroid Build Coastguard Worker
GetLayerSupport() const286*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
287*89c4ff92SAndroid Build Coastguard Worker {
288*89c4ff92SAndroid Build Coastguard Worker return std::make_shared<MockLayerSupport>();
289*89c4ff92SAndroid Build Coastguard Worker }
290*89c4ff92SAndroid Build Coastguard Worker
OptimizeSubgraphView(const SubgraphView &) const291*89c4ff92SAndroid Build Coastguard Worker OptimizationViews OptimizeSubgraphView(const SubgraphView&) const override
292*89c4ff92SAndroid Build Coastguard Worker {
293*89c4ff92SAndroid Build Coastguard Worker return {};
294*89c4ff92SAndroid Build Coastguard Worker };
295*89c4ff92SAndroid Build Coastguard Worker
GetCapabilities() const296*89c4ff92SAndroid Build Coastguard Worker BackendCapabilities GetCapabilities() const override
297*89c4ff92SAndroid Build Coastguard Worker {
298*89c4ff92SAndroid Build Coastguard Worker return m_BackendCapabilities;
299*89c4ff92SAndroid Build Coastguard Worker };
300*89c4ff92SAndroid Build Coastguard Worker
301*89c4ff92SAndroid Build Coastguard Worker BackendCapabilities m_BackendCapabilities;
302*89c4ff92SAndroid Build Coastguard Worker };
303*89c4ff92SAndroid Build Coastguard Worker
304*89c4ff92SAndroid Build Coastguard Worker } // namespace
305*89c4ff92SAndroid Build Coastguard Worker
306*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer")
307*89c4ff92SAndroid Build Coastguard Worker {
308*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::optimizations;
309*89c4ff92SAndroid Build Coastguard Worker
310*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LSTMValidateTensorShapesFromInputsCIFGDisabledTest")
311*89c4ff92SAndroid Build Coastguard Worker {
312*89c4ff92SAndroid Build Coastguard Worker Graph graph;
313*89c4ff92SAndroid Build Coastguard Worker
314*89c4ff92SAndroid Build Coastguard Worker //Helper function creates graph containing LSTM layer with required input and output layers
315*89c4ff92SAndroid Build Coastguard Worker CreateLSTMLayerHelper(graph, false);
316*89c4ff92SAndroid Build Coastguard Worker
317*89c4ff92SAndroid Build Coastguard Worker //This function used to call ValidateShapesFromInputs();
318*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
319*89c4ff92SAndroid Build Coastguard Worker }
320*89c4ff92SAndroid Build Coastguard Worker
321*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LSTMValidateTensorShapesFromInputsCIFGEnabledTest")
322*89c4ff92SAndroid Build Coastguard Worker {
323*89c4ff92SAndroid Build Coastguard Worker Graph graph;
324*89c4ff92SAndroid Build Coastguard Worker
325*89c4ff92SAndroid Build Coastguard Worker //Helper function creates graph containing LSTM layer with required input and output layers
326*89c4ff92SAndroid Build Coastguard Worker CreateLSTMLayerHelper(graph, true);
327*89c4ff92SAndroid Build Coastguard Worker
328*89c4ff92SAndroid Build Coastguard Worker //This function used to call ValidateShapesFromInputs();
329*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
330*89c4ff92SAndroid Build Coastguard Worker }
331*89c4ff92SAndroid Build Coastguard Worker
332*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("InsertConvertersTest")
333*89c4ff92SAndroid Build Coastguard Worker {
334*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo info({ 1, 5, 2, 3 }, armnn::DataType::Float16);
335*89c4ff92SAndroid Build Coastguard Worker
336*89c4ff92SAndroid Build Coastguard Worker armnn::Graph graph;
337*89c4ff92SAndroid Build Coastguard Worker
338*89c4ff92SAndroid Build Coastguard Worker armnn::LayerBindingId inputId = 0;
339*89c4ff92SAndroid Build Coastguard Worker
340*89c4ff92SAndroid Build Coastguard Worker armnn::Layer* head = graph.AddLayer<armnn::OutputLayer>(0, "output");
341*89c4ff92SAndroid Build Coastguard Worker
342*89c4ff92SAndroid Build Coastguard Worker head = graph.InsertNewLayer<armnn::AdditionLayer>(head->GetInputSlot(0), "");
343*89c4ff92SAndroid Build Coastguard Worker head->GetOutputHandler().SetTensorInfo(info);
344*89c4ff92SAndroid Build Coastguard Worker
345*89c4ff92SAndroid Build Coastguard Worker graph.InsertNewLayer<armnn::InputLayer>(head->GetInputSlot(1), inputId++, "")
346*89c4ff92SAndroid Build Coastguard Worker ->GetOutputHandler().SetTensorInfo(info);
347*89c4ff92SAndroid Build Coastguard Worker
348*89c4ff92SAndroid Build Coastguard Worker head = graph.InsertNewLayer<armnn::FloorLayer>(head->GetInputSlot(0), "");
349*89c4ff92SAndroid Build Coastguard Worker head->GetOutputHandler().SetTensorInfo(info);
350*89c4ff92SAndroid Build Coastguard Worker
351*89c4ff92SAndroid Build Coastguard Worker head = graph.InsertNewLayer<armnn::MemCopyLayer>(head->GetInputSlot(0), "");
352*89c4ff92SAndroid Build Coastguard Worker head->GetOutputHandler().SetTensorInfo(info);
353*89c4ff92SAndroid Build Coastguard Worker
354*89c4ff92SAndroid Build Coastguard Worker graph.InsertNewLayer<armnn::InputLayer>(head->GetInputSlot(0), inputId++, "")
355*89c4ff92SAndroid Build Coastguard Worker ->GetOutputHandler().SetTensorInfo(info);
356*89c4ff92SAndroid Build Coastguard Worker
357*89c4ff92SAndroid Build Coastguard Worker // Check graph layer sequence before inserting convert layers
358*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(),
359*89c4ff92SAndroid Build Coastguard Worker graph.cend(),
360*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::InputLayer>,
361*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::InputLayer>,
362*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::MemCopyLayer>,
363*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::FloorLayer>,
364*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::AdditionLayer>,
365*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>));
366*89c4ff92SAndroid Build Coastguard Worker
367*89c4ff92SAndroid Build Coastguard Worker // Check layers have Float16 DataType
368*89c4ff92SAndroid Build Coastguard Worker for (auto& layer : graph)
369*89c4ff92SAndroid Build Coastguard Worker {
370*89c4ff92SAndroid Build Coastguard Worker if(layer->GetType()==LayerType::Floor || layer->GetType() == LayerType::Addition)
371*89c4ff92SAndroid Build Coastguard Worker {
372*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer->GetOutputSlot(0).GetTensorInfo().GetDataType() == DataType::Float16);
373*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer->GetDataType() == DataType::Float16);
374*89c4ff92SAndroid Build Coastguard Worker }
375*89c4ff92SAndroid Build Coastguard Worker }
376*89c4ff92SAndroid Build Coastguard Worker
377*89c4ff92SAndroid Build Coastguard Worker // Insert convert layers either side of unsupported layer
378*89c4ff92SAndroid Build Coastguard Worker for (auto& layer : graph)
379*89c4ff92SAndroid Build Coastguard Worker {
380*89c4ff92SAndroid Build Coastguard Worker if(layer->GetType()==LayerType::Floor || layer->GetType() == LayerType::Addition)
381*89c4ff92SAndroid Build Coastguard Worker {
382*89c4ff92SAndroid Build Coastguard Worker InsertConvertFp16ToFp32LayersBefore(graph, *layer);
383*89c4ff92SAndroid Build Coastguard Worker InsertConvertFp32ToFp16LayersAfter(graph, *layer);
384*89c4ff92SAndroid Build Coastguard Worker }
385*89c4ff92SAndroid Build Coastguard Worker }
386*89c4ff92SAndroid Build Coastguard Worker
387*89c4ff92SAndroid Build Coastguard Worker // Check layers have correct DataType after inserting convert layers
388*89c4ff92SAndroid Build Coastguard Worker for (auto& layer : graph)
389*89c4ff92SAndroid Build Coastguard Worker {
390*89c4ff92SAndroid Build Coastguard Worker if (layer->GetType()==LayerType::Floor || layer->GetType() == LayerType::Addition)
391*89c4ff92SAndroid Build Coastguard Worker {
392*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer->GetOutputSlot(0).GetTensorInfo().GetDataType() == DataType::Float32);
393*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer->GetDataType() == DataType::Float32);
394*89c4ff92SAndroid Build Coastguard Worker }
395*89c4ff92SAndroid Build Coastguard Worker else if (layer->GetType() == LayerType::ConvertFp16ToFp32)
396*89c4ff92SAndroid Build Coastguard Worker {
397*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer->GetOutputSlot(0).GetTensorInfo().GetDataType() == DataType::Float32);
398*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer->GetDataType() == DataType::Float16);
399*89c4ff92SAndroid Build Coastguard Worker }
400*89c4ff92SAndroid Build Coastguard Worker else if (layer->GetType() == LayerType::ConvertFp32ToFp16)
401*89c4ff92SAndroid Build Coastguard Worker {
402*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer->GetOutputSlot(0).GetTensorInfo().GetDataType() == DataType::Float16);
403*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(layer->GetDataType() == DataType::Float32);
404*89c4ff92SAndroid Build Coastguard Worker }
405*89c4ff92SAndroid Build Coastguard Worker }
406*89c4ff92SAndroid Build Coastguard Worker
407*89c4ff92SAndroid Build Coastguard Worker // Check sequence of layers after inserting convert layers
408*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(),
409*89c4ff92SAndroid Build Coastguard Worker graph.cend(),
410*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::InputLayer>,
411*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::InputLayer>,
412*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::ConvertFp16ToFp32Layer>,
413*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::MemCopyLayer>,
414*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::ConvertFp16ToFp32Layer>,
415*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::FloorLayer>,
416*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::ConvertFp32ToFp16Layer>,
417*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::ConvertFp16ToFp32Layer>,
418*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::AdditionLayer>,
419*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::ConvertFp32ToFp16Layer>,
420*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>));
421*89c4ff92SAndroid Build Coastguard Worker }
422*89c4ff92SAndroid Build Coastguard Worker
CreateConvolution2dGraph(Graph & graph,const unsigned int * inputShape,const unsigned int * weightsShape,const unsigned int * outputShape,DataLayout dataLayout=DataLayout::NCHW)423*89c4ff92SAndroid Build Coastguard Worker void CreateConvolution2dGraph(Graph &graph, const unsigned int* inputShape,
424*89c4ff92SAndroid Build Coastguard Worker const unsigned int* weightsShape, const unsigned int* outputShape,
425*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW)
426*89c4ff92SAndroid Build Coastguard Worker {
427*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo(4, inputShape, DataType::Float32);
428*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo(4, outputShape, DataType::Float32);
429*89c4ff92SAndroid Build Coastguard Worker
430*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(90);
431*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weights(
432*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo(4, weightsShape, armnn::DataType::Float32, 0.0f, 0, true),
433*89c4ff92SAndroid Build Coastguard Worker weightsVector);
434*89c4ff92SAndroid Build Coastguard Worker
435*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor desc;
436*89c4ff92SAndroid Build Coastguard Worker desc.m_BiasEnabled = false;
437*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideX = 1;
438*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideY = 1;
439*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = dataLayout;
440*89c4ff92SAndroid Build Coastguard Worker
441*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input");
442*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo);
443*89c4ff92SAndroid Build Coastguard Worker
444*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* weightsLayer = graph.AddLayer<ConstantLayer>("Weights");
445*89c4ff92SAndroid Build Coastguard Worker weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weights);
446*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsLayer->m_LayerOutput->GetTensorInfo());
447*89c4ff92SAndroid Build Coastguard Worker
448*89c4ff92SAndroid Build Coastguard Worker Convolution2dLayer* layer = graph.AddLayer<Convolution2dLayer>(desc, "conv2d");
449*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().SetTensorInfo(outputInfo);
450*89c4ff92SAndroid Build Coastguard Worker
451*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output");
452*89c4ff92SAndroid Build Coastguard Worker
453*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0));
454*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(output->GetInputSlot(0));
455*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
456*89c4ff92SAndroid Build Coastguard Worker }
457*89c4ff92SAndroid Build Coastguard Worker
458*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Conv2dValidateTensorShapesFromInputs")
459*89c4ff92SAndroid Build Coastguard Worker {
460*89c4ff92SAndroid Build Coastguard Worker Graph graph;
461*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = { 1, 3, 8, 16 };
462*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsShape[] = { 2, 3, 5, 3 };
463*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = { 1, 2, 4, 14 };
464*89c4ff92SAndroid Build Coastguard Worker CreateConvolution2dGraph(graph, inputShape, weightsShape, outputShape);
465*89c4ff92SAndroid Build Coastguard Worker
466*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
467*89c4ff92SAndroid Build Coastguard Worker }
468*89c4ff92SAndroid Build Coastguard Worker
469*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Conv2dValidateTensorShapesFromInputsNhwc")
470*89c4ff92SAndroid Build Coastguard Worker {
471*89c4ff92SAndroid Build Coastguard Worker Graph graph;
472*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = { 1, 8, 16, 3 };
473*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsShape[] = { 2, 5, 3, 3 };
474*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = { 1, 4, 14, 2 };
475*89c4ff92SAndroid Build Coastguard Worker CreateConvolution2dGraph(graph, inputShape, weightsShape, outputShape, DataLayout::NHWC);
476*89c4ff92SAndroid Build Coastguard Worker
477*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
478*89c4ff92SAndroid Build Coastguard Worker }
479*89c4ff92SAndroid Build Coastguard Worker
CreateDepthwiseConvolution2dGraph(Graph & graph,const unsigned int * inputShape,const unsigned int * weightsShape,const unsigned int * outputShape,DataLayout dataLayout=DataLayout::NCHW)480*89c4ff92SAndroid Build Coastguard Worker void CreateDepthwiseConvolution2dGraph(Graph &graph, const unsigned int* inputShape,
481*89c4ff92SAndroid Build Coastguard Worker const unsigned int* weightsShape, const unsigned int* outputShape,
482*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW)
483*89c4ff92SAndroid Build Coastguard Worker {
484*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo(4, inputShape, DataType::Float32);
485*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo(4, outputShape, DataType::Float32);
486*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo weightsInfo(TensorShape(4, weightsShape), armnn::DataType::Float32, 0.0f, 0, true);
487*89c4ff92SAndroid Build Coastguard Worker
488*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector(18);
489*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor weights(weightsInfo, weightsVector);
490*89c4ff92SAndroid Build Coastguard Worker
491*89c4ff92SAndroid Build Coastguard Worker DepthwiseConvolution2dDescriptor desc;
492*89c4ff92SAndroid Build Coastguard Worker desc.m_BiasEnabled = false;
493*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideX = 1;
494*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideY = 1;
495*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = dataLayout;
496*89c4ff92SAndroid Build Coastguard Worker
497*89c4ff92SAndroid Build Coastguard Worker InputLayer* input = graph.AddLayer<InputLayer>(0, "input");
498*89c4ff92SAndroid Build Coastguard Worker DepthwiseConvolution2dLayer* layer = graph.AddLayer<DepthwiseConvolution2dLayer>(desc, "depthwiseConv2d");
499*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* weightsLayer = graph.AddLayer<ConstantLayer>("weights");
500*89c4ff92SAndroid Build Coastguard Worker OutputLayer* output = graph.AddLayer<OutputLayer>(0, "output");
501*89c4ff92SAndroid Build Coastguard Worker
502*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo);
503*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().SetTensorInfo(outputInfo);
504*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot().SetTensorInfo(weightsInfo);
505*89c4ff92SAndroid Build Coastguard Worker
506*89c4ff92SAndroid Build Coastguard Worker weightsLayer->m_LayerOutput = std::make_unique<armnn::ScopedTensorHandle>(weights);
507*89c4ff92SAndroid Build Coastguard Worker
508*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0));
509*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot().Connect(layer->GetInputSlot(1));
510*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(output->GetInputSlot(0));
511*89c4ff92SAndroid Build Coastguard Worker }
512*89c4ff92SAndroid Build Coastguard Worker
513*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("DepthwiseConv2dValidateTensorShapesFromInputs")
514*89c4ff92SAndroid Build Coastguard Worker {
515*89c4ff92SAndroid Build Coastguard Worker Graph graph;
516*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = { 1, 2, 3, 3 };
517*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsShape[] = { 1, 3, 3, 2 };
518*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = { 1, 2, 1, 1 };
519*89c4ff92SAndroid Build Coastguard Worker CreateDepthwiseConvolution2dGraph(graph, inputShape, weightsShape, outputShape);
520*89c4ff92SAndroid Build Coastguard Worker
521*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
522*89c4ff92SAndroid Build Coastguard Worker }
523*89c4ff92SAndroid Build Coastguard Worker
524*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("DepthwiseConv2dValidateTensorShapesFromInputsNhwc")
525*89c4ff92SAndroid Build Coastguard Worker {
526*89c4ff92SAndroid Build Coastguard Worker Graph graph;
527*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = { 1, 3, 3, 2 };
528*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsShape[] = { 1, 3, 3, 2 };
529*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = { 1, 1, 1, 2 };
530*89c4ff92SAndroid Build Coastguard Worker CreateDepthwiseConvolution2dGraph(graph, inputShape, weightsShape, outputShape, DataLayout::NHWC);
531*89c4ff92SAndroid Build Coastguard Worker
532*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
533*89c4ff92SAndroid Build Coastguard Worker }
534*89c4ff92SAndroid Build Coastguard Worker
CreatePooling2dGraph(Graph & graph,const unsigned int * inputShape,const unsigned int * outputShape,DataLayout dataLayout=DataLayout::NCHW)535*89c4ff92SAndroid Build Coastguard Worker void CreatePooling2dGraph(Graph& graph, const unsigned int* inputShape, const unsigned int* outputShape,
536*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW)
537*89c4ff92SAndroid Build Coastguard Worker {
538*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputInfo(4, inputShape, DataType::Float32);
539*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo(4, outputShape, DataType::Float32);
540*89c4ff92SAndroid Build Coastguard Worker
541*89c4ff92SAndroid Build Coastguard Worker Pooling2dDescriptor desc;
542*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolType = armnn::PoolingAlgorithm::Average;
543*89c4ff92SAndroid Build Coastguard Worker desc.m_PoolWidth = desc.m_PoolHeight = 100;
544*89c4ff92SAndroid Build Coastguard Worker desc.m_StrideX = desc.m_StrideY = 5;
545*89c4ff92SAndroid Build Coastguard Worker desc.m_PadLeft = 50;
546*89c4ff92SAndroid Build Coastguard Worker desc.m_PadRight = 50;
547*89c4ff92SAndroid Build Coastguard Worker desc.m_PadTop = 50;
548*89c4ff92SAndroid Build Coastguard Worker desc.m_PadBottom = 50;
549*89c4ff92SAndroid Build Coastguard Worker desc.m_PaddingMethod = armnn::PaddingMethod::Exclude;
550*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = dataLayout;
551*89c4ff92SAndroid Build Coastguard Worker
552*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input");
553*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo);
554*89c4ff92SAndroid Build Coastguard Worker
555*89c4ff92SAndroid Build Coastguard Worker Pooling2dLayer* layer = graph.AddLayer<Pooling2dLayer>(desc, "pooling2d");
556*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().SetTensorInfo(outputInfo);
557*89c4ff92SAndroid Build Coastguard Worker
558*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output");
559*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0));
560*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(output->GetInputSlot(0));
561*89c4ff92SAndroid Build Coastguard Worker }
562*89c4ff92SAndroid Build Coastguard Worker
563*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Pooling2dValidateTensorShapesFromInputs")
564*89c4ff92SAndroid Build Coastguard Worker {
565*89c4ff92SAndroid Build Coastguard Worker Graph graph;
566*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = { 5, 3, 52, 60 };
567*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = { 5, 3, 11, 13 };
568*89c4ff92SAndroid Build Coastguard Worker CreatePooling2dGraph(graph, inputShape, outputShape, DataLayout::NCHW);
569*89c4ff92SAndroid Build Coastguard Worker
570*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
571*89c4ff92SAndroid Build Coastguard Worker }
572*89c4ff92SAndroid Build Coastguard Worker
573*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Pooling2dValidateTensorShapesFromInputsNhwc")
574*89c4ff92SAndroid Build Coastguard Worker {
575*89c4ff92SAndroid Build Coastguard Worker Graph graph;
576*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = { 5, 52, 60, 3 };
577*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = { 5, 11, 13, 3 };
578*89c4ff92SAndroid Build Coastguard Worker CreatePooling2dGraph(graph, inputShape, outputShape, DataLayout::NHWC);
579*89c4ff92SAndroid Build Coastguard Worker
580*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
581*89c4ff92SAndroid Build Coastguard Worker }
582*89c4ff92SAndroid Build Coastguard Worker
CreateResizeBilinearGraph(Graph & graph,const unsigned int * inputShape,const unsigned int * outputShape,DataLayout dataLayout=DataLayout::NCHW)583*89c4ff92SAndroid Build Coastguard Worker void CreateResizeBilinearGraph(Graph& graph,
584*89c4ff92SAndroid Build Coastguard Worker const unsigned int* inputShape,
585*89c4ff92SAndroid Build Coastguard Worker const unsigned int* outputShape,
586*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout = DataLayout::NCHW)
587*89c4ff92SAndroid Build Coastguard Worker {
588*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputShape, DataType::Float32);
589*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputShape, DataType::Float32);
590*89c4ff92SAndroid Build Coastguard Worker
591*89c4ff92SAndroid Build Coastguard Worker ResizeDescriptor desc;
592*89c4ff92SAndroid Build Coastguard Worker desc.m_Method = ResizeMethod::Bilinear;
593*89c4ff92SAndroid Build Coastguard Worker desc.m_TargetHeight = 3;
594*89c4ff92SAndroid Build Coastguard Worker desc.m_TargetWidth = 4;
595*89c4ff92SAndroid Build Coastguard Worker desc.m_DataLayout = dataLayout;
596*89c4ff92SAndroid Build Coastguard Worker
597*89c4ff92SAndroid Build Coastguard Worker Layer* input = graph.AddLayer<InputLayer>(0, "input");
598*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo);
599*89c4ff92SAndroid Build Coastguard Worker
600*89c4ff92SAndroid Build Coastguard Worker ResizeLayer* layer = graph.AddLayer<ResizeLayer>(desc, "resizeBilinear");
601*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().SetTensorInfo(outputInfo);
602*89c4ff92SAndroid Build Coastguard Worker
603*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output");
604*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().Connect(layer->GetInputSlot(0));
605*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(output->GetInputSlot(0));
606*89c4ff92SAndroid Build Coastguard Worker }
607*89c4ff92SAndroid Build Coastguard Worker
608*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ResizeBilinearValidateTensorShapesFromInputs")
609*89c4ff92SAndroid Build Coastguard Worker {
610*89c4ff92SAndroid Build Coastguard Worker Graph graph;
611*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = { 1, 2, 4, 5 };
612*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = { 1, 2, 3, 4 };
613*89c4ff92SAndroid Build Coastguard Worker CreateResizeBilinearGraph(graph, inputShape, outputShape);
614*89c4ff92SAndroid Build Coastguard Worker
615*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
616*89c4ff92SAndroid Build Coastguard Worker }
617*89c4ff92SAndroid Build Coastguard Worker
618*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ResizeBilinearValidateTensorShapesFromInputsNhwc")
619*89c4ff92SAndroid Build Coastguard Worker {
620*89c4ff92SAndroid Build Coastguard Worker Graph graph;
621*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputShape[] = { 1, 4, 5, 2 };
622*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputShape[] = { 1, 3, 4, 2 };
623*89c4ff92SAndroid Build Coastguard Worker CreateResizeBilinearGraph(graph, inputShape, outputShape, DataLayout::NHWC);
624*89c4ff92SAndroid Build Coastguard Worker
625*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
626*89c4ff92SAndroid Build Coastguard Worker }
627*89c4ff92SAndroid Build Coastguard Worker
CreateGatherGraph(Graph & graph,const armnn::TensorInfo & paramsInfo,const armnn::TensorInfo & indicesInfo,const armnn::TensorInfo & outputInfo)628*89c4ff92SAndroid Build Coastguard Worker void CreateGatherGraph(Graph& graph,
629*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& paramsInfo,
630*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& indicesInfo,
631*89c4ff92SAndroid Build Coastguard Worker const armnn::TensorInfo& outputInfo)
632*89c4ff92SAndroid Build Coastguard Worker {
633*89c4ff92SAndroid Build Coastguard Worker Layer* input0 = graph.AddLayer<InputLayer>(0, "params");
634*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot().SetTensorInfo(paramsInfo);
635*89c4ff92SAndroid Build Coastguard Worker
636*89c4ff92SAndroid Build Coastguard Worker Layer* input1 = graph.AddLayer<InputLayer>(1, "indices");
637*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot().SetTensorInfo(indicesInfo);
638*89c4ff92SAndroid Build Coastguard Worker
639*89c4ff92SAndroid Build Coastguard Worker GatherDescriptor descriptor;
640*89c4ff92SAndroid Build Coastguard Worker GatherLayer* layer = graph.AddLayer<GatherLayer>(descriptor, "gather");
641*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().SetTensorInfo(outputInfo);
642*89c4ff92SAndroid Build Coastguard Worker
643*89c4ff92SAndroid Build Coastguard Worker Layer* output = graph.AddLayer<OutputLayer>(0, "output");
644*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot().Connect(layer->GetInputSlot(0));
645*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot().Connect(layer->GetInputSlot(1));
646*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot().Connect(output->GetInputSlot(0));
647*89c4ff92SAndroid Build Coastguard Worker }
648*89c4ff92SAndroid Build Coastguard Worker
649*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GatherValidateTensorShapesFromInputs")
650*89c4ff92SAndroid Build Coastguard Worker {
651*89c4ff92SAndroid Build Coastguard Worker Graph graph;
652*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo paramsInfo({10, 5}, DataType::Float32);
653*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo indicesInfo({3}, DataType::Signed32);
654*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo({3, 5}, DataType::Float32);
655*89c4ff92SAndroid Build Coastguard Worker
656*89c4ff92SAndroid Build Coastguard Worker CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
657*89c4ff92SAndroid Build Coastguard Worker
658*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
659*89c4ff92SAndroid Build Coastguard Worker }
660*89c4ff92SAndroid Build Coastguard Worker
661*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GatherValidateTensorShapesFromInputs1DParams")
662*89c4ff92SAndroid Build Coastguard Worker {
663*89c4ff92SAndroid Build Coastguard Worker Graph graph;
664*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo paramsInfo({8}, DataType::Float32);
665*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo indicesInfo({5}, DataType::Signed32);
666*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo( {5}, DataType::Float32);
667*89c4ff92SAndroid Build Coastguard Worker
668*89c4ff92SAndroid Build Coastguard Worker CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
669*89c4ff92SAndroid Build Coastguard Worker
670*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
671*89c4ff92SAndroid Build Coastguard Worker }
672*89c4ff92SAndroid Build Coastguard Worker
673*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GatherValidateTensorShapesFromInputsMultiDimIndices")
674*89c4ff92SAndroid Build Coastguard Worker {
675*89c4ff92SAndroid Build Coastguard Worker Graph graph;
676*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo paramsInfo({3, 2, 5}, DataType::Float32);
677*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo indicesInfo({2, 2}, DataType::Signed32);
678*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputInfo({2, 2, 2, 5}, DataType::Float32);
679*89c4ff92SAndroid Build Coastguard Worker
680*89c4ff92SAndroid Build Coastguard Worker CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
681*89c4ff92SAndroid Build Coastguard Worker
682*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
683*89c4ff92SAndroid Build Coastguard Worker }
684*89c4ff92SAndroid Build Coastguard Worker
685*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("DetectionPostProcessValidateTensorShapes")
686*89c4ff92SAndroid Build Coastguard Worker {
687*89c4ff92SAndroid Build Coastguard Worker Graph graph;
688*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo boxEncodingsInfo({1, 10, 4}, DataType::QAsymmU8);
689*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo scoresInfo({1, 10, 4}, DataType::QAsymmU8);
690*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> anchorsVector(40);
691*89c4ff92SAndroid Build Coastguard Worker armnn::ConstTensor anchors(armnn::TensorInfo({10, 4}, armnn::DataType::QAsymmU8, 0.0f, 0, true), anchorsVector);
692*89c4ff92SAndroid Build Coastguard Worker
693*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo detectionBoxesInfo({1, 3, 4}, DataType::QAsymmU8);
694*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo detectionScoresInfo({1, 3}, DataType::QAsymmU8);
695*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo detectionClassesInfo({1, 3}, DataType::QAsymmU8);
696*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo numDetectionInfo({1}, DataType::QAsymmU8);
697*89c4ff92SAndroid Build Coastguard Worker
698*89c4ff92SAndroid Build Coastguard Worker Layer* input0 = graph.AddLayer<InputLayer>(0, "boxEncodings");
699*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot().SetTensorInfo(boxEncodingsInfo);
700*89c4ff92SAndroid Build Coastguard Worker
701*89c4ff92SAndroid Build Coastguard Worker Layer* input1 = graph.AddLayer<InputLayer>(1, "score");
702*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot().SetTensorInfo(scoresInfo);
703*89c4ff92SAndroid Build Coastguard Worker
704*89c4ff92SAndroid Build Coastguard Worker DetectionPostProcessDescriptor descriptor;
705*89c4ff92SAndroid Build Coastguard Worker descriptor.m_MaxDetections = 3;
706*89c4ff92SAndroid Build Coastguard Worker
707*89c4ff92SAndroid Build Coastguard Worker DetectionPostProcessLayer* layer = graph.AddLayer<DetectionPostProcessLayer>(descriptor, "detectionPostProcess");
708*89c4ff92SAndroid Build Coastguard Worker layer->m_Anchors = std::make_unique<armnn::ScopedTensorHandle>(anchors);
709*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(0).SetTensorInfo(detectionBoxesInfo);
710*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(1).SetTensorInfo(detectionScoresInfo);
711*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(2).SetTensorInfo(detectionClassesInfo);
712*89c4ff92SAndroid Build Coastguard Worker layer->GetOutputSlot(3).SetTensorInfo(numDetectionInfo);
713*89c4ff92SAndroid Build Coastguard Worker
714*89c4ff92SAndroid Build Coastguard Worker input0->GetOutputSlot().Connect(layer->GetInputSlot(0));
715*89c4ff92SAndroid Build Coastguard Worker input1->GetOutputSlot().Connect(layer->GetInputSlot(1));
716*89c4ff92SAndroid Build Coastguard Worker
717*89c4ff92SAndroid Build Coastguard Worker CHECK_NOTHROW(graph.InferTensorInfos());
718*89c4ff92SAndroid Build Coastguard Worker }
719*89c4ff92SAndroid Build Coastguard Worker
720*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("BackendCapabilityTest")
721*89c4ff92SAndroid Build Coastguard Worker {
722*89c4ff92SAndroid Build Coastguard Worker BackendId backendId = "MockBackend";
723*89c4ff92SAndroid Build Coastguard Worker
724*89c4ff92SAndroid Build Coastguard Worker armnn::BackendOptions::BackendOption nonConstWeights{"NonConstWeights", true};
725*89c4ff92SAndroid Build Coastguard Worker
726*89c4ff92SAndroid Build Coastguard Worker // MockBackend does not support the NonConstWeights capability
727*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::HasCapability(nonConstWeights, backendId));
728*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::HasCapability("NonConstWeights", backendId));
729*89c4ff92SAndroid Build Coastguard Worker
730*89c4ff92SAndroid Build Coastguard Worker // MockBackend does not support the AsyncExecution capability
731*89c4ff92SAndroid Build Coastguard Worker CHECK(!armnn::GetCapability("AsyncExecution", backendId).has_value());
732*89c4ff92SAndroid Build Coastguard Worker }
733*89c4ff92SAndroid Build Coastguard Worker
734*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("BackendHintTest")
735*89c4ff92SAndroid Build Coastguard Worker {
736*89c4ff92SAndroid Build Coastguard Worker class TestBackendAssignment : public StrategyBase<NoThrowStrategy>
737*89c4ff92SAndroid Build Coastguard Worker {
738*89c4ff92SAndroid Build Coastguard Worker public:
739*89c4ff92SAndroid Build Coastguard Worker
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)740*89c4ff92SAndroid Build Coastguard Worker void ExecuteStrategy(const armnn::IConnectableLayer* layer,
741*89c4ff92SAndroid Build Coastguard Worker const armnn::BaseDescriptor& descriptor,
742*89c4ff92SAndroid Build Coastguard Worker const std::vector<armnn::ConstTensor>& constants,
743*89c4ff92SAndroid Build Coastguard Worker const char* name,
744*89c4ff92SAndroid Build Coastguard Worker const armnn::LayerBindingId id = 0) override
745*89c4ff92SAndroid Build Coastguard Worker {
746*89c4ff92SAndroid Build Coastguard Worker armnn::IgnoreUnused(descriptor, constants, id, name);
747*89c4ff92SAndroid Build Coastguard Worker switch (layer->GetType())
748*89c4ff92SAndroid Build Coastguard Worker {
749*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Input:
750*89c4ff92SAndroid Build Coastguard Worker {
751*89c4ff92SAndroid Build Coastguard Worker auto inputLayer = PolymorphicDowncast<const InputLayer*>(layer);
752*89c4ff92SAndroid Build Coastguard Worker const auto connectedLayerBackendId = inputLayer->GetOutputSlot(0).GetOwningLayer().GetBackendId();
753*89c4ff92SAndroid Build Coastguard Worker CHECK((inputLayer->GetBackendId() == connectedLayerBackendId));
754*89c4ff92SAndroid Build Coastguard Worker break;
755*89c4ff92SAndroid Build Coastguard Worker }
756*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Output:
757*89c4ff92SAndroid Build Coastguard Worker {
758*89c4ff92SAndroid Build Coastguard Worker auto outputLayer = PolymorphicDowncast<const OutputLayer*>(layer);
759*89c4ff92SAndroid Build Coastguard Worker CHECK((outputLayer->GetBackendId() == "MockBackend"));
760*89c4ff92SAndroid Build Coastguard Worker break;
761*89c4ff92SAndroid Build Coastguard Worker }
762*89c4ff92SAndroid Build Coastguard Worker case armnn::LayerType::Activation:
763*89c4ff92SAndroid Build Coastguard Worker {
764*89c4ff92SAndroid Build Coastguard Worker auto activation = PolymorphicDowncast<const ActivationLayer*>(layer);
765*89c4ff92SAndroid Build Coastguard Worker CHECK((activation->GetBackendId() == "CustomBackend"));
766*89c4ff92SAndroid Build Coastguard Worker break;
767*89c4ff92SAndroid Build Coastguard Worker }
768*89c4ff92SAndroid Build Coastguard Worker default:
769*89c4ff92SAndroid Build Coastguard Worker {
770*89c4ff92SAndroid Build Coastguard Worker m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
771*89c4ff92SAndroid Build Coastguard Worker }
772*89c4ff92SAndroid Build Coastguard Worker }
773*89c4ff92SAndroid Build Coastguard Worker }
774*89c4ff92SAndroid Build Coastguard Worker };
775*89c4ff92SAndroid Build Coastguard Worker
776*89c4ff92SAndroid Build Coastguard Worker struct CustomPolicy
777*89c4ff92SAndroid Build Coastguard Worker {
GetIdStaticCustomPolicy778*89c4ff92SAndroid Build Coastguard Worker static const BackendId& GetIdStatic()
779*89c4ff92SAndroid Build Coastguard Worker {
780*89c4ff92SAndroid Build Coastguard Worker static BackendId id = "CustomBackend";
781*89c4ff92SAndroid Build Coastguard Worker return id;
782*89c4ff92SAndroid Build Coastguard Worker }
783*89c4ff92SAndroid Build Coastguard Worker };
784*89c4ff92SAndroid Build Coastguard Worker
785*89c4ff92SAndroid Build Coastguard Worker struct MockPolicy
786*89c4ff92SAndroid Build Coastguard Worker {
GetIdStaticMockPolicy787*89c4ff92SAndroid Build Coastguard Worker static const BackendId& GetIdStatic()
788*89c4ff92SAndroid Build Coastguard Worker {
789*89c4ff92SAndroid Build Coastguard Worker static BackendId id = "MockBackend";
790*89c4ff92SAndroid Build Coastguard Worker return id;
791*89c4ff92SAndroid Build Coastguard Worker }
792*89c4ff92SAndroid Build Coastguard Worker };
793*89c4ff92SAndroid Build Coastguard Worker
794*89c4ff92SAndroid Build Coastguard Worker auto& backendRegistry = BackendRegistryInstance();
795*89c4ff92SAndroid Build Coastguard Worker
__anonfee9358a0202() 796*89c4ff92SAndroid Build Coastguard Worker backendRegistry.Register("MockBackend", []() { return std::make_unique<CustomAllocatorBackend<MockPolicy>>(); });
797*89c4ff92SAndroid Build Coastguard Worker
798*89c4ff92SAndroid Build Coastguard Worker backendRegistry.Register("CustomBackend",
__anonfee9358a0302() 799*89c4ff92SAndroid Build Coastguard Worker []() { return std::make_unique<CustomAllocatorBackend<CustomPolicy>>(); });
800*89c4ff92SAndroid Build Coastguard Worker
801*89c4ff92SAndroid Build Coastguard Worker // Define the network
802*89c4ff92SAndroid Build Coastguard Worker auto network = INetwork::Create();
803*89c4ff92SAndroid Build Coastguard Worker ActivationDescriptor desc;
804*89c4ff92SAndroid Build Coastguard Worker desc.m_Function = ActivationFunction::Linear;
805*89c4ff92SAndroid Build Coastguard Worker
806*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<Graph> graph = std::make_unique<Graph>();
807*89c4ff92SAndroid Build Coastguard Worker auto input = graph->AddLayer<InputLayer>(0, "input");
808*89c4ff92SAndroid Build Coastguard Worker auto act = graph->AddLayer<ActivationLayer>(desc, "activation");
809*89c4ff92SAndroid Build Coastguard Worker auto output = graph->AddLayer<OutputLayer>(0, "output");
810*89c4ff92SAndroid Build Coastguard Worker
811*89c4ff92SAndroid Build Coastguard Worker BackendId customBackendId("CustomBackend");
812*89c4ff92SAndroid Build Coastguard Worker act->BackendSelectionHint(customBackendId);
813*89c4ff92SAndroid Build Coastguard Worker
814*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot(0).Connect(act->GetInputSlot(0));
815*89c4ff92SAndroid Build Coastguard Worker act->GetOutputSlot(0).Connect(output->GetInputSlot(0));
816*89c4ff92SAndroid Build Coastguard Worker
817*89c4ff92SAndroid Build Coastguard Worker OptimizedNetworkImpl optNet(std::move(graph));
818*89c4ff92SAndroid Build Coastguard Worker
819*89c4ff92SAndroid Build Coastguard Worker // Get the optimized graph
820*89c4ff92SAndroid Build Coastguard Worker Graph& optGraph = optNet.GetGraph();
821*89c4ff92SAndroid Build Coastguard Worker
822*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> prefs{ "MockBackend", "CustomBackend" };
823*89c4ff92SAndroid Build Coastguard Worker
824*89c4ff92SAndroid Build Coastguard Worker BackendIdSet availableBackends = { "CustomBackend", "MockBackend" };
825*89c4ff92SAndroid Build Coastguard Worker DeviceSpec spec(availableBackends);
826*89c4ff92SAndroid Build Coastguard Worker
827*89c4ff92SAndroid Build Coastguard Worker BackendSettings backendSettings(prefs, spec);
828*89c4ff92SAndroid Build Coastguard Worker
829*89c4ff92SAndroid Build Coastguard Worker // Assign an available backend to each layer
830*89c4ff92SAndroid Build Coastguard Worker Graph::Iterator firstLayer = optGraph.begin();
831*89c4ff92SAndroid Build Coastguard Worker Graph::Iterator lastLayer = optGraph.end();
832*89c4ff92SAndroid Build Coastguard Worker
833*89c4ff92SAndroid Build Coastguard Worker OptimizedNetworkImpl* optNetObjPtr = &optNet;
834*89c4ff92SAndroid Build Coastguard Worker OptimizationResult res = AssignBackends(optNetObjPtr,
835*89c4ff92SAndroid Build Coastguard Worker backendSettings,
836*89c4ff92SAndroid Build Coastguard Worker firstLayer,
837*89c4ff92SAndroid Build Coastguard Worker lastLayer,
838*89c4ff92SAndroid Build Coastguard Worker EmptyOptional());
839*89c4ff92SAndroid Build Coastguard Worker
840*89c4ff92SAndroid Build Coastguard Worker CHECK(res.IsOk());
841*89c4ff92SAndroid Build Coastguard Worker
842*89c4ff92SAndroid Build Coastguard Worker TestBackendAssignment visitor;
843*89c4ff92SAndroid Build Coastguard Worker for (auto it = firstLayer; it != lastLayer; ++it)
844*89c4ff92SAndroid Build Coastguard Worker {
845*89c4ff92SAndroid Build Coastguard Worker (*it)->ExecuteStrategy(visitor);
846*89c4ff92SAndroid Build Coastguard Worker }
847*89c4ff92SAndroid Build Coastguard Worker // Clean up the registry for the next test.
848*89c4ff92SAndroid Build Coastguard Worker backendRegistry.Deregister("MockBackend");
849*89c4ff92SAndroid Build Coastguard Worker backendRegistry.Deregister("CustomBackend");
850*89c4ff92SAndroid Build Coastguard Worker }
851*89c4ff92SAndroid Build Coastguard Worker
852*89c4ff92SAndroid Build Coastguard Worker // Tests that OptimizeForExclusiveConnections works, fusing when needed, using BatchNorm fusing as example
853*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizeForExclusiveConnectionsFuseTest")
854*89c4ff92SAndroid Build Coastguard Worker {
855*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
856*89c4ff92SAndroid Build Coastguard Worker // Define layers information
857*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolution2dDescriptor;
858*89c4ff92SAndroid Build Coastguard Worker convolution2dDescriptor.m_BiasEnabled = false;
859*89c4ff92SAndroid Build Coastguard Worker convolution2dDescriptor.m_DataLayout = DataLayout::NHWC;
860*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationDescriptor batchNormDescriptor;
861*89c4ff92SAndroid Build Coastguard Worker batchNormDescriptor.m_DataLayout = DataLayout::NHWC;
862*89c4ff92SAndroid Build Coastguard Worker
863*89c4ff92SAndroid Build Coastguard Worker const unsigned int inputDimensionSizes[] = { 1, 4, 4, 3 }; // NHWCin
864*89c4ff92SAndroid Build Coastguard Worker const unsigned int weightsDimensionSizes[] = { 1, 2, 2, 3 }; // CoutHWCin
865*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputDimensionSizes[] = { 1, 3, 3, 1 }; // NHWCout
866*89c4ff92SAndroid Build Coastguard Worker const unsigned int outputChannelSize[] = { outputDimensionSizes[3] }; // Cout
867*89c4ff92SAndroid Build Coastguard Worker
868*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo(4, inputDimensionSizes, DataType::Float32);
869*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputInfo(4, outputDimensionSizes, DataType::Float32);
870*89c4ff92SAndroid Build Coastguard Worker
871*89c4ff92SAndroid Build Coastguard Worker std::vector<float> weightsVector = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 };
872*89c4ff92SAndroid Build Coastguard Worker ConstTensor weights(TensorInfo(4, weightsDimensionSizes, DataType::Float32, 0.0f, 0, true), weightsVector);
873*89c4ff92SAndroid Build Coastguard Worker
874*89c4ff92SAndroid Build Coastguard Worker std::vector<float> betaVector = { 0.1f };
875*89c4ff92SAndroid Build Coastguard Worker std::vector<float> gammaVector = { 0.5f };
876*89c4ff92SAndroid Build Coastguard Worker std::vector<float> meanVector = { 0 };
877*89c4ff92SAndroid Build Coastguard Worker std::vector<float> varianceVector = { 1 };
878*89c4ff92SAndroid Build Coastguard Worker ConstTensor beta(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), betaVector);
879*89c4ff92SAndroid Build Coastguard Worker ConstTensor gamma(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), gammaVector);
880*89c4ff92SAndroid Build Coastguard Worker ConstTensor mean(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), meanVector);
881*89c4ff92SAndroid Build Coastguard Worker ConstTensor variance(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), varianceVector);
882*89c4ff92SAndroid Build Coastguard Worker
883*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* biasLayer = nullptr;
884*89c4ff92SAndroid Build Coastguard Worker
885*89c4ff92SAndroid Build Coastguard Worker // Define the network
886*89c4ff92SAndroid Build Coastguard Worker Graph graph;
887*89c4ff92SAndroid Build Coastguard Worker auto input = graph.AddLayer<InputLayer>(0, "input");
888*89c4ff92SAndroid Build Coastguard Worker auto weightsLayer = graph.AddLayer<ConstantLayer>("Weights");
889*89c4ff92SAndroid Build Coastguard Worker auto conv = graph.AddLayer<Convolution2dLayer>(convolution2dDescriptor, "convolution");
890*89c4ff92SAndroid Build Coastguard Worker auto batchNorm = graph.AddLayer<BatchNormalizationLayer>(batchNormDescriptor, "batchNorm");
891*89c4ff92SAndroid Build Coastguard Worker auto output = graph.AddLayer<OutputLayer>(0, "output");
892*89c4ff92SAndroid Build Coastguard Worker
893*89c4ff92SAndroid Build Coastguard Worker // Set layer information
894*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot().SetTensorInfo(inputInfo);
895*89c4ff92SAndroid Build Coastguard Worker
896*89c4ff92SAndroid Build Coastguard Worker weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weights);
897*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsLayer->m_LayerOutput->GetTensorInfo());
898*89c4ff92SAndroid Build Coastguard Worker conv->GetOutputSlot().SetTensorInfo(outputInfo);
899*89c4ff92SAndroid Build Coastguard Worker
900*89c4ff92SAndroid Build Coastguard Worker batchNorm->GetOutputSlot().SetTensorInfo(outputInfo);
901*89c4ff92SAndroid Build Coastguard Worker batchNorm->m_Beta = std::make_unique<ScopedTensorHandle>(beta);
902*89c4ff92SAndroid Build Coastguard Worker batchNorm->m_Gamma = std::make_unique<ScopedTensorHandle>(gamma);
903*89c4ff92SAndroid Build Coastguard Worker batchNorm->m_Mean = std::make_unique<ScopedTensorHandle>(mean);
904*89c4ff92SAndroid Build Coastguard Worker batchNorm->m_Variance = std::make_unique<ScopedTensorHandle>(variance);
905*89c4ff92SAndroid Build Coastguard Worker
906*89c4ff92SAndroid Build Coastguard Worker if (convolution2dDescriptor.m_BiasEnabled)
907*89c4ff92SAndroid Build Coastguard Worker {
908*89c4ff92SAndroid Build Coastguard Worker std::vector<float> biasVector = { 11 };
909*89c4ff92SAndroid Build Coastguard Worker ConstTensor bias(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), biasVector);
910*89c4ff92SAndroid Build Coastguard Worker biasLayer = graph.AddLayer<ConstantLayer>("Bias");
911*89c4ff92SAndroid Build Coastguard Worker biasLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(bias);
912*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).SetTensorInfo(biasLayer->m_LayerOutput->GetTensorInfo());
913*89c4ff92SAndroid Build Coastguard Worker biasLayer->GetOutputSlot(0).Connect(conv->GetInputSlot(2));
914*89c4ff92SAndroid Build Coastguard Worker }
915*89c4ff92SAndroid Build Coastguard Worker
916*89c4ff92SAndroid Build Coastguard Worker // Connect layers
917*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot(0).Connect(conv->GetInputSlot(0));
918*89c4ff92SAndroid Build Coastguard Worker weightsLayer->GetOutputSlot(0).Connect(conv->GetInputSlot(1));
919*89c4ff92SAndroid Build Coastguard Worker conv->GetOutputSlot(0).Connect(batchNorm->GetInputSlot(0));
920*89c4ff92SAndroid Build Coastguard Worker batchNorm->GetOutputSlot(0).Connect(output->GetInputSlot(0));
921*89c4ff92SAndroid Build Coastguard Worker
922*89c4ff92SAndroid Build Coastguard Worker if (convolution2dDescriptor.m_BiasEnabled)
923*89c4ff92SAndroid Build Coastguard Worker {
924*89c4ff92SAndroid Build Coastguard Worker CHECK(6 == graph.GetNumLayers());
925*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(),
926*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>,
927*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>,
928*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>,
929*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<Convolution2dLayer>,
930*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<BatchNormalizationLayer>,
931*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>));
932*89c4ff92SAndroid Build Coastguard Worker }
933*89c4ff92SAndroid Build Coastguard Worker else
934*89c4ff92SAndroid Build Coastguard Worker {
935*89c4ff92SAndroid Build Coastguard Worker CHECK(5 == graph.GetNumLayers());
936*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(),
937*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>,
938*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>,
939*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<Convolution2dLayer>,
940*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<BatchNormalizationLayer>,
941*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>));
942*89c4ff92SAndroid Build Coastguard Worker }
943*89c4ff92SAndroid Build Coastguard Worker
944*89c4ff92SAndroid Build Coastguard Worker // Optimize graph
945*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, MakeOptimizations(FuseBatchNormIntoConvolution2DFloat32()));
946*89c4ff92SAndroid Build Coastguard Worker
__anonfee9358a0402(const armnn::Layer* const layer) 947*89c4ff92SAndroid Build Coastguard Worker auto checkFusedConv2d = [](const armnn::Layer* const layer) -> bool {
948*89c4ff92SAndroid Build Coastguard Worker return IsLayerOfType<armnn::Convolution2dLayer>(layer) &&
949*89c4ff92SAndroid Build Coastguard Worker (layer->GetNameStr() == "fused-batchNorm-into-convolution");
950*89c4ff92SAndroid Build Coastguard Worker };
951*89c4ff92SAndroid Build Coastguard Worker
952*89c4ff92SAndroid Build Coastguard Worker CHECK(5 == graph.GetNumLayers());
953*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(),
954*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<InputLayer>,
955*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>,
956*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<ConstantLayer>,
957*89c4ff92SAndroid Build Coastguard Worker checkFusedConv2d,
958*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<OutputLayer>));
959*89c4ff92SAndroid Build Coastguard Worker }
960*89c4ff92SAndroid Build Coastguard Worker
961*89c4ff92SAndroid Build Coastguard Worker // Tests that OptimizeForExclusiveConnections works, not fusing when not needed, using BatchNorm fusing as example
962*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizeForExclusiveConnectionsWithoutFuseTest")
963*89c4ff92SAndroid Build Coastguard Worker {
964*89c4ff92SAndroid Build Coastguard Worker // Define the network
965*89c4ff92SAndroid Build Coastguard Worker Graph graph;
966*89c4ff92SAndroid Build Coastguard Worker Convolution2dDescriptor convolution2dDescriptor;
967*89c4ff92SAndroid Build Coastguard Worker BatchNormalizationDescriptor batchNormDescriptor;
968*89c4ff92SAndroid Build Coastguard Worker
969*89c4ff92SAndroid Build Coastguard Worker auto input = graph.AddLayer<InputLayer>(0, "input");
970*89c4ff92SAndroid Build Coastguard Worker auto conv = graph.AddLayer<Convolution2dLayer>(convolution2dDescriptor, "convolution");
971*89c4ff92SAndroid Build Coastguard Worker auto batchNorm = graph.AddLayer<BatchNormalizationLayer>(batchNormDescriptor, "batchNorm");
972*89c4ff92SAndroid Build Coastguard Worker auto output = graph.AddLayer<OutputLayer>(0, "output");
973*89c4ff92SAndroid Build Coastguard Worker auto output2 = graph.AddLayer<OutputLayer>(1, "output2");
974*89c4ff92SAndroid Build Coastguard Worker
975*89c4ff92SAndroid Build Coastguard Worker // Connect layers
976*89c4ff92SAndroid Build Coastguard Worker input->GetOutputSlot(0).Connect(conv->GetInputSlot(0));
977*89c4ff92SAndroid Build Coastguard Worker conv->GetOutputSlot(0).Connect(batchNorm->GetInputSlot(0));
978*89c4ff92SAndroid Build Coastguard Worker batchNorm->GetOutputSlot(0).Connect(output->GetInputSlot(0));
979*89c4ff92SAndroid Build Coastguard Worker conv->GetOutputSlot(0).Connect(output2->GetInputSlot(0));
980*89c4ff92SAndroid Build Coastguard Worker
981*89c4ff92SAndroid Build Coastguard Worker CHECK((5 == graph.GetNumLayers()));
982*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(),
983*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::InputLayer>,
984*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::Convolution2dLayer>,
985*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::BatchNormalizationLayer>,
986*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>,
987*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>));
988*89c4ff92SAndroid Build Coastguard Worker // Optimize graph
989*89c4ff92SAndroid Build Coastguard Worker armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(FuseBatchNormIntoConvolution2DFloat32()));
990*89c4ff92SAndroid Build Coastguard Worker
991*89c4ff92SAndroid Build Coastguard Worker CHECK(5 == graph.GetNumLayers());
992*89c4ff92SAndroid Build Coastguard Worker CHECK(CheckSequence(graph.cbegin(), graph.cend(),
993*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::InputLayer>,
994*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::Convolution2dLayer>,
995*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::BatchNormalizationLayer>,
996*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>,
997*89c4ff92SAndroid Build Coastguard Worker &IsLayerOfType<armnn::OutputLayer>));
998*89c4ff92SAndroid Build Coastguard Worker }
999*89c4ff92SAndroid Build Coastguard Worker } // Optimizer TestSuite
1000