xref: /aosp_15_r20/external/armnn/src/backends/cl/ClBackend.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2023 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ClBackend.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "ClBackendContext.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "ClBackendDefaultAllocator.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "ClBackendId.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "ClBackendModelContext.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "ClImportTensorHandleFactory.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "ClLayerSupport.hpp"
13*89c4ff92SAndroid Build Coastguard Worker #include "ClTensorHandleFactory.hpp"
14*89c4ff92SAndroid Build Coastguard Worker #include "ClWorkloadFactory.hpp"
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/BackendRegistry.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeSubgraphUtils.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/ArmComputeUtils.hpp>
21*89c4ff92SAndroid Build Coastguard Worker #include <aclCommon/BaseMemoryManager.hpp>
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendContext.hpp>
24*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IMemoryManager.hpp>
25*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClAdditionWorkload.hpp"
28*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClBatchNormalizationFloatWorkload.hpp"
29*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClConvolution2dWorkload.hpp"
30*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClDepthwiseConvolutionWorkload.hpp"
31*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClDivisionWorkload.hpp"
32*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClFullyConnectedWorkload.hpp"
33*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClMultiplicationWorkload.hpp"
34*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClReduceWorkload.hpp"
35*89c4ff92SAndroid Build Coastguard Worker #include "workloads/ClSubtractionWorkload.hpp"
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker #include <Optimizer.hpp>
38*89c4ff92SAndroid Build Coastguard Worker 
39*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/core/Types.h>
40*89c4ff92SAndroid Build Coastguard Worker #include <arm_compute/runtime/CL/CLBufferAllocator.h>
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker namespace armnn
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker 
GetIdStatic()45*89c4ff92SAndroid Build Coastguard Worker const BackendId& ClBackend::GetIdStatic()
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker     static const BackendId s_Id{ClBackendId()};
48*89c4ff92SAndroid Build Coastguard Worker     return s_Id;
49*89c4ff92SAndroid Build Coastguard Worker }
50*89c4ff92SAndroid Build Coastguard Worker 
CreateMemoryManager() const51*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IMemoryManagerUniquePtr ClBackend::CreateMemoryManager() const
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker     if (m_UsingCustomAllocator)
54*89c4ff92SAndroid Build Coastguard Worker     {
55*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<ClMemoryManager>(m_CustomAllocator);
56*89c4ff92SAndroid Build Coastguard Worker     }
57*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
58*89c4ff92SAndroid Build Coastguard Worker }
59*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr & memoryManager) const60*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
61*89c4ff92SAndroid Build Coastguard Worker     const IBackendInternal::IMemoryManagerSharedPtr& memoryManager) const
62*89c4ff92SAndroid Build Coastguard Worker {
63*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<ClWorkloadFactory>(
64*89c4ff92SAndroid Build Coastguard Worker         PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const ModelOptions & modelOptions) const67*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
68*89c4ff92SAndroid Build Coastguard Worker     const IBackendInternal::IMemoryManagerSharedPtr& memoryManager, const ModelOptions& modelOptions) const
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<ClWorkloadFactory>(
71*89c4ff92SAndroid Build Coastguard Worker         PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
72*89c4ff92SAndroid Build Coastguard Worker }
73*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(TensorHandleFactoryRegistry & registry) const74*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
75*89c4ff92SAndroid Build Coastguard Worker     TensorHandleFactoryRegistry& registry) const
76*89c4ff92SAndroid Build Coastguard Worker {
77*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<ClMemoryManager> memoryManager;
78*89c4ff92SAndroid Build Coastguard Worker     if (m_UsingCustomAllocator)
79*89c4ff92SAndroid Build Coastguard Worker     {
80*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
81*89c4ff92SAndroid Build Coastguard Worker     }
82*89c4ff92SAndroid Build Coastguard Worker     else
83*89c4ff92SAndroid Build Coastguard Worker     {
84*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
85*89c4ff92SAndroid Build Coastguard Worker     }
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
88*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
89*89c4ff92SAndroid Build Coastguard Worker         static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc));
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
92*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterMemoryManager(memoryManager);
95*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(factory));
96*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(importFactory));
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<ClWorkloadFactory>(
99*89c4ff92SAndroid Build Coastguard Worker             PolymorphicPointerDowncast<ClMemoryManager>(memoryManager));
100*89c4ff92SAndroid Build Coastguard Worker }
101*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(TensorHandleFactoryRegistry & registry,const ModelOptions & modelOptions) const102*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
103*89c4ff92SAndroid Build Coastguard Worker     TensorHandleFactoryRegistry& registry, const ModelOptions& modelOptions) const
104*89c4ff92SAndroid Build Coastguard Worker {
105*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<ClMemoryManager> memoryManager;
106*89c4ff92SAndroid Build Coastguard Worker     if (m_UsingCustomAllocator)
107*89c4ff92SAndroid Build Coastguard Worker     {
108*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
109*89c4ff92SAndroid Build Coastguard Worker     }
110*89c4ff92SAndroid Build Coastguard Worker     else
111*89c4ff92SAndroid Build Coastguard Worker     {
112*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
113*89c4ff92SAndroid Build Coastguard Worker     }
114*89c4ff92SAndroid Build Coastguard Worker 
115*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
116*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
117*89c4ff92SAndroid Build Coastguard Worker         static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc));
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
120*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterMemoryManager(memoryManager);
123*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(factory));
124*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(importFactory));
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<ClWorkloadFactory>(
127*89c4ff92SAndroid Build Coastguard Worker         PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker 
CreateWorkloadFactory(TensorHandleFactoryRegistry & registry,const ModelOptions & modelOptions,MemorySourceFlags inputFlags,MemorySourceFlags outputFlags) const130*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr ClBackend::CreateWorkloadFactory(
131*89c4ff92SAndroid Build Coastguard Worker     TensorHandleFactoryRegistry& registry,
132*89c4ff92SAndroid Build Coastguard Worker     const ModelOptions& modelOptions,
133*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags inputFlags,
134*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags outputFlags) const
135*89c4ff92SAndroid Build Coastguard Worker {
136*89c4ff92SAndroid Build Coastguard Worker     // To allow force import if inputFlags/outputFlags are Undefined, set it as Malloc
137*89c4ff92SAndroid Build Coastguard Worker     if (inputFlags == static_cast<MemorySourceFlags>(MemorySource::Undefined))
138*89c4ff92SAndroid Build Coastguard Worker     {
139*89c4ff92SAndroid Build Coastguard Worker         inputFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc);
140*89c4ff92SAndroid Build Coastguard Worker     }
141*89c4ff92SAndroid Build Coastguard Worker     if (outputFlags == static_cast<MemorySourceFlags>(MemorySource::Undefined))
142*89c4ff92SAndroid Build Coastguard Worker     {
143*89c4ff92SAndroid Build Coastguard Worker         outputFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc);
144*89c4ff92SAndroid Build Coastguard Worker     }
145*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<ClMemoryManager> memoryManager;
146*89c4ff92SAndroid Build Coastguard Worker     if (m_UsingCustomAllocator)
147*89c4ff92SAndroid Build Coastguard Worker     {
148*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
149*89c4ff92SAndroid Build Coastguard Worker     }
150*89c4ff92SAndroid Build Coastguard Worker     else
151*89c4ff92SAndroid Build Coastguard Worker     {
152*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
153*89c4ff92SAndroid Build Coastguard Worker     }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
156*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
157*89c4ff92SAndroid Build Coastguard Worker             inputFlags, outputFlags);
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
160*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterMemoryManager(memoryManager);
163*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(factory));
164*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(importFactory));
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<ClWorkloadFactory>(
167*89c4ff92SAndroid Build Coastguard Worker         PolymorphicPointerDowncast<ClMemoryManager>(memoryManager), CreateBackendSpecificModelContext(modelOptions));
168*89c4ff92SAndroid Build Coastguard Worker }
169*89c4ff92SAndroid Build Coastguard Worker 
GetHandleFactoryPreferences() const170*89c4ff92SAndroid Build Coastguard Worker std::vector<ITensorHandleFactory::FactoryId> ClBackend::GetHandleFactoryPreferences() const
171*89c4ff92SAndroid Build Coastguard Worker {
172*89c4ff92SAndroid Build Coastguard Worker     return std::vector<ITensorHandleFactory::FactoryId> {ClTensorHandleFactory::GetIdStatic(),
173*89c4ff92SAndroid Build Coastguard Worker                                                          ClImportTensorHandleFactory::GetIdStatic()};
174*89c4ff92SAndroid Build Coastguard Worker }
175*89c4ff92SAndroid Build Coastguard Worker 
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry)176*89c4ff92SAndroid Build Coastguard Worker void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry)
177*89c4ff92SAndroid Build Coastguard Worker {
178*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<ClMemoryManager> memoryManager;
179*89c4ff92SAndroid Build Coastguard Worker     if (m_UsingCustomAllocator)
180*89c4ff92SAndroid Build Coastguard Worker     {
181*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
182*89c4ff92SAndroid Build Coastguard Worker     }
183*89c4ff92SAndroid Build Coastguard Worker     else
184*89c4ff92SAndroid Build Coastguard Worker     {
185*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
186*89c4ff92SAndroid Build Coastguard Worker     }
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
189*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
190*89c4ff92SAndroid Build Coastguard Worker         static_cast<MemorySourceFlags>(MemorySource::Malloc), static_cast<MemorySourceFlags>(MemorySource::Malloc));
191*89c4ff92SAndroid Build Coastguard Worker 
192*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
193*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterMemoryManager(memoryManager);
196*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(factory));
197*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(importFactory));
198*89c4ff92SAndroid Build Coastguard Worker 
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker 
RegisterTensorHandleFactories(TensorHandleFactoryRegistry & registry,MemorySourceFlags inputFlags,MemorySourceFlags outputFlags)201*89c4ff92SAndroid Build Coastguard Worker void ClBackend::RegisterTensorHandleFactories(TensorHandleFactoryRegistry& registry,
202*89c4ff92SAndroid Build Coastguard Worker                                               MemorySourceFlags inputFlags,
203*89c4ff92SAndroid Build Coastguard Worker                                               MemorySourceFlags outputFlags)
204*89c4ff92SAndroid Build Coastguard Worker {
205*89c4ff92SAndroid Build Coastguard Worker     // To allow force import if inputFlags/outputFlags are Undefined, set it as Malloc
206*89c4ff92SAndroid Build Coastguard Worker     if (inputFlags == static_cast<MemorySourceFlags>(MemorySource::Undefined))
207*89c4ff92SAndroid Build Coastguard Worker     {
208*89c4ff92SAndroid Build Coastguard Worker         inputFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc);
209*89c4ff92SAndroid Build Coastguard Worker     }
210*89c4ff92SAndroid Build Coastguard Worker     if (outputFlags == static_cast<MemorySourceFlags>(MemorySource::Undefined))
211*89c4ff92SAndroid Build Coastguard Worker     {
212*89c4ff92SAndroid Build Coastguard Worker         outputFlags = static_cast<MemorySourceFlags>(MemorySource::Malloc);
213*89c4ff92SAndroid Build Coastguard Worker     }
214*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<ClMemoryManager> memoryManager;
215*89c4ff92SAndroid Build Coastguard Worker     if (m_UsingCustomAllocator)
216*89c4ff92SAndroid Build Coastguard Worker     {
217*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(m_CustomAllocator);
218*89c4ff92SAndroid Build Coastguard Worker     }
219*89c4ff92SAndroid Build Coastguard Worker     else
220*89c4ff92SAndroid Build Coastguard Worker     {
221*89c4ff92SAndroid Build Coastguard Worker         memoryManager = std::make_shared<ClMemoryManager>(std::make_unique<arm_compute::CLBufferAllocator>());
222*89c4ff92SAndroid Build Coastguard Worker     }
223*89c4ff92SAndroid Build Coastguard Worker 
224*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> factory = std::make_unique<ClTensorHandleFactory>(memoryManager);
225*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandleFactory> importFactory = std::make_unique<ClImportTensorHandleFactory>(
226*89c4ff92SAndroid Build Coastguard Worker             inputFlags, outputFlags);
227*89c4ff92SAndroid Build Coastguard Worker 
228*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(factory->GetId(), importFactory->GetId());
229*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterCopyAndImportFactoryPair(importFactory->GetId(), factory->GetId());
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterMemoryManager(memoryManager);
232*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(factory));
233*89c4ff92SAndroid Build Coastguard Worker     registry.RegisterFactory(std::move(importFactory));
234*89c4ff92SAndroid Build Coastguard Worker }
235*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendContext(const IRuntime::CreationOptions & options) const236*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendContextPtr ClBackend::CreateBackendContext(const IRuntime::CreationOptions& options) const
237*89c4ff92SAndroid Build Coastguard Worker {
238*89c4ff92SAndroid Build Coastguard Worker     return IBackendContextPtr{new ClBackendContext{options}};
239*89c4ff92SAndroid Build Coastguard Worker }
240*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendProfilingContext(const IRuntime::CreationOptions &,IBackendProfilingPtr &)241*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendProfilingContextPtr ClBackend::CreateBackendProfilingContext(
242*89c4ff92SAndroid Build Coastguard Worker     const IRuntime::CreationOptions&, IBackendProfilingPtr&)
243*89c4ff92SAndroid Build Coastguard Worker {
244*89c4ff92SAndroid Build Coastguard Worker     return IBackendProfilingContextPtr{};
245*89c4ff92SAndroid Build Coastguard Worker }
246*89c4ff92SAndroid Build Coastguard Worker 
CreateBackendSpecificModelContext(const ModelOptions & modelOptions) const247*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendSpecificModelContextPtr ClBackend::CreateBackendSpecificModelContext(
248*89c4ff92SAndroid Build Coastguard Worker     const ModelOptions& modelOptions) const
249*89c4ff92SAndroid Build Coastguard Worker {
250*89c4ff92SAndroid Build Coastguard Worker     return IBackendSpecificModelContextPtr{new ClBackendModelContext{modelOptions}};
251*89c4ff92SAndroid Build Coastguard Worker }
252*89c4ff92SAndroid Build Coastguard Worker 
GetLayerSupport() const253*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr ClBackend::GetLayerSupport() const
254*89c4ff92SAndroid Build Coastguard Worker {
255*89c4ff92SAndroid Build Coastguard Worker     static ILayerSupportSharedPtr layerSupport
256*89c4ff92SAndroid Build Coastguard Worker         {
257*89c4ff92SAndroid Build Coastguard Worker             new ClLayerSupport(IBackendInternal::IBackendSpecificModelContextPtr{})
258*89c4ff92SAndroid Build Coastguard Worker         };
259*89c4ff92SAndroid Build Coastguard Worker     return layerSupport;
260*89c4ff92SAndroid Build Coastguard Worker }
261*89c4ff92SAndroid Build Coastguard Worker 
GetLayerSupport(const ModelOptions & modelOptions) const262*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr ClBackend::GetLayerSupport(const ModelOptions& modelOptions) const
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker     static ILayerSupportSharedPtr layerSupport
265*89c4ff92SAndroid Build Coastguard Worker     {
266*89c4ff92SAndroid Build Coastguard Worker         new ClLayerSupport(CreateBackendSpecificModelContext(modelOptions))
267*89c4ff92SAndroid Build Coastguard Worker     };
268*89c4ff92SAndroid Build Coastguard Worker     return layerSupport;
269*89c4ff92SAndroid Build Coastguard Worker }
270*89c4ff92SAndroid Build Coastguard Worker 
GetDefaultAllocator() const271*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ICustomAllocator> ClBackend::GetDefaultAllocator() const
272*89c4ff92SAndroid Build Coastguard Worker {
273*89c4ff92SAndroid Build Coastguard Worker     return std::make_unique<ClBackendDefaultAllocator>();
274*89c4ff92SAndroid Build Coastguard Worker }
275*89c4ff92SAndroid Build Coastguard Worker 
OptimizeSubgraphView(const SubgraphView & subgraph,const ModelOptions & modelOptions) const276*89c4ff92SAndroid Build Coastguard Worker OptimizationViews ClBackend::OptimizeSubgraphView(const SubgraphView& subgraph,
277*89c4ff92SAndroid Build Coastguard Worker                                                   const ModelOptions& modelOptions) const
278*89c4ff92SAndroid Build Coastguard Worker {
279*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews optimizationViews(modelOptions);
280*89c4ff92SAndroid Build Coastguard Worker 
281*89c4ff92SAndroid Build Coastguard Worker     auto it = subgraph.endIConnectable();
282*89c4ff92SAndroid Build Coastguard Worker     bool isFastMathEnabled = false;
283*89c4ff92SAndroid Build Coastguard Worker     std::map<LayerGuid, Layer*> untouched;
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker     while (it != subgraph.beginIConnectable())
286*89c4ff92SAndroid Build Coastguard Worker     {
287*89c4ff92SAndroid Build Coastguard Worker         --it;
288*89c4ff92SAndroid Build Coastguard Worker         Layer& base = *(PolymorphicDowncast<Layer*>(*it));
289*89c4ff92SAndroid Build Coastguard Worker         untouched.insert({base.GetGuid(), &base});
290*89c4ff92SAndroid Build Coastguard Worker     }
291*89c4ff92SAndroid Build Coastguard Worker 
292*89c4ff92SAndroid Build Coastguard Worker     it = subgraph.endIConnectable();
293*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMCOMPUTECL_ENABLED)
294*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IBackendSpecificModelContextPtr modelContextPtr = CreateBackendSpecificModelContext(modelOptions);
295*89c4ff92SAndroid Build Coastguard Worker 
296*89c4ff92SAndroid Build Coastguard Worker     if (modelContextPtr)
297*89c4ff92SAndroid Build Coastguard Worker     {
298*89c4ff92SAndroid Build Coastguard Worker         auto clModelOptions = dynamic_cast<ClBackendModelContext*>(modelContextPtr.get());
299*89c4ff92SAndroid Build Coastguard Worker         if (clModelOptions)
300*89c4ff92SAndroid Build Coastguard Worker         {
301*89c4ff92SAndroid Build Coastguard Worker             isFastMathEnabled = clModelOptions->IsFastMathEnabled();
302*89c4ff92SAndroid Build Coastguard Worker         }
303*89c4ff92SAndroid Build Coastguard Worker     }
304*89c4ff92SAndroid Build Coastguard Worker #endif
305*89c4ff92SAndroid Build Coastguard Worker     while (it != subgraph.beginIConnectable())
306*89c4ff92SAndroid Build Coastguard Worker     {
307*89c4ff92SAndroid Build Coastguard Worker         --it;
308*89c4ff92SAndroid Build Coastguard Worker         Layer& base = *(PolymorphicDowncast<Layer*>(*it));
309*89c4ff92SAndroid Build Coastguard Worker 
310*89c4ff92SAndroid Build Coastguard Worker         // Fuse activation into previous layer if supported by backend
311*89c4ff92SAndroid Build Coastguard Worker         if ((base.GetType() == LayerType::DepthwiseConvolution2d || base.GetType() == LayerType::Convolution2d
312*89c4ff92SAndroid Build Coastguard Worker             || base.GetType() == LayerType::BatchNormalization || base.GetType() == LayerType::FullyConnected
313*89c4ff92SAndroid Build Coastguard Worker             || base.GetType() == LayerType::Addition || base.GetType() == LayerType::Multiplication
314*89c4ff92SAndroid Build Coastguard Worker             || base.GetType() == LayerType::Subtraction || base.GetType() == LayerType::Division
315*89c4ff92SAndroid Build Coastguard Worker             || base.GetType() == LayerType::ElementwiseBinary)
316*89c4ff92SAndroid Build Coastguard Worker             && (base.GetAdditionalInformation<ActivationDescriptor>() == nullptr))
317*89c4ff92SAndroid Build Coastguard Worker         {
318*89c4ff92SAndroid Build Coastguard Worker             for (auto output = base.BeginOutputSlots(); output != base.EndOutputSlots(); ++output)
319*89c4ff92SAndroid Build Coastguard Worker             {
320*89c4ff92SAndroid Build Coastguard Worker                 if (output->GetNumConnections() == 1)
321*89c4ff92SAndroid Build Coastguard Worker                 {
322*89c4ff92SAndroid Build Coastguard Worker                     for (auto&& childInput : output->GetConnections())
323*89c4ff92SAndroid Build Coastguard Worker                     {
324*89c4ff92SAndroid Build Coastguard Worker                         if ((childInput->GetOwningLayer().GetType() == LayerType::Activation) &&
325*89c4ff92SAndroid Build Coastguard Worker                             (checkDataTypeInputandOutput(childInput->GetOwningLayer())))
326*89c4ff92SAndroid Build Coastguard Worker                         {
327*89c4ff92SAndroid Build Coastguard Worker                             Layer& child = childInput->GetOwningLayer();
328*89c4ff92SAndroid Build Coastguard Worker 
329*89c4ff92SAndroid Build Coastguard Worker                             auto* activationLayer = PolymorphicDowncast<ActivationLayer*>(&child);
330*89c4ff92SAndroid Build Coastguard Worker 
331*89c4ff92SAndroid Build Coastguard Worker                             const std::string name = std::string("fused-") + child.GetName() + std::string("-into-") +
332*89c4ff92SAndroid Build Coastguard Worker                                                      base.GetName();
333*89c4ff92SAndroid Build Coastguard Worker 
334*89c4ff92SAndroid Build Coastguard Worker                             // Get params from activation layer
335*89c4ff92SAndroid Build Coastguard Worker                             ActivationDescriptor activationDesc = activationLayer->GetParameters();
336*89c4ff92SAndroid Build Coastguard Worker 
337*89c4ff92SAndroid Build Coastguard Worker                             if (base.GetType() == LayerType::Convolution2d)
338*89c4ff92SAndroid Build Coastguard Worker                             {
339*89c4ff92SAndroid Build Coastguard Worker                                 Convolution2dLayer* baseLayer = PolymorphicDowncast<Convolution2dLayer*>(&base);
340*89c4ff92SAndroid Build Coastguard Worker 
341*89c4ff92SAndroid Build Coastguard Worker                                 Optional<TensorInfo> biases;
342*89c4ff92SAndroid Build Coastguard Worker 
343*89c4ff92SAndroid Build Coastguard Worker                                 if (baseLayer->GetParameters().m_BiasEnabled)
344*89c4ff92SAndroid Build Coastguard Worker                                 {
345*89c4ff92SAndroid Build Coastguard Worker                                     biases = baseLayer->GetInputSlot(2).GetConnectedOutputSlot()->GetTensorInfo();
346*89c4ff92SAndroid Build Coastguard Worker                                 }
347*89c4ff92SAndroid Build Coastguard Worker 
348*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = ClConvolution2dWorkloadValidate(
349*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
350*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
351*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetParameters(),
352*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
353*89c4ff92SAndroid Build Coastguard Worker                                         biases,
354*89c4ff92SAndroid Build Coastguard Worker                                         isFastMathEnabled,
355*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
356*89c4ff92SAndroid Build Coastguard Worker 
357*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
358*89c4ff92SAndroid Build Coastguard Worker                                 {
359*89c4ff92SAndroid Build Coastguard Worker                                     FuseConvolution2dLayer<Convolution2dLayer>(optimizationViews,
360*89c4ff92SAndroid Build Coastguard Worker                                                                                baseLayer,
361*89c4ff92SAndroid Build Coastguard Worker                                                                                activationLayer,
362*89c4ff92SAndroid Build Coastguard Worker                                                                                activationDesc,
363*89c4ff92SAndroid Build Coastguard Worker                                                                                name);
364*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
365*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
366*89c4ff92SAndroid Build Coastguard Worker                                 }
367*89c4ff92SAndroid Build Coastguard Worker                             }
368*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::DepthwiseConvolution2d)
369*89c4ff92SAndroid Build Coastguard Worker                             {
370*89c4ff92SAndroid Build Coastguard Worker                                 DepthwiseConvolution2dLayer* baseLayer =
371*89c4ff92SAndroid Build Coastguard Worker                                         PolymorphicDowncast<DepthwiseConvolution2dLayer*>(&base);
372*89c4ff92SAndroid Build Coastguard Worker 
373*89c4ff92SAndroid Build Coastguard Worker                                 Optional<TensorInfo> biases;
374*89c4ff92SAndroid Build Coastguard Worker 
375*89c4ff92SAndroid Build Coastguard Worker                                 if (baseLayer->GetParameters().m_BiasEnabled)
376*89c4ff92SAndroid Build Coastguard Worker                                 {
377*89c4ff92SAndroid Build Coastguard Worker                                     biases = baseLayer->GetInputSlot(2).GetConnectedOutputSlot()->GetTensorInfo();
378*89c4ff92SAndroid Build Coastguard Worker                                 }
379*89c4ff92SAndroid Build Coastguard Worker 
380*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = ClDepthwiseConvolutionWorkloadValidate(
381*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
382*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
383*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetParameters(),
384*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
385*89c4ff92SAndroid Build Coastguard Worker                                         biases,
386*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
387*89c4ff92SAndroid Build Coastguard Worker 
388*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
389*89c4ff92SAndroid Build Coastguard Worker                                 {
390*89c4ff92SAndroid Build Coastguard Worker                                     FuseDepthwiseConvolution2dLayer<DepthwiseConvolution2dLayer>(optimizationViews,
391*89c4ff92SAndroid Build Coastguard Worker                                                                                                  baseLayer,
392*89c4ff92SAndroid Build Coastguard Worker                                                                                                  activationLayer,
393*89c4ff92SAndroid Build Coastguard Worker                                                                                                  activationDesc,
394*89c4ff92SAndroid Build Coastguard Worker                                                                                                  name);
395*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
396*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
397*89c4ff92SAndroid Build Coastguard Worker                                 }
398*89c4ff92SAndroid Build Coastguard Worker                             }
399*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::FullyConnected)
400*89c4ff92SAndroid Build Coastguard Worker                             {
401*89c4ff92SAndroid Build Coastguard Worker                                 FullyConnectedLayer* baseLayer = PolymorphicDowncast<FullyConnectedLayer*>(&base);
402*89c4ff92SAndroid Build Coastguard Worker                                 FullyConnectedDescriptor descriptor = baseLayer->GetParameters();
403*89c4ff92SAndroid Build Coastguard Worker 
404*89c4ff92SAndroid Build Coastguard Worker                                 // As bias is optional only try to get TensorInfo from input if bias is enabled.
405*89c4ff92SAndroid Build Coastguard Worker                                 Optional<TensorInfo> biases;
406*89c4ff92SAndroid Build Coastguard Worker                                 if (descriptor.m_BiasEnabled)
407*89c4ff92SAndroid Build Coastguard Worker                                 {
408*89c4ff92SAndroid Build Coastguard Worker                                     biases = baseLayer->GetInputSlot(2).GetConnectedOutputSlot()->GetTensorInfo();
409*89c4ff92SAndroid Build Coastguard Worker                                 }
410*89c4ff92SAndroid Build Coastguard Worker 
411*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = ClFullyConnectedWorkloadValidate(
412*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
413*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
414*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
415*89c4ff92SAndroid Build Coastguard Worker                                         biases,
416*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetParameters(),
417*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
418*89c4ff92SAndroid Build Coastguard Worker 
419*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
420*89c4ff92SAndroid Build Coastguard Worker                                 {
421*89c4ff92SAndroid Build Coastguard Worker                                     FuseFullyConnectedLayer<FullyConnectedLayer>(optimizationViews,
422*89c4ff92SAndroid Build Coastguard Worker                                                                                  baseLayer,
423*89c4ff92SAndroid Build Coastguard Worker                                                                                  activationLayer,
424*89c4ff92SAndroid Build Coastguard Worker                                                                                  activationDesc,
425*89c4ff92SAndroid Build Coastguard Worker                                                                                  name);
426*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
427*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
428*89c4ff92SAndroid Build Coastguard Worker                                 }
429*89c4ff92SAndroid Build Coastguard Worker                             }
430*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::BatchNormalization)
431*89c4ff92SAndroid Build Coastguard Worker                             {
432*89c4ff92SAndroid Build Coastguard Worker                                 BatchNormalizationLayer* baseLayer =
433*89c4ff92SAndroid Build Coastguard Worker                                         PolymorphicDowncast<BatchNormalizationLayer*>(&base);
434*89c4ff92SAndroid Build Coastguard Worker 
435*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = ClBatchNormalizationValidate(
436*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
437*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
438*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->m_Mean->GetTensorInfo(),
439*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->m_Variance->GetTensorInfo(),
440*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->m_Beta->GetTensorInfo(),
441*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->m_Gamma->GetTensorInfo(),
442*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetParameters(),
443*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
444*89c4ff92SAndroid Build Coastguard Worker 
445*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
446*89c4ff92SAndroid Build Coastguard Worker                                 {
447*89c4ff92SAndroid Build Coastguard Worker                                     BatchNormalizationLayer* replacementLayer =
448*89c4ff92SAndroid Build Coastguard Worker                                         FuseBatchNormalizationLayer<BatchNormalizationLayer>(optimizationViews,
449*89c4ff92SAndroid Build Coastguard Worker                                                                                              baseLayer,
450*89c4ff92SAndroid Build Coastguard Worker                                                                                              activationLayer,
451*89c4ff92SAndroid Build Coastguard Worker                                                                                              activationDesc,
452*89c4ff92SAndroid Build Coastguard Worker                                                                                              name);
453*89c4ff92SAndroid Build Coastguard Worker 
454*89c4ff92SAndroid Build Coastguard Worker                                     replacementLayer->m_Beta     = std::move(baseLayer->m_Beta);
455*89c4ff92SAndroid Build Coastguard Worker                                     replacementLayer->m_Gamma    = std::move(baseLayer->m_Gamma);
456*89c4ff92SAndroid Build Coastguard Worker                                     replacementLayer->m_Mean     = std::move(baseLayer->m_Mean);
457*89c4ff92SAndroid Build Coastguard Worker                                     replacementLayer->m_Variance = std::move(baseLayer->m_Variance);
458*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
459*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
460*89c4ff92SAndroid Build Coastguard Worker                                 }
461*89c4ff92SAndroid Build Coastguard Worker                             }
462*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::Addition)
463*89c4ff92SAndroid Build Coastguard Worker                             {
464*89c4ff92SAndroid Build Coastguard Worker                                 AdditionLayer* baseLayer = PolymorphicDowncast<AdditionLayer*>(&base);
465*89c4ff92SAndroid Build Coastguard Worker 
466*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = ClAdditionValidate(
467*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
468*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
469*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
470*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
471*89c4ff92SAndroid Build Coastguard Worker 
472*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
473*89c4ff92SAndroid Build Coastguard Worker                                 {
474*89c4ff92SAndroid Build Coastguard Worker                                     FuseAdditionLayer<AdditionLayer>(optimizationViews,
475*89c4ff92SAndroid Build Coastguard Worker                                                                      baseLayer,
476*89c4ff92SAndroid Build Coastguard Worker                                                                      activationLayer,
477*89c4ff92SAndroid Build Coastguard Worker                                                                      activationDesc,
478*89c4ff92SAndroid Build Coastguard Worker                                                                      name);
479*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
480*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
481*89c4ff92SAndroid Build Coastguard Worker                                 }
482*89c4ff92SAndroid Build Coastguard Worker                             }
483*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::Division)
484*89c4ff92SAndroid Build Coastguard Worker                             {
485*89c4ff92SAndroid Build Coastguard Worker                                 DivisionLayer* baseLayer = PolymorphicDowncast<DivisionLayer*>(&base);
486*89c4ff92SAndroid Build Coastguard Worker 
487*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = ClDivisionWorkloadValidate(
488*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
489*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
490*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
491*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
492*89c4ff92SAndroid Build Coastguard Worker 
493*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
494*89c4ff92SAndroid Build Coastguard Worker                                 {
495*89c4ff92SAndroid Build Coastguard Worker                                     FuseDivisionLayer<DivisionLayer>(optimizationViews,
496*89c4ff92SAndroid Build Coastguard Worker                                                                      baseLayer,
497*89c4ff92SAndroid Build Coastguard Worker                                                                      activationLayer,
498*89c4ff92SAndroid Build Coastguard Worker                                                                      activationDesc,
499*89c4ff92SAndroid Build Coastguard Worker                                                                      name);
500*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
501*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
502*89c4ff92SAndroid Build Coastguard Worker                                 }
503*89c4ff92SAndroid Build Coastguard Worker                             }
504*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::Multiplication)
505*89c4ff92SAndroid Build Coastguard Worker                             {
506*89c4ff92SAndroid Build Coastguard Worker                                 MultiplicationLayer* baseLayer = PolymorphicDowncast<MultiplicationLayer*>(&base);
507*89c4ff92SAndroid Build Coastguard Worker 
508*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = ClMultiplicationWorkloadValidate(
509*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
510*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
511*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
512*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
513*89c4ff92SAndroid Build Coastguard Worker 
514*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
515*89c4ff92SAndroid Build Coastguard Worker                                 {
516*89c4ff92SAndroid Build Coastguard Worker                                     FuseMultiplicationLayer<MultiplicationLayer>(optimizationViews,
517*89c4ff92SAndroid Build Coastguard Worker                                                                                  baseLayer,
518*89c4ff92SAndroid Build Coastguard Worker                                                                                  activationLayer,
519*89c4ff92SAndroid Build Coastguard Worker                                                                                  activationDesc,
520*89c4ff92SAndroid Build Coastguard Worker                                                                                  name);
521*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
522*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
523*89c4ff92SAndroid Build Coastguard Worker                                 }
524*89c4ff92SAndroid Build Coastguard Worker                             }
525*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::Subtraction)
526*89c4ff92SAndroid Build Coastguard Worker                             {
527*89c4ff92SAndroid Build Coastguard Worker                                 SubtractionLayer* baseLayer = PolymorphicDowncast<SubtractionLayer*>(&base);
528*89c4ff92SAndroid Build Coastguard Worker 
529*89c4ff92SAndroid Build Coastguard Worker                                 arm_compute::Status status = ClSubtractionValidate(
530*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
531*89c4ff92SAndroid Build Coastguard Worker                                         baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
532*89c4ff92SAndroid Build Coastguard Worker                                         activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
533*89c4ff92SAndroid Build Coastguard Worker                                         &activationDesc);
534*89c4ff92SAndroid Build Coastguard Worker 
535*89c4ff92SAndroid Build Coastguard Worker                                 if (status)
536*89c4ff92SAndroid Build Coastguard Worker                                 {
537*89c4ff92SAndroid Build Coastguard Worker                                     FuseSubtractionLayer<SubtractionLayer>(optimizationViews,
538*89c4ff92SAndroid Build Coastguard Worker                                                                            baseLayer,
539*89c4ff92SAndroid Build Coastguard Worker                                                                            activationLayer,
540*89c4ff92SAndroid Build Coastguard Worker                                                                            activationDesc,
541*89c4ff92SAndroid Build Coastguard Worker                                                                            name);
542*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(baseLayer->GetGuid());
543*89c4ff92SAndroid Build Coastguard Worker                                     untouched.erase(activationLayer->GetGuid());
544*89c4ff92SAndroid Build Coastguard Worker                                 }
545*89c4ff92SAndroid Build Coastguard Worker                             }
546*89c4ff92SAndroid Build Coastguard Worker                             else if (base.GetType() == LayerType::ElementwiseBinary)
547*89c4ff92SAndroid Build Coastguard Worker                             {
548*89c4ff92SAndroid Build Coastguard Worker                                 ElementwiseBinaryLayer* baseLayer = PolymorphicDowncast<ElementwiseBinaryLayer*>(&base);
549*89c4ff92SAndroid Build Coastguard Worker 
550*89c4ff92SAndroid Build Coastguard Worker                                 if (baseLayer->GetParameters().m_Operation == BinaryOperation::Add)
551*89c4ff92SAndroid Build Coastguard Worker                                 {
552*89c4ff92SAndroid Build Coastguard Worker                                     arm_compute::Status status = ClAdditionValidate(
553*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
554*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
555*89c4ff92SAndroid Build Coastguard Worker                                             activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
556*89c4ff92SAndroid Build Coastguard Worker                                             &activationDesc);
557*89c4ff92SAndroid Build Coastguard Worker 
558*89c4ff92SAndroid Build Coastguard Worker                                     if (status)
559*89c4ff92SAndroid Build Coastguard Worker                                     {
560*89c4ff92SAndroid Build Coastguard Worker                                         FuseElementwiseBinaryLayer<ElementwiseBinaryLayer>(optimizationViews,
561*89c4ff92SAndroid Build Coastguard Worker                                                                                            baseLayer,
562*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationLayer,
563*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationDesc,
564*89c4ff92SAndroid Build Coastguard Worker                                                                                            BinaryOperation::Add,
565*89c4ff92SAndroid Build Coastguard Worker                                                                                            name);
566*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(baseLayer->GetGuid());
567*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(activationLayer->GetGuid());
568*89c4ff92SAndroid Build Coastguard Worker                                     }
569*89c4ff92SAndroid Build Coastguard Worker                                 }
570*89c4ff92SAndroid Build Coastguard Worker                                 else if (baseLayer->GetParameters().m_Operation == BinaryOperation::Div)
571*89c4ff92SAndroid Build Coastguard Worker                                 {
572*89c4ff92SAndroid Build Coastguard Worker                                     arm_compute::Status status = ClDivisionWorkloadValidate(
573*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
574*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
575*89c4ff92SAndroid Build Coastguard Worker                                             activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
576*89c4ff92SAndroid Build Coastguard Worker                                             &activationDesc);
577*89c4ff92SAndroid Build Coastguard Worker 
578*89c4ff92SAndroid Build Coastguard Worker                                     if (status)
579*89c4ff92SAndroid Build Coastguard Worker                                     {
580*89c4ff92SAndroid Build Coastguard Worker                                         FuseElementwiseBinaryLayer<ElementwiseBinaryLayer>(optimizationViews,
581*89c4ff92SAndroid Build Coastguard Worker                                                                                            baseLayer,
582*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationLayer,
583*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationDesc,
584*89c4ff92SAndroid Build Coastguard Worker                                                                                            BinaryOperation::Div,
585*89c4ff92SAndroid Build Coastguard Worker                                                                                            name);
586*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(baseLayer->GetGuid());
587*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(activationLayer->GetGuid());
588*89c4ff92SAndroid Build Coastguard Worker                                     }
589*89c4ff92SAndroid Build Coastguard Worker                                 }
590*89c4ff92SAndroid Build Coastguard Worker                                 else if (baseLayer->GetParameters().m_Operation == BinaryOperation::Mul)
591*89c4ff92SAndroid Build Coastguard Worker                                 {
592*89c4ff92SAndroid Build Coastguard Worker                                     arm_compute::Status status = ClMultiplicationWorkloadValidate(
593*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
594*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
595*89c4ff92SAndroid Build Coastguard Worker                                             activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
596*89c4ff92SAndroid Build Coastguard Worker                                             &activationDesc);
597*89c4ff92SAndroid Build Coastguard Worker 
598*89c4ff92SAndroid Build Coastguard Worker                                     if (status)
599*89c4ff92SAndroid Build Coastguard Worker                                     {
600*89c4ff92SAndroid Build Coastguard Worker                                         FuseElementwiseBinaryLayer<ElementwiseBinaryLayer>(optimizationViews,
601*89c4ff92SAndroid Build Coastguard Worker                                                                                            baseLayer,
602*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationLayer,
603*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationDesc,
604*89c4ff92SAndroid Build Coastguard Worker                                                                                            BinaryOperation::Mul,
605*89c4ff92SAndroid Build Coastguard Worker                                                                                            name);
606*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(baseLayer->GetGuid());
607*89c4ff92SAndroid Build Coastguard Worker                                         untouched.erase(activationLayer->GetGuid());
608*89c4ff92SAndroid Build Coastguard Worker                                     }
609*89c4ff92SAndroid Build Coastguard Worker                                 }
610*89c4ff92SAndroid Build Coastguard Worker                                 else if (baseLayer->GetParameters().m_Operation == BinaryOperation::Sub)
611*89c4ff92SAndroid Build Coastguard Worker                                 {
612*89c4ff92SAndroid Build Coastguard Worker                                     arm_compute::Status status = ClSubtractionValidate(
613*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
614*89c4ff92SAndroid Build Coastguard Worker                                             baseLayer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(),
615*89c4ff92SAndroid Build Coastguard Worker                                             activationLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(),
616*89c4ff92SAndroid Build Coastguard Worker                                             &activationDesc);
617*89c4ff92SAndroid Build Coastguard Worker 
618*89c4ff92SAndroid Build Coastguard Worker                                     if (status)
619*89c4ff92SAndroid Build Coastguard Worker                                     {
620*89c4ff92SAndroid Build Coastguard Worker                                         FuseElementwiseBinaryLayer<ElementwiseBinaryLayer>(optimizationViews,
621*89c4ff92SAndroid Build Coastguard Worker                                                                                            baseLayer,
622*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationLayer,
623*89c4ff92SAndroid Build Coastguard Worker                                                                                            activationDesc,
624*89c4ff92SAndroid Build Coastguard Worker                                                                                            BinaryOperation::Sub,
625*89c4ff92SAndroid Build Coastguard Worker                                                                                            name);
626*89c4ff92SAndroid Build Coastguard Worker                                     }
627*89c4ff92SAndroid Build Coastguard Worker                                 }
628*89c4ff92SAndroid Build Coastguard Worker                                 // No fusion available for other BinaryOperations
629*89c4ff92SAndroid Build Coastguard Worker                             }
630*89c4ff92SAndroid Build Coastguard Worker                         }
631*89c4ff92SAndroid Build Coastguard Worker                     }
632*89c4ff92SAndroid Build Coastguard Worker                 }
633*89c4ff92SAndroid Build Coastguard Worker             }
634*89c4ff92SAndroid Build Coastguard Worker         }
635*89c4ff92SAndroid Build Coastguard Worker 
636*89c4ff92SAndroid Build Coastguard Worker         // Separate reduce layer with multiple axes into multiple reduce layers with 1 axis.
637*89c4ff92SAndroid Build Coastguard Worker         if (base.GetType() == LayerType::Reduce)
638*89c4ff92SAndroid Build Coastguard Worker         {
639*89c4ff92SAndroid Build Coastguard Worker             ReduceLayer* baseLayer            = PolymorphicDowncast<ReduceLayer*>(&base);
640*89c4ff92SAndroid Build Coastguard Worker             ReduceDescriptor reduceDescriptor = baseLayer->GetParameters();
641*89c4ff92SAndroid Build Coastguard Worker 
642*89c4ff92SAndroid Build Coastguard Worker             if (!reduceDescriptor.m_vAxis.empty() && reduceDescriptor.m_vAxis.size() > 1)
643*89c4ff92SAndroid Build Coastguard Worker             {
644*89c4ff92SAndroid Build Coastguard Worker                 // Add new layers to the graph and connect them.
645*89c4ff92SAndroid Build Coastguard Worker                 std::vector<IConnectableLayer*> layers = ChainReduceLayers<ReduceLayer>(optimizationViews,
646*89c4ff92SAndroid Build Coastguard Worker                                                                                         baseLayer,
647*89c4ff92SAndroid Build Coastguard Worker                                                                                         reduceDescriptor);
648*89c4ff92SAndroid Build Coastguard Worker 
649*89c4ff92SAndroid Build Coastguard Worker                 // Replace existing baselayer with new subgraph.
650*89c4ff92SAndroid Build Coastguard Worker                 ReplaceLayers<ReduceLayer>(optimizationViews, baseLayer, layers);
651*89c4ff92SAndroid Build Coastguard Worker                 untouched.erase(baseLayer->GetGuid());
652*89c4ff92SAndroid Build Coastguard Worker             }
653*89c4ff92SAndroid Build Coastguard Worker         }
654*89c4ff92SAndroid Build Coastguard Worker 
655*89c4ff92SAndroid Build Coastguard Worker         // Special case to fuse padding into average pooling 2d for quantized datatype.
656*89c4ff92SAndroid Build Coastguard Worker         // Required to be done as a backend specific optimization as Neon does not support this special case.
657*89c4ff92SAndroid Build Coastguard Worker         if (base.GetType() == LayerType::Pooling2d)
658*89c4ff92SAndroid Build Coastguard Worker         {
659*89c4ff92SAndroid Build Coastguard Worker             Pooling2dLayer* baseLayer = PolymorphicDowncast<Pooling2dLayer*>(&base);
660*89c4ff92SAndroid Build Coastguard Worker             Pooling2dDescriptor poolingDescriptor = baseLayer->GetParameters();
661*89c4ff92SAndroid Build Coastguard Worker 
662*89c4ff92SAndroid Build Coastguard Worker             if (baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer().GetType() == LayerType::Pad)
663*89c4ff92SAndroid Build Coastguard Worker             {
664*89c4ff92SAndroid Build Coastguard Worker                 PadLayer* padLayer = PolymorphicDowncast<PadLayer*>(
665*89c4ff92SAndroid Build Coastguard Worker                     &baseLayer->GetInputSlot(0).GetConnectedOutputSlot()->GetOwningLayer());
666*89c4ff92SAndroid Build Coastguard Worker                 if (padLayer->GetOutputSlot(0).GetNumConnections() == 1 &&
667*89c4ff92SAndroid Build Coastguard Worker                     optimizations::pad_fold::TryFoldPadIntoLayer2d(padLayer->GetParameters(),
668*89c4ff92SAndroid Build Coastguard Worker                                                                    poolingDescriptor,
669*89c4ff92SAndroid Build Coastguard Worker                                                                    padLayer->GetOutputSlot().GetTensorInfo(),
670*89c4ff92SAndroid Build Coastguard Worker                                                                    true))
671*89c4ff92SAndroid Build Coastguard Worker                 {
672*89c4ff92SAndroid Build Coastguard Worker                     FoldPadIntoAveragePool2d<Pooling2dLayer>(optimizationViews, baseLayer,
673*89c4ff92SAndroid Build Coastguard Worker                                                              poolingDescriptor, padLayer);
674*89c4ff92SAndroid Build Coastguard Worker                     untouched.erase(baseLayer->GetGuid());
675*89c4ff92SAndroid Build Coastguard Worker                     untouched.erase(padLayer->GetGuid());
676*89c4ff92SAndroid Build Coastguard Worker                 }
677*89c4ff92SAndroid Build Coastguard Worker             }
678*89c4ff92SAndroid Build Coastguard Worker         }
679*89c4ff92SAndroid Build Coastguard Worker     }
680*89c4ff92SAndroid Build Coastguard Worker 
681*89c4ff92SAndroid Build Coastguard Worker     if (optimizationViews.GetSubstitutions().empty())
682*89c4ff92SAndroid Build Coastguard Worker     {
683*89c4ff92SAndroid Build Coastguard Worker         optimizationViews.AddUntouchedSubgraph(SubgraphView(subgraph));
684*89c4ff92SAndroid Build Coastguard Worker     }
685*89c4ff92SAndroid Build Coastguard Worker     else
686*89c4ff92SAndroid Build Coastguard Worker     {
687*89c4ff92SAndroid Build Coastguard Worker         ReportUntouchedLayers(optimizationViews, untouched);
688*89c4ff92SAndroid Build Coastguard Worker     }
689*89c4ff92SAndroid Build Coastguard Worker 
690*89c4ff92SAndroid Build Coastguard Worker     return optimizationViews;
691*89c4ff92SAndroid Build Coastguard Worker }
692*89c4ff92SAndroid Build Coastguard Worker 
693*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
694