xref: /aosp_15_r20/external/armnn/src/armnn/test/OptimizerTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include <TestUtils.hpp>
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <BackendSettings.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <Graph.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <Network.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendHelper.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/INetwork.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/StrategyBase.hpp>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/LayerSupportBase.hpp>
23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/TensorHandle.hpp>
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker using namespace armnn;
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker namespace
30*89c4ff92SAndroid Build Coastguard Worker {
31*89c4ff92SAndroid Build Coastguard Worker 
CreateLSTMLayerHelper(Graph & graph,bool CifgEnabled)32*89c4ff92SAndroid Build Coastguard Worker void CreateLSTMLayerHelper(Graph &graph, bool CifgEnabled)
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker     LstmDescriptor layerDesc;
35*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_ActivationFunc = 4;
36*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_ClippingThresCell = 0.2f;
37*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_ClippingThresProj = 0.4f;
38*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_CifgEnabled = CifgEnabled;
39*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_PeepholeEnabled = false;
40*89c4ff92SAndroid Build Coastguard Worker     layerDesc.m_ProjectionEnabled = false;
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker     LstmLayer* const layer = graph.AddLayer<LstmLayer>(layerDesc, "layer");
43*89c4ff92SAndroid Build Coastguard Worker     unsigned int batchSize = 3;
44*89c4ff92SAndroid Build Coastguard Worker     unsigned int inputSize = 2;
45*89c4ff92SAndroid Build Coastguard Worker     unsigned int numUnits = 4;
46*89c4ff92SAndroid Build Coastguard Worker     unsigned int outputSize = 4;
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_InputToForgetWeights = std::make_unique<ScopedTensorHandle>
49*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits, inputSize }, DataType::Float32));
50*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_InputToCellWeights = std::make_unique<ScopedTensorHandle>
51*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits, inputSize }, DataType::Float32));
52*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_InputToOutputWeights = std::make_unique<ScopedTensorHandle>
53*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits, inputSize }, DataType::Float32));
54*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_RecurrentToForgetWeights = std::make_unique<ScopedTensorHandle>
55*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits, outputSize }, DataType::Float32));
56*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_RecurrentToCellWeights = std::make_unique<ScopedTensorHandle>
57*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits, outputSize }, DataType::Float32));
58*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_RecurrentToOutputWeights = std::make_unique<ScopedTensorHandle>
59*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits, outputSize }, DataType::Float32));
60*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_ForgetGateBias = std::make_unique<ScopedTensorHandle>
61*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits }, DataType::Float32));
62*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_CellBias = std::make_unique<ScopedTensorHandle>
63*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits }, DataType::Float32));
64*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_OutputGateBias = std::make_unique<ScopedTensorHandle>
65*89c4ff92SAndroid Build Coastguard Worker             (TensorInfo({ numUnits }, DataType::Float32));
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_InputToForgetWeights->Allocate();
68*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_InputToCellWeights->Allocate();
69*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_InputToOutputWeights->Allocate();
70*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_RecurrentToForgetWeights->Allocate();
71*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_RecurrentToCellWeights->Allocate();
72*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_RecurrentToOutputWeights->Allocate();
73*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_ForgetGateBias->Allocate();
74*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_CellBias->Allocate();
75*89c4ff92SAndroid Build Coastguard Worker     layer->m_BasicParameters.m_OutputGateBias->Allocate();
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker     if (!layerDesc.m_CifgEnabled)
78*89c4ff92SAndroid Build Coastguard Worker     {
79*89c4ff92SAndroid Build Coastguard Worker         layer->m_CifgParameters.m_InputToInputWeights = std::make_unique<ScopedTensorHandle>
80*89c4ff92SAndroid Build Coastguard Worker                 (TensorInfo({ numUnits, inputSize }, DataType::Float32));
81*89c4ff92SAndroid Build Coastguard Worker         layer->m_CifgParameters.m_RecurrentToInputWeights = std::make_unique<ScopedTensorHandle>
82*89c4ff92SAndroid Build Coastguard Worker                 (TensorInfo({ numUnits, outputSize }, DataType::Float32));
83*89c4ff92SAndroid Build Coastguard Worker         layer->m_CifgParameters.m_InputGateBias = std::make_unique<ScopedTensorHandle>
84*89c4ff92SAndroid Build Coastguard Worker                 (TensorInfo({ numUnits }, DataType::Float32));
85*89c4ff92SAndroid Build Coastguard Worker         layer->m_CifgParameters.m_InputToInputWeights->Allocate();
86*89c4ff92SAndroid Build Coastguard Worker         layer->m_CifgParameters.m_RecurrentToInputWeights->Allocate();
87*89c4ff92SAndroid Build Coastguard Worker         layer->m_CifgParameters.m_InputGateBias->Allocate();
88*89c4ff92SAndroid Build Coastguard Worker     }
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker     if (layerDesc.m_ProjectionEnabled)
91*89c4ff92SAndroid Build Coastguard Worker     {
92*89c4ff92SAndroid Build Coastguard Worker         layer->m_ProjectionParameters.m_ProjectionWeights = std::make_unique<ScopedTensorHandle>
93*89c4ff92SAndroid Build Coastguard Worker                 (TensorInfo({ outputSize, numUnits }, DataType::Float32));
94*89c4ff92SAndroid Build Coastguard Worker         layer->m_ProjectionParameters.m_ProjectionBias = std::make_unique<ScopedTensorHandle>
95*89c4ff92SAndroid Build Coastguard Worker                 (TensorInfo({ outputSize }, DataType::Float32));
96*89c4ff92SAndroid Build Coastguard Worker         layer->m_ProjectionParameters.m_ProjectionWeights->Allocate();
97*89c4ff92SAndroid Build Coastguard Worker         layer->m_ProjectionParameters.m_ProjectionBias->Allocate();
98*89c4ff92SAndroid Build Coastguard Worker     }
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker     if (layerDesc.m_PeepholeEnabled)
101*89c4ff92SAndroid Build Coastguard Worker     {
102*89c4ff92SAndroid Build Coastguard Worker         if (!layerDesc.m_CifgEnabled)
103*89c4ff92SAndroid Build Coastguard Worker         {
104*89c4ff92SAndroid Build Coastguard Worker             layer->m_PeepholeParameters.m_CellToInputWeights = std::make_unique<ScopedTensorHandle>
105*89c4ff92SAndroid Build Coastguard Worker                     (TensorInfo({ numUnits }, DataType::Float32));
106*89c4ff92SAndroid Build Coastguard Worker             layer->m_PeepholeParameters.m_CellToInputWeights->Allocate();
107*89c4ff92SAndroid Build Coastguard Worker         }
108*89c4ff92SAndroid Build Coastguard Worker         layer->m_PeepholeParameters.m_CellToForgetWeights = std::make_unique<ScopedTensorHandle>
109*89c4ff92SAndroid Build Coastguard Worker                 (TensorInfo({ numUnits }, DataType::Float32));
110*89c4ff92SAndroid Build Coastguard Worker         layer->m_PeepholeParameters.m_CellToOutputWeights = std::make_unique<ScopedTensorHandle>
111*89c4ff92SAndroid Build Coastguard Worker                 (TensorInfo({ numUnits }, DataType::Float32));
112*89c4ff92SAndroid Build Coastguard Worker         layer->m_PeepholeParameters.m_CellToForgetWeights->Allocate();
113*89c4ff92SAndroid Build Coastguard Worker         layer->m_PeepholeParameters.m_CellToOutputWeights->Allocate();
114*89c4ff92SAndroid Build Coastguard Worker     }
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     // create input and output layers
117*89c4ff92SAndroid Build Coastguard Worker     Layer* const input = graph.AddLayer<InputLayer>(0, "input");
118*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputStateIn = graph.AddLayer<InputLayer>(1, "outputStateIn");
119*89c4ff92SAndroid Build Coastguard Worker     Layer* const cellStateIn = graph.AddLayer<InputLayer>(2, "cellStateIn");
120*89c4ff92SAndroid Build Coastguard Worker     Layer* const scratchBuffer = graph.AddLayer<OutputLayer>(0, "scratchBuffer");
121*89c4ff92SAndroid Build Coastguard Worker     Layer* const outputStateOut = graph.AddLayer<OutputLayer>(1, "outputStateOut");
122*89c4ff92SAndroid Build Coastguard Worker     Layer* const cellStateOut = graph.AddLayer<OutputLayer>(2, "cellStateOut");
123*89c4ff92SAndroid Build Coastguard Worker     Layer* const output = graph.AddLayer<OutputLayer>(3, "output");
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker     // connect up
126*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo lstmTensorInfo1({ batchSize, inputSize }, DataType::Float32);
127*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo lstmTensorInfo2({ batchSize, numUnits}, DataType::Float32);
128*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo lstmTensorInfo3({ batchSize, outputSize }, DataType::Float32);
129*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo lstmTensorInfoScratchBuff({ batchSize, numUnits * (layerDesc.m_CifgEnabled ? 3 : 4) },
130*89c4ff92SAndroid Build Coastguard Worker                                                 DataType::Float32);
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker     Connect(input, layer, lstmTensorInfo1, 0, 0);
133*89c4ff92SAndroid Build Coastguard Worker     Connect(cellStateIn, layer, lstmTensorInfo2, 0, 1);
134*89c4ff92SAndroid Build Coastguard Worker     Connect(outputStateIn, layer, lstmTensorInfo3, 0, 2);
135*89c4ff92SAndroid Build Coastguard Worker     Connect(layer, scratchBuffer, lstmTensorInfoScratchBuff, 0, 0);
136*89c4ff92SAndroid Build Coastguard Worker     Connect(layer, outputStateOut, lstmTensorInfo3, 1, 0);
137*89c4ff92SAndroid Build Coastguard Worker     Connect(layer, cellStateOut, lstmTensorInfo2, 2, 0);
138*89c4ff92SAndroid Build Coastguard Worker     Connect(layer, output, lstmTensorInfo3, 3, 0);
139*89c4ff92SAndroid Build Coastguard Worker }
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker class MockLayerSupport : public LayerSupportBase
143*89c4ff92SAndroid Build Coastguard Worker {
144*89c4ff92SAndroid Build Coastguard Worker public:
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> &,const Optional<QuantizedLstmInputParamsInfo> &,Optional<std::string &> reasonIfUnsupported) const145*89c4ff92SAndroid Build Coastguard Worker     bool IsLayerSupported(const LayerType& type,
146*89c4ff92SAndroid Build Coastguard Worker                           const std::vector<TensorInfo>& infos,
147*89c4ff92SAndroid Build Coastguard Worker                           const BaseDescriptor& descriptor,
148*89c4ff92SAndroid Build Coastguard Worker                           const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/,
149*89c4ff92SAndroid Build Coastguard Worker                           const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/,
150*89c4ff92SAndroid Build Coastguard Worker                           Optional<std::string&> reasonIfUnsupported) const override
151*89c4ff92SAndroid Build Coastguard Worker     {
152*89c4ff92SAndroid Build Coastguard Worker         switch (type)
153*89c4ff92SAndroid Build Coastguard Worker         {
154*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Input:
155*89c4ff92SAndroid Build Coastguard Worker                 return IsInputSupported(infos[0], reasonIfUnsupported);
156*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Output:
157*89c4ff92SAndroid Build Coastguard Worker                 return IsOutputSupported(infos[0], reasonIfUnsupported);
158*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Activation:
159*89c4ff92SAndroid Build Coastguard Worker                 return IsActivationSupported(infos[0],
160*89c4ff92SAndroid Build Coastguard Worker                                              infos[1],
161*89c4ff92SAndroid Build Coastguard Worker                                              *(PolymorphicDowncast<const ActivationDescriptor*>(&descriptor)),
162*89c4ff92SAndroid Build Coastguard Worker                                              reasonIfUnsupported);
163*89c4ff92SAndroid Build Coastguard Worker             default:
164*89c4ff92SAndroid Build Coastguard Worker                 return false;
165*89c4ff92SAndroid Build Coastguard Worker         }
166*89c4ff92SAndroid Build Coastguard Worker     }
167*89c4ff92SAndroid Build Coastguard Worker 
IsInputSupported(const TensorInfo &,Optional<std::string &>) const168*89c4ff92SAndroid Build Coastguard Worker     bool IsInputSupported(const TensorInfo& /*input*/,
169*89c4ff92SAndroid Build Coastguard Worker                           Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
170*89c4ff92SAndroid Build Coastguard Worker     {
171*89c4ff92SAndroid Build Coastguard Worker         return true;
172*89c4ff92SAndroid Build Coastguard Worker     }
173*89c4ff92SAndroid Build Coastguard Worker 
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const174*89c4ff92SAndroid Build Coastguard Worker     bool IsOutputSupported(const TensorInfo& /*input*/,
175*89c4ff92SAndroid Build Coastguard Worker                            Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
176*89c4ff92SAndroid Build Coastguard Worker     {
177*89c4ff92SAndroid Build Coastguard Worker         return true;
178*89c4ff92SAndroid Build Coastguard Worker     }
179*89c4ff92SAndroid Build Coastguard Worker 
IsActivationSupported(const TensorInfo &,const TensorInfo &,const ActivationDescriptor &,Optional<std::string &>) const180*89c4ff92SAndroid Build Coastguard Worker     bool IsActivationSupported(const TensorInfo& /*input0*/,
181*89c4ff92SAndroid Build Coastguard Worker                                const TensorInfo& /*output*/,
182*89c4ff92SAndroid Build Coastguard Worker                                const ActivationDescriptor& /*descriptor*/,
183*89c4ff92SAndroid Build Coastguard Worker                                Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
184*89c4ff92SAndroid Build Coastguard Worker     {
185*89c4ff92SAndroid Build Coastguard Worker         return true;
186*89c4ff92SAndroid Build Coastguard Worker     }
187*89c4ff92SAndroid Build Coastguard Worker };
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker template <typename NamePolicy>
190*89c4ff92SAndroid Build Coastguard Worker class CustomAllocatorBackend : public IBackendInternal
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker public:
CustomAllocatorBackend()193*89c4ff92SAndroid Build Coastguard Worker     CustomAllocatorBackend() :
194*89c4ff92SAndroid Build Coastguard Worker             m_BackendCapabilities(NamePolicy::GetIdStatic(), {{"NullCapability", false}}),
195*89c4ff92SAndroid Build Coastguard Worker             m_CustomAllocator(false) {};
CustomAllocatorBackend(const BackendCapabilities & capabilities)196*89c4ff92SAndroid Build Coastguard Worker     CustomAllocatorBackend(const BackendCapabilities& capabilities) :
197*89c4ff92SAndroid Build Coastguard Worker             m_BackendCapabilities(capabilities),
198*89c4ff92SAndroid Build Coastguard Worker             m_CustomAllocator(false) {};
199*89c4ff92SAndroid Build Coastguard Worker     ~CustomAllocatorBackend() = default;
200*89c4ff92SAndroid Build Coastguard Worker 
GetIdStatic()201*89c4ff92SAndroid Build Coastguard Worker     static const BackendId& GetIdStatic()
202*89c4ff92SAndroid Build Coastguard Worker     {
203*89c4ff92SAndroid Build Coastguard Worker         return NamePolicy::GetIdStatic();
204*89c4ff92SAndroid Build Coastguard Worker     }
GetId() const205*89c4ff92SAndroid Build Coastguard Worker     const BackendId& GetId() const override
206*89c4ff92SAndroid Build Coastguard Worker     {
207*89c4ff92SAndroid Build Coastguard Worker         return GetIdStatic();
208*89c4ff92SAndroid Build Coastguard Worker     }
209*89c4ff92SAndroid Build Coastguard Worker 
CreateMemoryManager() const210*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override
211*89c4ff92SAndroid Build Coastguard Worker     {
212*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
213*89c4ff92SAndroid Build Coastguard Worker     };
214*89c4ff92SAndroid Build Coastguard Worker 
215*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IWorkloadFactoryPtr
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr &) const216*89c4ff92SAndroid Build Coastguard Worker         CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr&) const override
217*89c4ff92SAndroid Build Coastguard Worker     {
218*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
219*89c4ff92SAndroid Build Coastguard Worker     }
220*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendContext(const IRuntime::CreationOptions &) const221*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
222*89c4ff92SAndroid Build Coastguard Worker     {
223*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
224*89c4ff92SAndroid Build Coastguard Worker     }
225*89c4ff92SAndroid Build Coastguard Worker 
GetLayerSupport() const226*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
227*89c4ff92SAndroid Build Coastguard Worker     {
228*89c4ff92SAndroid Build Coastguard Worker         return std::make_shared<MockLayerSupport>();
229*89c4ff92SAndroid Build Coastguard Worker     }
230*89c4ff92SAndroid Build Coastguard Worker 
OptimizeSubgraphView(const SubgraphView &) const231*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews OptimizeSubgraphView(const SubgraphView&) const override
232*89c4ff92SAndroid Build Coastguard Worker     {
233*89c4ff92SAndroid Build Coastguard Worker         return {};
234*89c4ff92SAndroid Build Coastguard Worker     };
235*89c4ff92SAndroid Build Coastguard Worker 
GetCapabilities() const236*89c4ff92SAndroid Build Coastguard Worker     BackendCapabilities GetCapabilities() const override
237*89c4ff92SAndroid Build Coastguard Worker     {
238*89c4ff92SAndroid Build Coastguard Worker         return m_BackendCapabilities;
239*89c4ff92SAndroid Build Coastguard Worker     };
240*89c4ff92SAndroid Build Coastguard Worker 
UseCustomMemoryAllocator(std::shared_ptr<ICustomAllocator> allocator,armnn::Optional<std::string &> errMsg)241*89c4ff92SAndroid Build Coastguard Worker     virtual bool UseCustomMemoryAllocator(std::shared_ptr<ICustomAllocator> allocator,
242*89c4ff92SAndroid Build Coastguard Worker                                           armnn::Optional<std::string&> errMsg) override
243*89c4ff92SAndroid Build Coastguard Worker     {
244*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(errMsg, allocator);
245*89c4ff92SAndroid Build Coastguard Worker         m_CustomAllocator = true;
246*89c4ff92SAndroid Build Coastguard Worker         return m_CustomAllocator;
247*89c4ff92SAndroid Build Coastguard Worker     }
248*89c4ff92SAndroid Build Coastguard Worker 
249*89c4ff92SAndroid Build Coastguard Worker     BackendCapabilities m_BackendCapabilities;
250*89c4ff92SAndroid Build Coastguard Worker     bool m_CustomAllocator;
251*89c4ff92SAndroid Build Coastguard Worker };
252*89c4ff92SAndroid Build Coastguard Worker 
253*89c4ff92SAndroid Build Coastguard Worker template <typename NamePolicy>
254*89c4ff92SAndroid Build Coastguard Worker class NoProtectedModeMockBackend : public IBackendInternal
255*89c4ff92SAndroid Build Coastguard Worker {
256*89c4ff92SAndroid Build Coastguard Worker public:
NoProtectedModeMockBackend()257*89c4ff92SAndroid Build Coastguard Worker     NoProtectedModeMockBackend() : m_BackendCapabilities(NamePolicy::GetIdStatic(), {{"NullCapability", false}}) {};
NoProtectedModeMockBackend(const BackendCapabilities & capabilities)258*89c4ff92SAndroid Build Coastguard Worker     NoProtectedModeMockBackend(const BackendCapabilities& capabilities) : m_BackendCapabilities(capabilities) {};
259*89c4ff92SAndroid Build Coastguard Worker     ~NoProtectedModeMockBackend() = default;
260*89c4ff92SAndroid Build Coastguard Worker 
GetIdStatic()261*89c4ff92SAndroid Build Coastguard Worker     static const BackendId& GetIdStatic()
262*89c4ff92SAndroid Build Coastguard Worker     {
263*89c4ff92SAndroid Build Coastguard Worker         return NamePolicy::GetIdStatic();
264*89c4ff92SAndroid Build Coastguard Worker     }
GetId() const265*89c4ff92SAndroid Build Coastguard Worker     const BackendId& GetId() const override
266*89c4ff92SAndroid Build Coastguard Worker     {
267*89c4ff92SAndroid Build Coastguard Worker         return GetIdStatic();
268*89c4ff92SAndroid Build Coastguard Worker     }
269*89c4ff92SAndroid Build Coastguard Worker 
CreateMemoryManager() const270*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override
271*89c4ff92SAndroid Build Coastguard Worker     {
272*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
273*89c4ff92SAndroid Build Coastguard Worker     };
274*89c4ff92SAndroid Build Coastguard Worker 
275*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IWorkloadFactoryPtr
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr &) const276*89c4ff92SAndroid Build Coastguard Worker         CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr&) const override
277*89c4ff92SAndroid Build Coastguard Worker     {
278*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
279*89c4ff92SAndroid Build Coastguard Worker     }
280*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendContext(const IRuntime::CreationOptions &) const281*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override
282*89c4ff92SAndroid Build Coastguard Worker     {
283*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
284*89c4ff92SAndroid Build Coastguard Worker     }
285*89c4ff92SAndroid Build Coastguard Worker 
GetLayerSupport() const286*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override
287*89c4ff92SAndroid Build Coastguard Worker     {
288*89c4ff92SAndroid Build Coastguard Worker         return std::make_shared<MockLayerSupport>();
289*89c4ff92SAndroid Build Coastguard Worker     }
290*89c4ff92SAndroid Build Coastguard Worker 
OptimizeSubgraphView(const SubgraphView &) const291*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews OptimizeSubgraphView(const SubgraphView&) const override
292*89c4ff92SAndroid Build Coastguard Worker     {
293*89c4ff92SAndroid Build Coastguard Worker         return {};
294*89c4ff92SAndroid Build Coastguard Worker     };
295*89c4ff92SAndroid Build Coastguard Worker 
GetCapabilities() const296*89c4ff92SAndroid Build Coastguard Worker     BackendCapabilities GetCapabilities() const override
297*89c4ff92SAndroid Build Coastguard Worker     {
298*89c4ff92SAndroid Build Coastguard Worker         return m_BackendCapabilities;
299*89c4ff92SAndroid Build Coastguard Worker     };
300*89c4ff92SAndroid Build Coastguard Worker 
301*89c4ff92SAndroid Build Coastguard Worker     BackendCapabilities m_BackendCapabilities;
302*89c4ff92SAndroid Build Coastguard Worker };
303*89c4ff92SAndroid Build Coastguard Worker 
304*89c4ff92SAndroid Build Coastguard Worker }    // namespace
305*89c4ff92SAndroid Build Coastguard Worker 
306*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Optimizer")
307*89c4ff92SAndroid Build Coastguard Worker {
308*89c4ff92SAndroid Build Coastguard Worker using namespace armnn::optimizations;
309*89c4ff92SAndroid Build Coastguard Worker 
310*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LSTMValidateTensorShapesFromInputsCIFGDisabledTest")
311*89c4ff92SAndroid Build Coastguard Worker {
312*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
313*89c4ff92SAndroid Build Coastguard Worker 
314*89c4ff92SAndroid Build Coastguard Worker     //Helper function creates graph containing LSTM layer with required input and output layers
315*89c4ff92SAndroid Build Coastguard Worker     CreateLSTMLayerHelper(graph, false);
316*89c4ff92SAndroid Build Coastguard Worker 
317*89c4ff92SAndroid Build Coastguard Worker     //This function used to call ValidateShapesFromInputs();
318*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
319*89c4ff92SAndroid Build Coastguard Worker }
320*89c4ff92SAndroid Build Coastguard Worker 
321*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("LSTMValidateTensorShapesFromInputsCIFGEnabledTest")
322*89c4ff92SAndroid Build Coastguard Worker {
323*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
324*89c4ff92SAndroid Build Coastguard Worker 
325*89c4ff92SAndroid Build Coastguard Worker     //Helper function creates graph containing LSTM layer with required input and output layers
326*89c4ff92SAndroid Build Coastguard Worker     CreateLSTMLayerHelper(graph, true);
327*89c4ff92SAndroid Build Coastguard Worker 
328*89c4ff92SAndroid Build Coastguard Worker     //This function used to call ValidateShapesFromInputs();
329*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
330*89c4ff92SAndroid Build Coastguard Worker }
331*89c4ff92SAndroid Build Coastguard Worker 
332*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("InsertConvertersTest")
333*89c4ff92SAndroid Build Coastguard Worker {
334*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorInfo info({ 1, 5, 2, 3 }, armnn::DataType::Float16);
335*89c4ff92SAndroid Build Coastguard Worker 
336*89c4ff92SAndroid Build Coastguard Worker     armnn::Graph graph;
337*89c4ff92SAndroid Build Coastguard Worker 
338*89c4ff92SAndroid Build Coastguard Worker     armnn::LayerBindingId inputId = 0;
339*89c4ff92SAndroid Build Coastguard Worker 
340*89c4ff92SAndroid Build Coastguard Worker     armnn::Layer* head = graph.AddLayer<armnn::OutputLayer>(0, "output");
341*89c4ff92SAndroid Build Coastguard Worker 
342*89c4ff92SAndroid Build Coastguard Worker     head = graph.InsertNewLayer<armnn::AdditionLayer>(head->GetInputSlot(0), "");
343*89c4ff92SAndroid Build Coastguard Worker     head->GetOutputHandler().SetTensorInfo(info);
344*89c4ff92SAndroid Build Coastguard Worker 
345*89c4ff92SAndroid Build Coastguard Worker     graph.InsertNewLayer<armnn::InputLayer>(head->GetInputSlot(1), inputId++, "")
346*89c4ff92SAndroid Build Coastguard Worker         ->GetOutputHandler().SetTensorInfo(info);
347*89c4ff92SAndroid Build Coastguard Worker 
348*89c4ff92SAndroid Build Coastguard Worker     head = graph.InsertNewLayer<armnn::FloorLayer>(head->GetInputSlot(0), "");
349*89c4ff92SAndroid Build Coastguard Worker     head->GetOutputHandler().SetTensorInfo(info);
350*89c4ff92SAndroid Build Coastguard Worker 
351*89c4ff92SAndroid Build Coastguard Worker     head = graph.InsertNewLayer<armnn::MemCopyLayer>(head->GetInputSlot(0), "");
352*89c4ff92SAndroid Build Coastguard Worker     head->GetOutputHandler().SetTensorInfo(info);
353*89c4ff92SAndroid Build Coastguard Worker 
354*89c4ff92SAndroid Build Coastguard Worker     graph.InsertNewLayer<armnn::InputLayer>(head->GetInputSlot(0), inputId++, "")
355*89c4ff92SAndroid Build Coastguard Worker         ->GetOutputHandler().SetTensorInfo(info);
356*89c4ff92SAndroid Build Coastguard Worker 
357*89c4ff92SAndroid Build Coastguard Worker     // Check graph layer sequence before inserting convert layers
358*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckSequence(graph.cbegin(),
359*89c4ff92SAndroid Build Coastguard Worker                              graph.cend(),
360*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::InputLayer>,
361*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::InputLayer>,
362*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::MemCopyLayer>,
363*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::FloorLayer>,
364*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::AdditionLayer>,
365*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::OutputLayer>));
366*89c4ff92SAndroid Build Coastguard Worker 
367*89c4ff92SAndroid Build Coastguard Worker     // Check layers have Float16 DataType
368*89c4ff92SAndroid Build Coastguard Worker     for (auto& layer : graph)
369*89c4ff92SAndroid Build Coastguard Worker     {
370*89c4ff92SAndroid Build Coastguard Worker         if(layer->GetType()==LayerType::Floor || layer->GetType() == LayerType::Addition)
371*89c4ff92SAndroid Build Coastguard Worker         {
372*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(layer->GetOutputSlot(0).GetTensorInfo().GetDataType() == DataType::Float16);
373*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(layer->GetDataType() == DataType::Float16);
374*89c4ff92SAndroid Build Coastguard Worker         }
375*89c4ff92SAndroid Build Coastguard Worker     }
376*89c4ff92SAndroid Build Coastguard Worker 
377*89c4ff92SAndroid Build Coastguard Worker     // Insert convert layers either side of unsupported layer
378*89c4ff92SAndroid Build Coastguard Worker     for (auto& layer : graph)
379*89c4ff92SAndroid Build Coastguard Worker     {
380*89c4ff92SAndroid Build Coastguard Worker         if(layer->GetType()==LayerType::Floor || layer->GetType() == LayerType::Addition)
381*89c4ff92SAndroid Build Coastguard Worker         {
382*89c4ff92SAndroid Build Coastguard Worker             InsertConvertFp16ToFp32LayersBefore(graph, *layer);
383*89c4ff92SAndroid Build Coastguard Worker             InsertConvertFp32ToFp16LayersAfter(graph, *layer);
384*89c4ff92SAndroid Build Coastguard Worker         }
385*89c4ff92SAndroid Build Coastguard Worker     }
386*89c4ff92SAndroid Build Coastguard Worker 
387*89c4ff92SAndroid Build Coastguard Worker     // Check layers have correct DataType after inserting convert layers
388*89c4ff92SAndroid Build Coastguard Worker     for (auto& layer : graph)
389*89c4ff92SAndroid Build Coastguard Worker     {
390*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetType()==LayerType::Floor || layer->GetType() == LayerType::Addition)
391*89c4ff92SAndroid Build Coastguard Worker         {
392*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(layer->GetOutputSlot(0).GetTensorInfo().GetDataType() == DataType::Float32);
393*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(layer->GetDataType() == DataType::Float32);
394*89c4ff92SAndroid Build Coastguard Worker         }
395*89c4ff92SAndroid Build Coastguard Worker         else if (layer->GetType() == LayerType::ConvertFp16ToFp32)
396*89c4ff92SAndroid Build Coastguard Worker         {
397*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(layer->GetOutputSlot(0).GetTensorInfo().GetDataType() == DataType::Float32);
398*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(layer->GetDataType() == DataType::Float16);
399*89c4ff92SAndroid Build Coastguard Worker         }
400*89c4ff92SAndroid Build Coastguard Worker         else if (layer->GetType() == LayerType::ConvertFp32ToFp16)
401*89c4ff92SAndroid Build Coastguard Worker         {
402*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(layer->GetOutputSlot(0).GetTensorInfo().GetDataType() == DataType::Float16);
403*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT(layer->GetDataType() == DataType::Float32);
404*89c4ff92SAndroid Build Coastguard Worker         }
405*89c4ff92SAndroid Build Coastguard Worker     }
406*89c4ff92SAndroid Build Coastguard Worker 
407*89c4ff92SAndroid Build Coastguard Worker     // Check sequence of layers after inserting convert layers
408*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckSequence(graph.cbegin(),
409*89c4ff92SAndroid Build Coastguard Worker                              graph.cend(),
410*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::InputLayer>,
411*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::InputLayer>,
412*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::ConvertFp16ToFp32Layer>,
413*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::MemCopyLayer>,
414*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::ConvertFp16ToFp32Layer>,
415*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::FloorLayer>,
416*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::ConvertFp32ToFp16Layer>,
417*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::ConvertFp16ToFp32Layer>,
418*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::AdditionLayer>,
419*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::ConvertFp32ToFp16Layer>,
420*89c4ff92SAndroid Build Coastguard Worker                              &IsLayerOfType<armnn::OutputLayer>));
421*89c4ff92SAndroid Build Coastguard Worker }
422*89c4ff92SAndroid Build Coastguard Worker 
CreateConvolution2dGraph(Graph & graph,const unsigned int * inputShape,const unsigned int * weightsShape,const unsigned int * outputShape,DataLayout dataLayout=DataLayout::NCHW)423*89c4ff92SAndroid Build Coastguard Worker void CreateConvolution2dGraph(Graph &graph, const unsigned int* inputShape,
424*89c4ff92SAndroid Build Coastguard Worker                               const unsigned int* weightsShape, const unsigned int* outputShape,
425*89c4ff92SAndroid Build Coastguard Worker                               DataLayout dataLayout = DataLayout::NCHW)
426*89c4ff92SAndroid Build Coastguard Worker {
427*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo(4, inputShape, DataType::Float32);
428*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo(4, outputShape, DataType::Float32);
429*89c4ff92SAndroid Build Coastguard Worker 
430*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector(90);
431*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(
432*89c4ff92SAndroid Build Coastguard Worker             armnn::TensorInfo(4, weightsShape, armnn::DataType::Float32, 0.0f, 0, true),
433*89c4ff92SAndroid Build Coastguard Worker             weightsVector);
434*89c4ff92SAndroid Build Coastguard Worker 
435*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor desc;
436*89c4ff92SAndroid Build Coastguard Worker     desc.m_BiasEnabled = false;
437*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX     = 1;
438*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY     = 1;
439*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout  = dataLayout;
440*89c4ff92SAndroid Build Coastguard Worker 
441*89c4ff92SAndroid Build Coastguard Worker     Layer* input = graph.AddLayer<InputLayer>(0, "input");
442*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().SetTensorInfo(inputInfo);
443*89c4ff92SAndroid Build Coastguard Worker 
444*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* weightsLayer = graph.AddLayer<ConstantLayer>("Weights");
445*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weights);
446*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsLayer->m_LayerOutput->GetTensorInfo());
447*89c4ff92SAndroid Build Coastguard Worker 
448*89c4ff92SAndroid Build Coastguard Worker     Convolution2dLayer* layer = graph.AddLayer<Convolution2dLayer>(desc, "conv2d");
449*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().SetTensorInfo(outputInfo);
450*89c4ff92SAndroid Build Coastguard Worker 
451*89c4ff92SAndroid Build Coastguard Worker     Layer* output = graph.AddLayer<OutputLayer>(0, "output");
452*89c4ff92SAndroid Build Coastguard Worker 
453*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
454*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(output->GetInputSlot(0));
455*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(layer->GetInputSlot(1));
456*89c4ff92SAndroid Build Coastguard Worker }
457*89c4ff92SAndroid Build Coastguard Worker 
458*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Conv2dValidateTensorShapesFromInputs")
459*89c4ff92SAndroid Build Coastguard Worker {
460*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
461*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[] = { 1, 3, 8, 16 };
462*89c4ff92SAndroid Build Coastguard Worker     const unsigned int weightsShape[] = { 2, 3, 5, 3 };
463*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = { 1, 2, 4, 14 };
464*89c4ff92SAndroid Build Coastguard Worker     CreateConvolution2dGraph(graph, inputShape, weightsShape, outputShape);
465*89c4ff92SAndroid Build Coastguard Worker 
466*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
467*89c4ff92SAndroid Build Coastguard Worker }
468*89c4ff92SAndroid Build Coastguard Worker 
469*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Conv2dValidateTensorShapesFromInputsNhwc")
470*89c4ff92SAndroid Build Coastguard Worker {
471*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
472*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[] = { 1, 8, 16, 3 };
473*89c4ff92SAndroid Build Coastguard Worker     const unsigned int weightsShape[] = { 2, 5, 3, 3 };
474*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = { 1, 4, 14, 2 };
475*89c4ff92SAndroid Build Coastguard Worker     CreateConvolution2dGraph(graph, inputShape, weightsShape, outputShape, DataLayout::NHWC);
476*89c4ff92SAndroid Build Coastguard Worker 
477*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
478*89c4ff92SAndroid Build Coastguard Worker }
479*89c4ff92SAndroid Build Coastguard Worker 
CreateDepthwiseConvolution2dGraph(Graph & graph,const unsigned int * inputShape,const unsigned int * weightsShape,const unsigned int * outputShape,DataLayout dataLayout=DataLayout::NCHW)480*89c4ff92SAndroid Build Coastguard Worker void CreateDepthwiseConvolution2dGraph(Graph &graph, const unsigned int* inputShape,
481*89c4ff92SAndroid Build Coastguard Worker                                        const unsigned int* weightsShape, const unsigned int* outputShape,
482*89c4ff92SAndroid Build Coastguard Worker                                        DataLayout dataLayout = DataLayout::NCHW)
483*89c4ff92SAndroid Build Coastguard Worker {
484*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo(4, inputShape, DataType::Float32);
485*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo(4, outputShape, DataType::Float32);
486*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo weightsInfo(TensorShape(4, weightsShape), armnn::DataType::Float32, 0.0f, 0, true);
487*89c4ff92SAndroid Build Coastguard Worker 
488*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector(18);
489*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor weights(weightsInfo, weightsVector);
490*89c4ff92SAndroid Build Coastguard Worker 
491*89c4ff92SAndroid Build Coastguard Worker     DepthwiseConvolution2dDescriptor desc;
492*89c4ff92SAndroid Build Coastguard Worker     desc.m_BiasEnabled = false;
493*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX     = 1;
494*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideY     = 1;
495*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout  = dataLayout;
496*89c4ff92SAndroid Build Coastguard Worker 
497*89c4ff92SAndroid Build Coastguard Worker     InputLayer* input                  = graph.AddLayer<InputLayer>(0, "input");
498*89c4ff92SAndroid Build Coastguard Worker     DepthwiseConvolution2dLayer* layer = graph.AddLayer<DepthwiseConvolution2dLayer>(desc, "depthwiseConv2d");
499*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* weightsLayer        = graph.AddLayer<ConstantLayer>("weights");
500*89c4ff92SAndroid Build Coastguard Worker     OutputLayer* output                = graph.AddLayer<OutputLayer>(0, "output");
501*89c4ff92SAndroid Build Coastguard Worker 
502*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().SetTensorInfo(inputInfo);
503*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().SetTensorInfo(outputInfo);
504*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot().SetTensorInfo(weightsInfo);
505*89c4ff92SAndroid Build Coastguard Worker 
506*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->m_LayerOutput = std::make_unique<armnn::ScopedTensorHandle>(weights);
507*89c4ff92SAndroid Build Coastguard Worker 
508*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
509*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot().Connect(layer->GetInputSlot(1));
510*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(output->GetInputSlot(0));
511*89c4ff92SAndroid Build Coastguard Worker }
512*89c4ff92SAndroid Build Coastguard Worker 
513*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("DepthwiseConv2dValidateTensorShapesFromInputs")
514*89c4ff92SAndroid Build Coastguard Worker {
515*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
516*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[] = { 1, 2, 3, 3 };
517*89c4ff92SAndroid Build Coastguard Worker     const unsigned int weightsShape[] = { 1, 3, 3, 2 };
518*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = { 1, 2, 1, 1 };
519*89c4ff92SAndroid Build Coastguard Worker     CreateDepthwiseConvolution2dGraph(graph, inputShape, weightsShape, outputShape);
520*89c4ff92SAndroid Build Coastguard Worker 
521*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
522*89c4ff92SAndroid Build Coastguard Worker }
523*89c4ff92SAndroid Build Coastguard Worker 
524*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("DepthwiseConv2dValidateTensorShapesFromInputsNhwc")
525*89c4ff92SAndroid Build Coastguard Worker {
526*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
527*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[] = { 1, 3, 3, 2 };
528*89c4ff92SAndroid Build Coastguard Worker     const unsigned int weightsShape[] = { 1, 3, 3, 2 };
529*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = { 1, 1, 1, 2 };
530*89c4ff92SAndroid Build Coastguard Worker     CreateDepthwiseConvolution2dGraph(graph, inputShape, weightsShape, outputShape, DataLayout::NHWC);
531*89c4ff92SAndroid Build Coastguard Worker 
532*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
533*89c4ff92SAndroid Build Coastguard Worker }
534*89c4ff92SAndroid Build Coastguard Worker 
CreatePooling2dGraph(Graph & graph,const unsigned int * inputShape,const unsigned int * outputShape,DataLayout dataLayout=DataLayout::NCHW)535*89c4ff92SAndroid Build Coastguard Worker void CreatePooling2dGraph(Graph& graph, const unsigned int* inputShape,  const unsigned int* outputShape,
536*89c4ff92SAndroid Build Coastguard Worker                           DataLayout dataLayout = DataLayout::NCHW)
537*89c4ff92SAndroid Build Coastguard Worker {
538*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputInfo(4, inputShape, DataType::Float32);
539*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo(4, outputShape, DataType::Float32);
540*89c4ff92SAndroid Build Coastguard Worker 
541*89c4ff92SAndroid Build Coastguard Worker     Pooling2dDescriptor desc;
542*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolType  = armnn::PoolingAlgorithm::Average;
543*89c4ff92SAndroid Build Coastguard Worker     desc.m_PoolWidth = desc.m_PoolHeight = 100;
544*89c4ff92SAndroid Build Coastguard Worker     desc.m_StrideX = desc.m_StrideY = 5;
545*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadLeft                  = 50;
546*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadRight                 = 50;
547*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadTop                   = 50;
548*89c4ff92SAndroid Build Coastguard Worker     desc.m_PadBottom                = 50;
549*89c4ff92SAndroid Build Coastguard Worker     desc.m_PaddingMethod            = armnn::PaddingMethod::Exclude;
550*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout               = dataLayout;
551*89c4ff92SAndroid Build Coastguard Worker 
552*89c4ff92SAndroid Build Coastguard Worker     Layer* input = graph.AddLayer<InputLayer>(0, "input");
553*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().SetTensorInfo(inputInfo);
554*89c4ff92SAndroid Build Coastguard Worker 
555*89c4ff92SAndroid Build Coastguard Worker     Pooling2dLayer* layer = graph.AddLayer<Pooling2dLayer>(desc, "pooling2d");
556*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().SetTensorInfo(outputInfo);
557*89c4ff92SAndroid Build Coastguard Worker 
558*89c4ff92SAndroid Build Coastguard Worker     Layer* output = graph.AddLayer<OutputLayer>(0, "output");
559*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
560*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(output->GetInputSlot(0));
561*89c4ff92SAndroid Build Coastguard Worker }
562*89c4ff92SAndroid Build Coastguard Worker 
563*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Pooling2dValidateTensorShapesFromInputs")
564*89c4ff92SAndroid Build Coastguard Worker {
565*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
566*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[]  = { 5, 3, 52, 60 };
567*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = { 5, 3, 11, 13 };
568*89c4ff92SAndroid Build Coastguard Worker     CreatePooling2dGraph(graph, inputShape, outputShape, DataLayout::NCHW);
569*89c4ff92SAndroid Build Coastguard Worker 
570*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
571*89c4ff92SAndroid Build Coastguard Worker }
572*89c4ff92SAndroid Build Coastguard Worker 
573*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("Pooling2dValidateTensorShapesFromInputsNhwc")
574*89c4ff92SAndroid Build Coastguard Worker {
575*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
576*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[]  = { 5, 52, 60, 3 };
577*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = { 5, 11, 13, 3 };
578*89c4ff92SAndroid Build Coastguard Worker     CreatePooling2dGraph(graph, inputShape, outputShape, DataLayout::NHWC);
579*89c4ff92SAndroid Build Coastguard Worker 
580*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
581*89c4ff92SAndroid Build Coastguard Worker }
582*89c4ff92SAndroid Build Coastguard Worker 
CreateResizeBilinearGraph(Graph & graph,const unsigned int * inputShape,const unsigned int * outputShape,DataLayout dataLayout=DataLayout::NCHW)583*89c4ff92SAndroid Build Coastguard Worker void CreateResizeBilinearGraph(Graph& graph,
584*89c4ff92SAndroid Build Coastguard Worker                                const unsigned int* inputShape,
585*89c4ff92SAndroid Build Coastguard Worker                                const unsigned int* outputShape,
586*89c4ff92SAndroid Build Coastguard Worker                                DataLayout dataLayout = DataLayout::NCHW)
587*89c4ff92SAndroid Build Coastguard Worker {
588*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputInfo(4, inputShape, DataType::Float32);
589*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputInfo(4, outputShape, DataType::Float32);
590*89c4ff92SAndroid Build Coastguard Worker 
591*89c4ff92SAndroid Build Coastguard Worker     ResizeDescriptor desc;
592*89c4ff92SAndroid Build Coastguard Worker     desc.m_Method       = ResizeMethod::Bilinear;
593*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetHeight = 3;
594*89c4ff92SAndroid Build Coastguard Worker     desc.m_TargetWidth  = 4;
595*89c4ff92SAndroid Build Coastguard Worker     desc.m_DataLayout   = dataLayout;
596*89c4ff92SAndroid Build Coastguard Worker 
597*89c4ff92SAndroid Build Coastguard Worker     Layer* input = graph.AddLayer<InputLayer>(0, "input");
598*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().SetTensorInfo(inputInfo);
599*89c4ff92SAndroid Build Coastguard Worker 
600*89c4ff92SAndroid Build Coastguard Worker     ResizeLayer* layer = graph.AddLayer<ResizeLayer>(desc, "resizeBilinear");
601*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().SetTensorInfo(outputInfo);
602*89c4ff92SAndroid Build Coastguard Worker 
603*89c4ff92SAndroid Build Coastguard Worker     Layer* output = graph.AddLayer<OutputLayer>(0, "output");
604*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().Connect(layer->GetInputSlot(0));
605*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(output->GetInputSlot(0));
606*89c4ff92SAndroid Build Coastguard Worker }
607*89c4ff92SAndroid Build Coastguard Worker 
608*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ResizeBilinearValidateTensorShapesFromInputs")
609*89c4ff92SAndroid Build Coastguard Worker {
610*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
611*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[]  = { 1, 2, 4, 5 };
612*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = { 1, 2, 3, 4 };
613*89c4ff92SAndroid Build Coastguard Worker     CreateResizeBilinearGraph(graph, inputShape, outputShape);
614*89c4ff92SAndroid Build Coastguard Worker 
615*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
616*89c4ff92SAndroid Build Coastguard Worker }
617*89c4ff92SAndroid Build Coastguard Worker 
618*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("ResizeBilinearValidateTensorShapesFromInputsNhwc")
619*89c4ff92SAndroid Build Coastguard Worker {
620*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
621*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputShape[]  = { 1, 4, 5, 2 };
622*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputShape[] = { 1, 3, 4, 2 };
623*89c4ff92SAndroid Build Coastguard Worker     CreateResizeBilinearGraph(graph, inputShape, outputShape, DataLayout::NHWC);
624*89c4ff92SAndroid Build Coastguard Worker 
625*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
626*89c4ff92SAndroid Build Coastguard Worker }
627*89c4ff92SAndroid Build Coastguard Worker 
CreateGatherGraph(Graph & graph,const armnn::TensorInfo & paramsInfo,const armnn::TensorInfo & indicesInfo,const armnn::TensorInfo & outputInfo)628*89c4ff92SAndroid Build Coastguard Worker void CreateGatherGraph(Graph& graph,
629*89c4ff92SAndroid Build Coastguard Worker                        const armnn::TensorInfo& paramsInfo,
630*89c4ff92SAndroid Build Coastguard Worker                        const armnn::TensorInfo& indicesInfo,
631*89c4ff92SAndroid Build Coastguard Worker                        const armnn::TensorInfo& outputInfo)
632*89c4ff92SAndroid Build Coastguard Worker {
633*89c4ff92SAndroid Build Coastguard Worker     Layer* input0 = graph.AddLayer<InputLayer>(0, "params");
634*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot().SetTensorInfo(paramsInfo);
635*89c4ff92SAndroid Build Coastguard Worker 
636*89c4ff92SAndroid Build Coastguard Worker     Layer* input1 = graph.AddLayer<InputLayer>(1, "indices");
637*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot().SetTensorInfo(indicesInfo);
638*89c4ff92SAndroid Build Coastguard Worker 
639*89c4ff92SAndroid Build Coastguard Worker     GatherDescriptor descriptor;
640*89c4ff92SAndroid Build Coastguard Worker     GatherLayer* layer = graph.AddLayer<GatherLayer>(descriptor, "gather");
641*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().SetTensorInfo(outputInfo);
642*89c4ff92SAndroid Build Coastguard Worker 
643*89c4ff92SAndroid Build Coastguard Worker     Layer* output = graph.AddLayer<OutputLayer>(0, "output");
644*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot().Connect(layer->GetInputSlot(0));
645*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot().Connect(layer->GetInputSlot(1));
646*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot().Connect(output->GetInputSlot(0));
647*89c4ff92SAndroid Build Coastguard Worker }
648*89c4ff92SAndroid Build Coastguard Worker 
649*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GatherValidateTensorShapesFromInputs")
650*89c4ff92SAndroid Build Coastguard Worker {
651*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
652*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo paramsInfo({10, 5}, DataType::Float32);
653*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo indicesInfo({3}, DataType::Signed32);
654*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo({3, 5}, DataType::Float32);
655*89c4ff92SAndroid Build Coastguard Worker 
656*89c4ff92SAndroid Build Coastguard Worker     CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
657*89c4ff92SAndroid Build Coastguard Worker 
658*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
659*89c4ff92SAndroid Build Coastguard Worker }
660*89c4ff92SAndroid Build Coastguard Worker 
661*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GatherValidateTensorShapesFromInputs1DParams")
662*89c4ff92SAndroid Build Coastguard Worker {
663*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
664*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo paramsInfo({8}, DataType::Float32);
665*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo indicesInfo({5}, DataType::Signed32);
666*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo( {5}, DataType::Float32);
667*89c4ff92SAndroid Build Coastguard Worker 
668*89c4ff92SAndroid Build Coastguard Worker     CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
669*89c4ff92SAndroid Build Coastguard Worker 
670*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
671*89c4ff92SAndroid Build Coastguard Worker }
672*89c4ff92SAndroid Build Coastguard Worker 
673*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GatherValidateTensorShapesFromInputsMultiDimIndices")
674*89c4ff92SAndroid Build Coastguard Worker {
675*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
676*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo paramsInfo({3, 2, 5}, DataType::Float32);
677*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo indicesInfo({2, 2}, DataType::Signed32);
678*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputInfo({2, 2, 2, 5}, DataType::Float32);
679*89c4ff92SAndroid Build Coastguard Worker 
680*89c4ff92SAndroid Build Coastguard Worker     CreateGatherGraph(graph, paramsInfo, indicesInfo, outputInfo);
681*89c4ff92SAndroid Build Coastguard Worker 
682*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
683*89c4ff92SAndroid Build Coastguard Worker }
684*89c4ff92SAndroid Build Coastguard Worker 
685*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("DetectionPostProcessValidateTensorShapes")
686*89c4ff92SAndroid Build Coastguard Worker {
687*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
688*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo boxEncodingsInfo({1, 10, 4}, DataType::QAsymmU8);
689*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo scoresInfo({1, 10, 4}, DataType::QAsymmU8);
690*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> anchorsVector(40);
691*89c4ff92SAndroid Build Coastguard Worker     armnn::ConstTensor anchors(armnn::TensorInfo({10, 4}, armnn::DataType::QAsymmU8, 0.0f, 0, true), anchorsVector);
692*89c4ff92SAndroid Build Coastguard Worker 
693*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo detectionBoxesInfo({1, 3, 4}, DataType::QAsymmU8);
694*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo detectionScoresInfo({1, 3}, DataType::QAsymmU8);
695*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo detectionClassesInfo({1, 3}, DataType::QAsymmU8);
696*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo numDetectionInfo({1}, DataType::QAsymmU8);
697*89c4ff92SAndroid Build Coastguard Worker 
698*89c4ff92SAndroid Build Coastguard Worker     Layer* input0 = graph.AddLayer<InputLayer>(0, "boxEncodings");
699*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot().SetTensorInfo(boxEncodingsInfo);
700*89c4ff92SAndroid Build Coastguard Worker 
701*89c4ff92SAndroid Build Coastguard Worker     Layer* input1 = graph.AddLayer<InputLayer>(1, "score");
702*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot().SetTensorInfo(scoresInfo);
703*89c4ff92SAndroid Build Coastguard Worker 
704*89c4ff92SAndroid Build Coastguard Worker     DetectionPostProcessDescriptor descriptor;
705*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_MaxDetections = 3;
706*89c4ff92SAndroid Build Coastguard Worker 
707*89c4ff92SAndroid Build Coastguard Worker     DetectionPostProcessLayer* layer = graph.AddLayer<DetectionPostProcessLayer>(descriptor, "detectionPostProcess");
708*89c4ff92SAndroid Build Coastguard Worker     layer->m_Anchors = std::make_unique<armnn::ScopedTensorHandle>(anchors);
709*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(0).SetTensorInfo(detectionBoxesInfo);
710*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(1).SetTensorInfo(detectionScoresInfo);
711*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(2).SetTensorInfo(detectionClassesInfo);
712*89c4ff92SAndroid Build Coastguard Worker     layer->GetOutputSlot(3).SetTensorInfo(numDetectionInfo);
713*89c4ff92SAndroid Build Coastguard Worker 
714*89c4ff92SAndroid Build Coastguard Worker     input0->GetOutputSlot().Connect(layer->GetInputSlot(0));
715*89c4ff92SAndroid Build Coastguard Worker     input1->GetOutputSlot().Connect(layer->GetInputSlot(1));
716*89c4ff92SAndroid Build Coastguard Worker 
717*89c4ff92SAndroid Build Coastguard Worker     CHECK_NOTHROW(graph.InferTensorInfos());
718*89c4ff92SAndroid Build Coastguard Worker }
719*89c4ff92SAndroid Build Coastguard Worker 
720*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("BackendCapabilityTest")
721*89c4ff92SAndroid Build Coastguard Worker {
722*89c4ff92SAndroid Build Coastguard Worker     BackendId backendId = "MockBackend";
723*89c4ff92SAndroid Build Coastguard Worker 
724*89c4ff92SAndroid Build Coastguard Worker     armnn::BackendOptions::BackendOption nonConstWeights{"NonConstWeights", true};
725*89c4ff92SAndroid Build Coastguard Worker 
726*89c4ff92SAndroid Build Coastguard Worker     // MockBackend does not support the NonConstWeights capability
727*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::HasCapability(nonConstWeights, backendId));
728*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::HasCapability("NonConstWeights", backendId));
729*89c4ff92SAndroid Build Coastguard Worker 
730*89c4ff92SAndroid Build Coastguard Worker     // MockBackend does not support the AsyncExecution capability
731*89c4ff92SAndroid Build Coastguard Worker     CHECK(!armnn::GetCapability("AsyncExecution", backendId).has_value());
732*89c4ff92SAndroid Build Coastguard Worker }
733*89c4ff92SAndroid Build Coastguard Worker 
734*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("BackendHintTest")
735*89c4ff92SAndroid Build Coastguard Worker {
736*89c4ff92SAndroid Build Coastguard Worker     class TestBackendAssignment : public StrategyBase<NoThrowStrategy>
737*89c4ff92SAndroid Build Coastguard Worker     {
738*89c4ff92SAndroid Build Coastguard Worker     public:
739*89c4ff92SAndroid Build Coastguard Worker 
ExecuteStrategy(const armnn::IConnectableLayer * layer,const armnn::BaseDescriptor & descriptor,const std::vector<armnn::ConstTensor> & constants,const char * name,const armnn::LayerBindingId id=0)740*89c4ff92SAndroid Build Coastguard Worker         void ExecuteStrategy(const armnn::IConnectableLayer* layer,
741*89c4ff92SAndroid Build Coastguard Worker                              const armnn::BaseDescriptor& descriptor,
742*89c4ff92SAndroid Build Coastguard Worker                              const std::vector<armnn::ConstTensor>& constants,
743*89c4ff92SAndroid Build Coastguard Worker                              const char* name,
744*89c4ff92SAndroid Build Coastguard Worker                              const armnn::LayerBindingId id = 0) override
745*89c4ff92SAndroid Build Coastguard Worker         {
746*89c4ff92SAndroid Build Coastguard Worker             armnn::IgnoreUnused(descriptor, constants, id, name);
747*89c4ff92SAndroid Build Coastguard Worker             switch (layer->GetType())
748*89c4ff92SAndroid Build Coastguard Worker             {
749*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Input:
750*89c4ff92SAndroid Build Coastguard Worker                 {
751*89c4ff92SAndroid Build Coastguard Worker                     auto inputLayer = PolymorphicDowncast<const InputLayer*>(layer);
752*89c4ff92SAndroid Build Coastguard Worker                     const auto connectedLayerBackendId = inputLayer->GetOutputSlot(0).GetOwningLayer().GetBackendId();
753*89c4ff92SAndroid Build Coastguard Worker                     CHECK((inputLayer->GetBackendId() == connectedLayerBackendId));
754*89c4ff92SAndroid Build Coastguard Worker                     break;
755*89c4ff92SAndroid Build Coastguard Worker                 }
756*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Output:
757*89c4ff92SAndroid Build Coastguard Worker                 {
758*89c4ff92SAndroid Build Coastguard Worker                     auto outputLayer = PolymorphicDowncast<const OutputLayer*>(layer);
759*89c4ff92SAndroid Build Coastguard Worker                     CHECK((outputLayer->GetBackendId() == "MockBackend"));
760*89c4ff92SAndroid Build Coastguard Worker                     break;
761*89c4ff92SAndroid Build Coastguard Worker                 }
762*89c4ff92SAndroid Build Coastguard Worker                 case armnn::LayerType::Activation:
763*89c4ff92SAndroid Build Coastguard Worker                 {
764*89c4ff92SAndroid Build Coastguard Worker                     auto activation = PolymorphicDowncast<const ActivationLayer*>(layer);
765*89c4ff92SAndroid Build Coastguard Worker                     CHECK((activation->GetBackendId() == "CustomBackend"));
766*89c4ff92SAndroid Build Coastguard Worker                     break;
767*89c4ff92SAndroid Build Coastguard Worker                 }
768*89c4ff92SAndroid Build Coastguard Worker                 default:
769*89c4ff92SAndroid Build Coastguard Worker                 {
770*89c4ff92SAndroid Build Coastguard Worker                     m_DefaultStrategy.Apply(GetLayerTypeAsCString(layer->GetType()));
771*89c4ff92SAndroid Build Coastguard Worker                 }
772*89c4ff92SAndroid Build Coastguard Worker             }
773*89c4ff92SAndroid Build Coastguard Worker         }
774*89c4ff92SAndroid Build Coastguard Worker     };
775*89c4ff92SAndroid Build Coastguard Worker 
776*89c4ff92SAndroid Build Coastguard Worker     struct CustomPolicy
777*89c4ff92SAndroid Build Coastguard Worker     {
GetIdStaticCustomPolicy778*89c4ff92SAndroid Build Coastguard Worker         static const BackendId& GetIdStatic()
779*89c4ff92SAndroid Build Coastguard Worker         {
780*89c4ff92SAndroid Build Coastguard Worker             static BackendId id = "CustomBackend";
781*89c4ff92SAndroid Build Coastguard Worker             return id;
782*89c4ff92SAndroid Build Coastguard Worker         }
783*89c4ff92SAndroid Build Coastguard Worker     };
784*89c4ff92SAndroid Build Coastguard Worker 
785*89c4ff92SAndroid Build Coastguard Worker     struct MockPolicy
786*89c4ff92SAndroid Build Coastguard Worker     {
GetIdStaticMockPolicy787*89c4ff92SAndroid Build Coastguard Worker         static const BackendId& GetIdStatic()
788*89c4ff92SAndroid Build Coastguard Worker         {
789*89c4ff92SAndroid Build Coastguard Worker             static BackendId id = "MockBackend";
790*89c4ff92SAndroid Build Coastguard Worker             return id;
791*89c4ff92SAndroid Build Coastguard Worker         }
792*89c4ff92SAndroid Build Coastguard Worker     };
793*89c4ff92SAndroid Build Coastguard Worker 
794*89c4ff92SAndroid Build Coastguard Worker     auto& backendRegistry = BackendRegistryInstance();
795*89c4ff92SAndroid Build Coastguard Worker 
__anonfee9358a0202() 796*89c4ff92SAndroid Build Coastguard Worker     backendRegistry.Register("MockBackend", []() { return std::make_unique<CustomAllocatorBackend<MockPolicy>>(); });
797*89c4ff92SAndroid Build Coastguard Worker 
798*89c4ff92SAndroid Build Coastguard Worker     backendRegistry.Register("CustomBackend",
__anonfee9358a0302() 799*89c4ff92SAndroid Build Coastguard Worker                              []() { return std::make_unique<CustomAllocatorBackend<CustomPolicy>>(); });
800*89c4ff92SAndroid Build Coastguard Worker 
801*89c4ff92SAndroid Build Coastguard Worker     // Define the network
802*89c4ff92SAndroid Build Coastguard Worker     auto network = INetwork::Create();
803*89c4ff92SAndroid Build Coastguard Worker     ActivationDescriptor desc;
804*89c4ff92SAndroid Build Coastguard Worker     desc.m_Function = ActivationFunction::Linear;
805*89c4ff92SAndroid Build Coastguard Worker 
806*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<Graph> graph = std::make_unique<Graph>();
807*89c4ff92SAndroid Build Coastguard Worker     auto input                   = graph->AddLayer<InputLayer>(0, "input");
808*89c4ff92SAndroid Build Coastguard Worker     auto act                     = graph->AddLayer<ActivationLayer>(desc, "activation");
809*89c4ff92SAndroid Build Coastguard Worker     auto output                  = graph->AddLayer<OutputLayer>(0, "output");
810*89c4ff92SAndroid Build Coastguard Worker 
811*89c4ff92SAndroid Build Coastguard Worker     BackendId customBackendId("CustomBackend");
812*89c4ff92SAndroid Build Coastguard Worker     act->BackendSelectionHint(customBackendId);
813*89c4ff92SAndroid Build Coastguard Worker 
814*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(act->GetInputSlot(0));
815*89c4ff92SAndroid Build Coastguard Worker     act->GetOutputSlot(0).Connect(output->GetInputSlot(0));
816*89c4ff92SAndroid Build Coastguard Worker 
817*89c4ff92SAndroid Build Coastguard Worker     OptimizedNetworkImpl optNet(std::move(graph));
818*89c4ff92SAndroid Build Coastguard Worker 
819*89c4ff92SAndroid Build Coastguard Worker     // Get the optimized graph
820*89c4ff92SAndroid Build Coastguard Worker     Graph& optGraph = optNet.GetGraph();
821*89c4ff92SAndroid Build Coastguard Worker 
822*89c4ff92SAndroid Build Coastguard Worker     std::vector<BackendId> prefs{ "MockBackend", "CustomBackend" };
823*89c4ff92SAndroid Build Coastguard Worker 
824*89c4ff92SAndroid Build Coastguard Worker     BackendIdSet availableBackends = { "CustomBackend", "MockBackend" };
825*89c4ff92SAndroid Build Coastguard Worker     DeviceSpec spec(availableBackends);
826*89c4ff92SAndroid Build Coastguard Worker 
827*89c4ff92SAndroid Build Coastguard Worker     BackendSettings backendSettings(prefs, spec);
828*89c4ff92SAndroid Build Coastguard Worker 
829*89c4ff92SAndroid Build Coastguard Worker     // Assign an available backend to each layer
830*89c4ff92SAndroid Build Coastguard Worker     Graph::Iterator firstLayer = optGraph.begin();
831*89c4ff92SAndroid Build Coastguard Worker     Graph::Iterator lastLayer  = optGraph.end();
832*89c4ff92SAndroid Build Coastguard Worker 
833*89c4ff92SAndroid Build Coastguard Worker     OptimizedNetworkImpl* optNetObjPtr = &optNet;
834*89c4ff92SAndroid Build Coastguard Worker     OptimizationResult res = AssignBackends(optNetObjPtr,
835*89c4ff92SAndroid Build Coastguard Worker                                             backendSettings,
836*89c4ff92SAndroid Build Coastguard Worker                                             firstLayer,
837*89c4ff92SAndroid Build Coastguard Worker                                             lastLayer,
838*89c4ff92SAndroid Build Coastguard Worker                                             EmptyOptional());
839*89c4ff92SAndroid Build Coastguard Worker 
840*89c4ff92SAndroid Build Coastguard Worker     CHECK(res.IsOk());
841*89c4ff92SAndroid Build Coastguard Worker 
842*89c4ff92SAndroid Build Coastguard Worker     TestBackendAssignment visitor;
843*89c4ff92SAndroid Build Coastguard Worker     for (auto it = firstLayer; it != lastLayer; ++it)
844*89c4ff92SAndroid Build Coastguard Worker     {
845*89c4ff92SAndroid Build Coastguard Worker         (*it)->ExecuteStrategy(visitor);
846*89c4ff92SAndroid Build Coastguard Worker     }
847*89c4ff92SAndroid Build Coastguard Worker     // Clean up the registry for the next test.
848*89c4ff92SAndroid Build Coastguard Worker     backendRegistry.Deregister("MockBackend");
849*89c4ff92SAndroid Build Coastguard Worker     backendRegistry.Deregister("CustomBackend");
850*89c4ff92SAndroid Build Coastguard Worker }
851*89c4ff92SAndroid Build Coastguard Worker 
852*89c4ff92SAndroid Build Coastguard Worker // Tests that OptimizeForExclusiveConnections works, fusing when needed, using BatchNorm fusing as example
853*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizeForExclusiveConnectionsFuseTest")
854*89c4ff92SAndroid Build Coastguard Worker {
855*89c4ff92SAndroid Build Coastguard Worker     using namespace armnn;
856*89c4ff92SAndroid Build Coastguard Worker     // Define layers information
857*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convolution2dDescriptor;
858*89c4ff92SAndroid Build Coastguard Worker     convolution2dDescriptor.m_BiasEnabled = false;
859*89c4ff92SAndroid Build Coastguard Worker     convolution2dDescriptor.m_DataLayout  = DataLayout::NHWC;
860*89c4ff92SAndroid Build Coastguard Worker     BatchNormalizationDescriptor batchNormDescriptor;
861*89c4ff92SAndroid Build Coastguard Worker     batchNormDescriptor.m_DataLayout = DataLayout::NHWC;
862*89c4ff92SAndroid Build Coastguard Worker 
863*89c4ff92SAndroid Build Coastguard Worker     const unsigned int inputDimensionSizes[]   = { 1, 4, 4, 3 };                 // NHWCin
864*89c4ff92SAndroid Build Coastguard Worker     const unsigned int weightsDimensionSizes[] = { 1, 2, 2, 3 };                 // CoutHWCin
865*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputDimensionSizes[]  = { 1, 3, 3, 1 };                 // NHWCout
866*89c4ff92SAndroid Build Coastguard Worker     const unsigned int outputChannelSize[]     = { outputDimensionSizes[3] };    // Cout
867*89c4ff92SAndroid Build Coastguard Worker 
868*89c4ff92SAndroid Build Coastguard Worker     TensorInfo inputInfo(4, inputDimensionSizes, DataType::Float32);
869*89c4ff92SAndroid Build Coastguard Worker     TensorInfo outputInfo(4, outputDimensionSizes, DataType::Float32);
870*89c4ff92SAndroid Build Coastguard Worker 
871*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> weightsVector = { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 };
872*89c4ff92SAndroid Build Coastguard Worker     ConstTensor weights(TensorInfo(4, weightsDimensionSizes, DataType::Float32, 0.0f, 0, true), weightsVector);
873*89c4ff92SAndroid Build Coastguard Worker 
874*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> betaVector     = { 0.1f };
875*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> gammaVector    = { 0.5f };
876*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> meanVector     = { 0 };
877*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> varianceVector = { 1 };
878*89c4ff92SAndroid Build Coastguard Worker     ConstTensor beta(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), betaVector);
879*89c4ff92SAndroid Build Coastguard Worker     ConstTensor gamma(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), gammaVector);
880*89c4ff92SAndroid Build Coastguard Worker     ConstTensor mean(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), meanVector);
881*89c4ff92SAndroid Build Coastguard Worker     ConstTensor variance(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), varianceVector);
882*89c4ff92SAndroid Build Coastguard Worker 
883*89c4ff92SAndroid Build Coastguard Worker     ConstantLayer* biasLayer = nullptr;
884*89c4ff92SAndroid Build Coastguard Worker 
885*89c4ff92SAndroid Build Coastguard Worker     // Define the network
886*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
887*89c4ff92SAndroid Build Coastguard Worker     auto input        = graph.AddLayer<InputLayer>(0, "input");
888*89c4ff92SAndroid Build Coastguard Worker     auto weightsLayer = graph.AddLayer<ConstantLayer>("Weights");
889*89c4ff92SAndroid Build Coastguard Worker     auto conv         = graph.AddLayer<Convolution2dLayer>(convolution2dDescriptor, "convolution");
890*89c4ff92SAndroid Build Coastguard Worker     auto batchNorm    = graph.AddLayer<BatchNormalizationLayer>(batchNormDescriptor, "batchNorm");
891*89c4ff92SAndroid Build Coastguard Worker     auto output       = graph.AddLayer<OutputLayer>(0, "output");
892*89c4ff92SAndroid Build Coastguard Worker 
893*89c4ff92SAndroid Build Coastguard Worker     // Set layer information
894*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot().SetTensorInfo(inputInfo);
895*89c4ff92SAndroid Build Coastguard Worker 
896*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(weights);
897*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).SetTensorInfo(weightsLayer->m_LayerOutput->GetTensorInfo());
898*89c4ff92SAndroid Build Coastguard Worker     conv->GetOutputSlot().SetTensorInfo(outputInfo);
899*89c4ff92SAndroid Build Coastguard Worker 
900*89c4ff92SAndroid Build Coastguard Worker     batchNorm->GetOutputSlot().SetTensorInfo(outputInfo);
901*89c4ff92SAndroid Build Coastguard Worker     batchNorm->m_Beta     = std::make_unique<ScopedTensorHandle>(beta);
902*89c4ff92SAndroid Build Coastguard Worker     batchNorm->m_Gamma    = std::make_unique<ScopedTensorHandle>(gamma);
903*89c4ff92SAndroid Build Coastguard Worker     batchNorm->m_Mean     = std::make_unique<ScopedTensorHandle>(mean);
904*89c4ff92SAndroid Build Coastguard Worker     batchNorm->m_Variance = std::make_unique<ScopedTensorHandle>(variance);
905*89c4ff92SAndroid Build Coastguard Worker 
906*89c4ff92SAndroid Build Coastguard Worker     if (convolution2dDescriptor.m_BiasEnabled)
907*89c4ff92SAndroid Build Coastguard Worker     {
908*89c4ff92SAndroid Build Coastguard Worker         std::vector<float> biasVector = { 11 };
909*89c4ff92SAndroid Build Coastguard Worker         ConstTensor bias(TensorInfo(1, outputChannelSize, DataType::Float32, 0.0f, 0, true), biasVector);
910*89c4ff92SAndroid Build Coastguard Worker         biasLayer = graph.AddLayer<ConstantLayer>("Bias");
911*89c4ff92SAndroid Build Coastguard Worker         biasLayer->m_LayerOutput = std::make_shared<ScopedTensorHandle>(bias);
912*89c4ff92SAndroid Build Coastguard Worker         biasLayer->GetOutputSlot(0).SetTensorInfo(biasLayer->m_LayerOutput->GetTensorInfo());
913*89c4ff92SAndroid Build Coastguard Worker         biasLayer->GetOutputSlot(0).Connect(conv->GetInputSlot(2));
914*89c4ff92SAndroid Build Coastguard Worker     }
915*89c4ff92SAndroid Build Coastguard Worker 
916*89c4ff92SAndroid Build Coastguard Worker     // Connect layers
917*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(conv->GetInputSlot(0));
918*89c4ff92SAndroid Build Coastguard Worker     weightsLayer->GetOutputSlot(0).Connect(conv->GetInputSlot(1));
919*89c4ff92SAndroid Build Coastguard Worker     conv->GetOutputSlot(0).Connect(batchNorm->GetInputSlot(0));
920*89c4ff92SAndroid Build Coastguard Worker     batchNorm->GetOutputSlot(0).Connect(output->GetInputSlot(0));
921*89c4ff92SAndroid Build Coastguard Worker 
922*89c4ff92SAndroid Build Coastguard Worker     if (convolution2dDescriptor.m_BiasEnabled)
923*89c4ff92SAndroid Build Coastguard Worker     {
924*89c4ff92SAndroid Build Coastguard Worker         CHECK(6 == graph.GetNumLayers());
925*89c4ff92SAndroid Build Coastguard Worker         CHECK(CheckSequence(graph.cbegin(), graph.cend(),
926*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<InputLayer>,
927*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<ConstantLayer>,
928*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<ConstantLayer>,
929*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<Convolution2dLayer>,
930*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<BatchNormalizationLayer>,
931*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<OutputLayer>));
932*89c4ff92SAndroid Build Coastguard Worker     }
933*89c4ff92SAndroid Build Coastguard Worker     else
934*89c4ff92SAndroid Build Coastguard Worker     {
935*89c4ff92SAndroid Build Coastguard Worker         CHECK(5 == graph.GetNumLayers());
936*89c4ff92SAndroid Build Coastguard Worker         CHECK(CheckSequence(graph.cbegin(), graph.cend(),
937*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<InputLayer>,
938*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<ConstantLayer>,
939*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<Convolution2dLayer>,
940*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<BatchNormalizationLayer>,
941*89c4ff92SAndroid Build Coastguard Worker                             &IsLayerOfType<OutputLayer>));
942*89c4ff92SAndroid Build Coastguard Worker     }
943*89c4ff92SAndroid Build Coastguard Worker 
944*89c4ff92SAndroid Build Coastguard Worker     // Optimize graph
945*89c4ff92SAndroid Build Coastguard Worker     armnn::Optimizer::Pass(graph, MakeOptimizations(FuseBatchNormIntoConvolution2DFloat32()));
946*89c4ff92SAndroid Build Coastguard Worker 
__anonfee9358a0402(const armnn::Layer* const layer) 947*89c4ff92SAndroid Build Coastguard Worker     auto checkFusedConv2d = [](const armnn::Layer* const layer) -> bool {
948*89c4ff92SAndroid Build Coastguard Worker         return IsLayerOfType<armnn::Convolution2dLayer>(layer) &&
949*89c4ff92SAndroid Build Coastguard Worker                (layer->GetNameStr() == "fused-batchNorm-into-convolution");
950*89c4ff92SAndroid Build Coastguard Worker     };
951*89c4ff92SAndroid Build Coastguard Worker 
952*89c4ff92SAndroid Build Coastguard Worker     CHECK(5 == graph.GetNumLayers());
953*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckSequence(graph.cbegin(), graph.cend(),
954*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<InputLayer>,
955*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<ConstantLayer>,
956*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<ConstantLayer>,
957*89c4ff92SAndroid Build Coastguard Worker                         checkFusedConv2d,
958*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<OutputLayer>));
959*89c4ff92SAndroid Build Coastguard Worker }
960*89c4ff92SAndroid Build Coastguard Worker 
961*89c4ff92SAndroid Build Coastguard Worker // Tests that OptimizeForExclusiveConnections works, not fusing when not needed, using BatchNorm fusing as example
962*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("OptimizeForExclusiveConnectionsWithoutFuseTest")
963*89c4ff92SAndroid Build Coastguard Worker {
964*89c4ff92SAndroid Build Coastguard Worker     // Define the network
965*89c4ff92SAndroid Build Coastguard Worker     Graph graph;
966*89c4ff92SAndroid Build Coastguard Worker     Convolution2dDescriptor convolution2dDescriptor;
967*89c4ff92SAndroid Build Coastguard Worker     BatchNormalizationDescriptor batchNormDescriptor;
968*89c4ff92SAndroid Build Coastguard Worker 
969*89c4ff92SAndroid Build Coastguard Worker     auto input     = graph.AddLayer<InputLayer>(0, "input");
970*89c4ff92SAndroid Build Coastguard Worker     auto conv      = graph.AddLayer<Convolution2dLayer>(convolution2dDescriptor, "convolution");
971*89c4ff92SAndroid Build Coastguard Worker     auto batchNorm = graph.AddLayer<BatchNormalizationLayer>(batchNormDescriptor, "batchNorm");
972*89c4ff92SAndroid Build Coastguard Worker     auto output    = graph.AddLayer<OutputLayer>(0, "output");
973*89c4ff92SAndroid Build Coastguard Worker     auto output2   = graph.AddLayer<OutputLayer>(1, "output2");
974*89c4ff92SAndroid Build Coastguard Worker 
975*89c4ff92SAndroid Build Coastguard Worker     // Connect layers
976*89c4ff92SAndroid Build Coastguard Worker     input->GetOutputSlot(0).Connect(conv->GetInputSlot(0));
977*89c4ff92SAndroid Build Coastguard Worker     conv->GetOutputSlot(0).Connect(batchNorm->GetInputSlot(0));
978*89c4ff92SAndroid Build Coastguard Worker     batchNorm->GetOutputSlot(0).Connect(output->GetInputSlot(0));
979*89c4ff92SAndroid Build Coastguard Worker     conv->GetOutputSlot(0).Connect(output2->GetInputSlot(0));
980*89c4ff92SAndroid Build Coastguard Worker 
981*89c4ff92SAndroid Build Coastguard Worker     CHECK((5 == graph.GetNumLayers()));
982*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckSequence(graph.cbegin(), graph.cend(),
983*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::InputLayer>,
984*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::Convolution2dLayer>,
985*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::BatchNormalizationLayer>,
986*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::OutputLayer>,
987*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::OutputLayer>));
988*89c4ff92SAndroid Build Coastguard Worker     // Optimize graph
989*89c4ff92SAndroid Build Coastguard Worker     armnn::Optimizer::Pass(graph, armnn::MakeOptimizations(FuseBatchNormIntoConvolution2DFloat32()));
990*89c4ff92SAndroid Build Coastguard Worker 
991*89c4ff92SAndroid Build Coastguard Worker     CHECK(5 == graph.GetNumLayers());
992*89c4ff92SAndroid Build Coastguard Worker     CHECK(CheckSequence(graph.cbegin(), graph.cend(),
993*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::InputLayer>,
994*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::Convolution2dLayer>,
995*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::BatchNormalizationLayer>,
996*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::OutputLayer>,
997*89c4ff92SAndroid Build Coastguard Worker                         &IsLayerOfType<armnn::OutputLayer>));
998*89c4ff92SAndroid Build Coastguard Worker }
999*89c4ff92SAndroid Build Coastguard Worker } // Optimizer TestSuite
1000