xref: /aosp_15_r20/external/armnn/include/armnnTestUtils/MockBackend.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Deprecated.hpp>
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Optional.hpp>
12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp>
13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp>
14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp>
15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/MemCopyWorkload.hpp>
16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/ITensorHandle.hpp>
17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IWorkload.hpp>
18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/OptimizationViews.hpp>
19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/SubgraphView.hpp>
20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadData.hpp>
21*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadFactory.hpp>
22*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadInfo.hpp>
23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp>
24*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
25*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockTensorHandle.hpp>
26*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/LayerSupportBase.hpp>
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker #include <client/include/CounterValue.hpp>
29*89c4ff92SAndroid Build Coastguard Worker #include <client/include/ISendTimelinePacket.hpp>
30*89c4ff92SAndroid Build Coastguard Worker #include <client/include/Timestamp.hpp>
31*89c4ff92SAndroid Build Coastguard Worker #include <client/include/backends/IBackendProfiling.hpp>
32*89c4ff92SAndroid Build Coastguard Worker #include <client/include/backends/IBackendProfilingContext.hpp>
33*89c4ff92SAndroid Build Coastguard Worker #include <common/include/Optional.hpp>
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker #include <atomic>
36*89c4ff92SAndroid Build Coastguard Worker #include <cstdint>
37*89c4ff92SAndroid Build Coastguard Worker #include <memory>
38*89c4ff92SAndroid Build Coastguard Worker #include <string>
39*89c4ff92SAndroid Build Coastguard Worker #include <utility>
40*89c4ff92SAndroid Build Coastguard Worker #include <vector>
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker namespace armnn
43*89c4ff92SAndroid Build Coastguard Worker {
44*89c4ff92SAndroid Build Coastguard Worker class BackendId;
45*89c4ff92SAndroid Build Coastguard Worker class ICustomAllocator;
46*89c4ff92SAndroid Build Coastguard Worker class MockMemoryManager;
47*89c4ff92SAndroid Build Coastguard Worker struct LstmInputParamsInfo;
48*89c4ff92SAndroid Build Coastguard Worker struct QuantizedLstmInputParamsInfo;
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker // A bare bones Mock backend to enable unit testing of simple tensor manipulation features.
51*89c4ff92SAndroid Build Coastguard Worker class MockBackend : public IBackendInternal
52*89c4ff92SAndroid Build Coastguard Worker {
53*89c4ff92SAndroid Build Coastguard Worker public:
54*89c4ff92SAndroid Build Coastguard Worker     MockBackend() = default;
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     ~MockBackend() = default;
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     static const BackendId& GetIdStatic();
59*89c4ff92SAndroid Build Coastguard Worker 
GetId() const60*89c4ff92SAndroid Build Coastguard Worker     const BackendId& GetId() const override
61*89c4ff92SAndroid Build Coastguard Worker     {
62*89c4ff92SAndroid Build Coastguard Worker         return GetIdStatic();
63*89c4ff92SAndroid Build Coastguard Worker     }
64*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IWorkloadFactoryPtr
65*89c4ff92SAndroid Build Coastguard Worker         CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override;
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override;
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override;
70*89c4ff92SAndroid Build Coastguard Worker 
71*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override;
72*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IBackendProfilingContextPtr
73*89c4ff92SAndroid Build Coastguard Worker     CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions,
74*89c4ff92SAndroid Build Coastguard Worker                                   IBackendProfilingPtr& backendProfiling) override;
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const override;
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ICustomAllocator> GetDefaultAllocator() const override;
79*89c4ff92SAndroid Build Coastguard Worker };
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker class MockWorkloadFactory : public IWorkloadFactory
82*89c4ff92SAndroid Build Coastguard Worker {
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker public:
85*89c4ff92SAndroid Build Coastguard Worker     explicit MockWorkloadFactory(const std::shared_ptr<MockMemoryManager>& memoryManager);
86*89c4ff92SAndroid Build Coastguard Worker     MockWorkloadFactory();
87*89c4ff92SAndroid Build Coastguard Worker 
~MockWorkloadFactory()88*89c4ff92SAndroid Build Coastguard Worker     ~MockWorkloadFactory()
89*89c4ff92SAndroid Build Coastguard Worker     {}
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     const BackendId& GetBackendId() const override;
92*89c4ff92SAndroid Build Coastguard Worker 
SupportsSubTensors() const93*89c4ff92SAndroid Build Coastguard Worker     bool SupportsSubTensors() const override
94*89c4ff92SAndroid Build Coastguard Worker     {
95*89c4ff92SAndroid Build Coastguard Worker         return false;
96*89c4ff92SAndroid Build Coastguard Worker     }
97*89c4ff92SAndroid Build Coastguard Worker 
98*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead")
CreateSubTensorHandle(ITensorHandle &,TensorShape const &,unsigned int const *) const99*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle&,
100*89c4ff92SAndroid Build Coastguard Worker                                                          TensorShape const&,
101*89c4ff92SAndroid Build Coastguard Worker                                                          unsigned int const*) const override
102*89c4ff92SAndroid Build Coastguard Worker     {
103*89c4ff92SAndroid Build Coastguard Worker         return nullptr;
104*89c4ff92SAndroid Build Coastguard Worker     }
105*89c4ff92SAndroid Build Coastguard Worker 
106*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged=true) const107*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
108*89c4ff92SAndroid Build Coastguard Worker                                                       const bool IsMemoryManaged = true) const override
109*89c4ff92SAndroid Build Coastguard Worker     {
110*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(IsMemoryManaged);
111*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager);
112*89c4ff92SAndroid Build Coastguard Worker     };
113*89c4ff92SAndroid Build Coastguard Worker 
114*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead")
CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged=true) const115*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
116*89c4ff92SAndroid Build Coastguard Worker                                                       DataLayout dataLayout,
117*89c4ff92SAndroid Build Coastguard Worker                                                       const bool IsMemoryManaged = true) const override
118*89c4ff92SAndroid Build Coastguard Worker     {
119*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(dataLayout, IsMemoryManaged);
120*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<MockTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc));
121*89c4ff92SAndroid Build Coastguard Worker     };
122*89c4ff92SAndroid Build Coastguard Worker 
123*89c4ff92SAndroid Build Coastguard Worker     ARMNN_DEPRECATED_MSG_REMOVAL_DATE(
124*89c4ff92SAndroid Build Coastguard Worker         "Use ABI stable "
125*89c4ff92SAndroid Build Coastguard Worker         "CreateWorkload(LayerType, const QueueDescriptor&, const WorkloadInfo& info) instead.",
126*89c4ff92SAndroid Build Coastguard Worker         "23.08")
CreateInput(const InputQueueDescriptor & descriptor,const WorkloadInfo & info) const127*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor,
128*89c4ff92SAndroid Build Coastguard Worker                                            const WorkloadInfo& info) const override
129*89c4ff92SAndroid Build Coastguard Worker     {
130*89c4ff92SAndroid Build Coastguard Worker         if (info.m_InputTensorInfos.empty())
131*89c4ff92SAndroid Build Coastguard Worker         {
132*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Input cannot be zero length");
133*89c4ff92SAndroid Build Coastguard Worker         }
134*89c4ff92SAndroid Build Coastguard Worker         if (info.m_OutputTensorInfos.empty())
135*89c4ff92SAndroid Build Coastguard Worker         {
136*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Output cannot be zero length");
137*89c4ff92SAndroid Build Coastguard Worker         }
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker         if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes())
140*89c4ff92SAndroid Build Coastguard Worker         {
141*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(
142*89c4ff92SAndroid Build Coastguard Worker                 "MockWorkloadFactory::CreateInput: data input and output differ in byte count.");
143*89c4ff92SAndroid Build Coastguard Worker         }
144*89c4ff92SAndroid Build Coastguard Worker 
145*89c4ff92SAndroid Build Coastguard Worker         return std::make_unique<CopyMemGenericWorkload>(descriptor, info);
146*89c4ff92SAndroid Build Coastguard Worker     };
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<IWorkload>
149*89c4ff92SAndroid Build Coastguard Worker         CreateWorkload(LayerType type, const QueueDescriptor& descriptor, const WorkloadInfo& info) const override;
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker private:
152*89c4ff92SAndroid Build Coastguard Worker     mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
153*89c4ff92SAndroid Build Coastguard Worker };
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker class MockBackendInitialiser
156*89c4ff92SAndroid Build Coastguard Worker {
157*89c4ff92SAndroid Build Coastguard Worker public:
158*89c4ff92SAndroid Build Coastguard Worker     MockBackendInitialiser();
159*89c4ff92SAndroid Build Coastguard Worker     ~MockBackendInitialiser();
160*89c4ff92SAndroid Build Coastguard Worker };
161*89c4ff92SAndroid Build Coastguard Worker 
162*89c4ff92SAndroid Build Coastguard Worker class MockBackendProfilingContext : public arm::pipe::IBackendProfilingContext
163*89c4ff92SAndroid Build Coastguard Worker {
164*89c4ff92SAndroid Build Coastguard Worker public:
MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr & backendProfiling)165*89c4ff92SAndroid Build Coastguard Worker     MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling)
166*89c4ff92SAndroid Build Coastguard Worker         : m_BackendProfiling(std::move(backendProfiling))
167*89c4ff92SAndroid Build Coastguard Worker         , m_CapturePeriod(0)
168*89c4ff92SAndroid Build Coastguard Worker         , m_IsTimelineEnabled(true)
169*89c4ff92SAndroid Build Coastguard Worker     {}
170*89c4ff92SAndroid Build Coastguard Worker 
171*89c4ff92SAndroid Build Coastguard Worker     ~MockBackendProfilingContext() = default;
172*89c4ff92SAndroid Build Coastguard Worker 
GetBackendProfiling()173*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IBackendProfilingPtr& GetBackendProfiling()
174*89c4ff92SAndroid Build Coastguard Worker     {
175*89c4ff92SAndroid Build Coastguard Worker         return m_BackendProfiling;
176*89c4ff92SAndroid Build Coastguard Worker     }
177*89c4ff92SAndroid Build Coastguard Worker 
RegisterCounters(uint16_t currentMaxGlobalCounterId)178*89c4ff92SAndroid Build Coastguard Worker     uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId)
179*89c4ff92SAndroid Build Coastguard Worker     {
180*89c4ff92SAndroid Build Coastguard Worker         std::unique_ptr<arm::pipe::IRegisterBackendCounters> counterRegistrar =
181*89c4ff92SAndroid Build Coastguard Worker             m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId));
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker         std::string categoryName("MockCounters");
184*89c4ff92SAndroid Build Coastguard Worker         counterRegistrar->RegisterCategory(categoryName);
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker         counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter");
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker         counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two",
189*89c4ff92SAndroid Build Coastguard Worker                                           "Another notional counter");
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker         std::string units("microseconds");
192*89c4ff92SAndroid Build Coastguard Worker         uint16_t nextMaxGlobalCounterId =
193*89c4ff92SAndroid Build Coastguard Worker                         counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter",
194*89c4ff92SAndroid Build Coastguard Worker                                                           "A dummy four core counter", units, 4);
195*89c4ff92SAndroid Build Coastguard Worker         return nextMaxGlobalCounterId;
196*89c4ff92SAndroid Build Coastguard Worker     }
197*89c4ff92SAndroid Build Coastguard Worker 
ActivateCounters(uint32_t capturePeriod,const std::vector<uint16_t> & counterIds)198*89c4ff92SAndroid Build Coastguard Worker     arm::pipe::Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds)
199*89c4ff92SAndroid Build Coastguard Worker     {
200*89c4ff92SAndroid Build Coastguard Worker         if (capturePeriod == 0 || counterIds.size() == 0)
201*89c4ff92SAndroid Build Coastguard Worker         {
202*89c4ff92SAndroid Build Coastguard Worker             m_ActiveCounters.clear();
203*89c4ff92SAndroid Build Coastguard Worker         }
204*89c4ff92SAndroid Build Coastguard Worker         else if (capturePeriod == 15939u)
205*89c4ff92SAndroid Build Coastguard Worker         {
206*89c4ff92SAndroid Build Coastguard Worker             return arm::pipe::Optional<std::string>("ActivateCounters example test error");
207*89c4ff92SAndroid Build Coastguard Worker         }
208*89c4ff92SAndroid Build Coastguard Worker         m_CapturePeriod  = capturePeriod;
209*89c4ff92SAndroid Build Coastguard Worker         m_ActiveCounters = counterIds;
210*89c4ff92SAndroid Build Coastguard Worker         return arm::pipe::Optional<std::string>();
211*89c4ff92SAndroid Build Coastguard Worker     }
212*89c4ff92SAndroid Build Coastguard Worker 
ReportCounterValues()213*89c4ff92SAndroid Build Coastguard Worker     std::vector<arm::pipe::Timestamp> ReportCounterValues()
214*89c4ff92SAndroid Build Coastguard Worker     {
215*89c4ff92SAndroid Build Coastguard Worker         std::vector<arm::pipe::CounterValue> counterValues;
216*89c4ff92SAndroid Build Coastguard Worker 
217*89c4ff92SAndroid Build Coastguard Worker         for (auto counterId : m_ActiveCounters)
218*89c4ff92SAndroid Build Coastguard Worker         {
219*89c4ff92SAndroid Build Coastguard Worker             counterValues.emplace_back(arm::pipe::CounterValue{ counterId, counterId + 1u });
220*89c4ff92SAndroid Build Coastguard Worker         }
221*89c4ff92SAndroid Build Coastguard Worker 
222*89c4ff92SAndroid Build Coastguard Worker         uint64_t timestamp = m_CapturePeriod;
223*89c4ff92SAndroid Build Coastguard Worker         return { arm::pipe::Timestamp{ timestamp, counterValues } };
224*89c4ff92SAndroid Build Coastguard Worker     }
225*89c4ff92SAndroid Build Coastguard Worker 
EnableProfiling(bool)226*89c4ff92SAndroid Build Coastguard Worker     bool EnableProfiling(bool)
227*89c4ff92SAndroid Build Coastguard Worker     {
228*89c4ff92SAndroid Build Coastguard Worker         auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket();
229*89c4ff92SAndroid Build Coastguard Worker         sendTimelinePacket->SendTimelineEntityBinaryPacket(4256);
230*89c4ff92SAndroid Build Coastguard Worker         sendTimelinePacket->Commit();
231*89c4ff92SAndroid Build Coastguard Worker         return true;
232*89c4ff92SAndroid Build Coastguard Worker     }
233*89c4ff92SAndroid Build Coastguard Worker 
EnableTimelineReporting(bool isEnabled)234*89c4ff92SAndroid Build Coastguard Worker     bool EnableTimelineReporting(bool isEnabled)
235*89c4ff92SAndroid Build Coastguard Worker     {
236*89c4ff92SAndroid Build Coastguard Worker         m_IsTimelineEnabled = isEnabled;
237*89c4ff92SAndroid Build Coastguard Worker         return isEnabled;
238*89c4ff92SAndroid Build Coastguard Worker     }
239*89c4ff92SAndroid Build Coastguard Worker 
TimelineReportingEnabled()240*89c4ff92SAndroid Build Coastguard Worker     bool TimelineReportingEnabled()
241*89c4ff92SAndroid Build Coastguard Worker     {
242*89c4ff92SAndroid Build Coastguard Worker         return m_IsTimelineEnabled;
243*89c4ff92SAndroid Build Coastguard Worker     }
244*89c4ff92SAndroid Build Coastguard Worker 
245*89c4ff92SAndroid Build Coastguard Worker private:
246*89c4ff92SAndroid Build Coastguard Worker     IBackendInternal::IBackendProfilingPtr m_BackendProfiling;
247*89c4ff92SAndroid Build Coastguard Worker     uint32_t m_CapturePeriod;
248*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint16_t> m_ActiveCounters;
249*89c4ff92SAndroid Build Coastguard Worker     std::atomic<bool> m_IsTimelineEnabled;
250*89c4ff92SAndroid Build Coastguard Worker };
251*89c4ff92SAndroid Build Coastguard Worker 
252*89c4ff92SAndroid Build Coastguard Worker class MockBackendProfilingService
253*89c4ff92SAndroid Build Coastguard Worker {
254*89c4ff92SAndroid Build Coastguard Worker public:
255*89c4ff92SAndroid Build Coastguard Worker     // Getter for the singleton instance
Instance()256*89c4ff92SAndroid Build Coastguard Worker     static MockBackendProfilingService& Instance()
257*89c4ff92SAndroid Build Coastguard Worker     {
258*89c4ff92SAndroid Build Coastguard Worker         static MockBackendProfilingService instance;
259*89c4ff92SAndroid Build Coastguard Worker         return instance;
260*89c4ff92SAndroid Build Coastguard Worker     }
261*89c4ff92SAndroid Build Coastguard Worker 
GetContext()262*89c4ff92SAndroid Build Coastguard Worker     MockBackendProfilingContext* GetContext()
263*89c4ff92SAndroid Build Coastguard Worker     {
264*89c4ff92SAndroid Build Coastguard Worker         return m_sharedContext.get();
265*89c4ff92SAndroid Build Coastguard Worker     }
266*89c4ff92SAndroid Build Coastguard Worker 
SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)267*89c4ff92SAndroid Build Coastguard Worker     void SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)
268*89c4ff92SAndroid Build Coastguard Worker     {
269*89c4ff92SAndroid Build Coastguard Worker         m_sharedContext = shared;
270*89c4ff92SAndroid Build Coastguard Worker     }
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker private:
273*89c4ff92SAndroid Build Coastguard Worker     std::shared_ptr<MockBackendProfilingContext> m_sharedContext;
274*89c4ff92SAndroid Build Coastguard Worker };
275*89c4ff92SAndroid Build Coastguard Worker 
276*89c4ff92SAndroid Build Coastguard Worker class MockLayerSupport : public LayerSupportBase
277*89c4ff92SAndroid Build Coastguard Worker {
278*89c4ff92SAndroid Build Coastguard Worker public:
IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> &,const Optional<QuantizedLstmInputParamsInfo> &,Optional<std::string &> reasonIfUnsupported) const279*89c4ff92SAndroid Build Coastguard Worker     bool IsLayerSupported(const LayerType& type,
280*89c4ff92SAndroid Build Coastguard Worker                           const std::vector<TensorInfo>& infos,
281*89c4ff92SAndroid Build Coastguard Worker                           const BaseDescriptor& descriptor,
282*89c4ff92SAndroid Build Coastguard Worker                           const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/,
283*89c4ff92SAndroid Build Coastguard Worker                           const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/,
284*89c4ff92SAndroid Build Coastguard Worker                           Optional<std::string&> reasonIfUnsupported) const override
285*89c4ff92SAndroid Build Coastguard Worker     {
286*89c4ff92SAndroid Build Coastguard Worker         switch(type)
287*89c4ff92SAndroid Build Coastguard Worker         {
288*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Input:
289*89c4ff92SAndroid Build Coastguard Worker                 return IsInputSupported(infos[0], reasonIfUnsupported);
290*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Output:
291*89c4ff92SAndroid Build Coastguard Worker                 return IsOutputSupported(infos[0], reasonIfUnsupported);
292*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Addition:
293*89c4ff92SAndroid Build Coastguard Worker                 return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported);
294*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Convolution2d:
295*89c4ff92SAndroid Build Coastguard Worker             {
296*89c4ff92SAndroid Build Coastguard Worker                 if (infos.size() != 4)
297*89c4ff92SAndroid Build Coastguard Worker                 {
298*89c4ff92SAndroid Build Coastguard Worker                     throw InvalidArgumentException("Invalid number of TransposeConvolution2d "
299*89c4ff92SAndroid Build Coastguard Worker                                                    "TensorInfos. TensorInfos should be of format: "
300*89c4ff92SAndroid Build Coastguard Worker                                                    "{input, output, weights, biases}.");
301*89c4ff92SAndroid Build Coastguard Worker                 }
302*89c4ff92SAndroid Build Coastguard Worker 
303*89c4ff92SAndroid Build Coastguard Worker                 auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor));
304*89c4ff92SAndroid Build Coastguard Worker                 if (infos[3] == TensorInfo())
305*89c4ff92SAndroid Build Coastguard Worker                 {
306*89c4ff92SAndroid Build Coastguard Worker                     return IsConvolution2dSupported(infos[0],
307*89c4ff92SAndroid Build Coastguard Worker                                                     infos[1],
308*89c4ff92SAndroid Build Coastguard Worker                                                     desc,
309*89c4ff92SAndroid Build Coastguard Worker                                                     infos[2],
310*89c4ff92SAndroid Build Coastguard Worker                                                     EmptyOptional(),
311*89c4ff92SAndroid Build Coastguard Worker                                                     reasonIfUnsupported);
312*89c4ff92SAndroid Build Coastguard Worker                 }
313*89c4ff92SAndroid Build Coastguard Worker                 else
314*89c4ff92SAndroid Build Coastguard Worker                 {
315*89c4ff92SAndroid Build Coastguard Worker                     return IsConvolution2dSupported(infos[0],
316*89c4ff92SAndroid Build Coastguard Worker                                                     infos[1],
317*89c4ff92SAndroid Build Coastguard Worker                                                     desc,
318*89c4ff92SAndroid Build Coastguard Worker                                                     infos[2],
319*89c4ff92SAndroid Build Coastguard Worker                                                     infos[3],
320*89c4ff92SAndroid Build Coastguard Worker                                                     reasonIfUnsupported);
321*89c4ff92SAndroid Build Coastguard Worker                 }
322*89c4ff92SAndroid Build Coastguard Worker             }
323*89c4ff92SAndroid Build Coastguard Worker             case LayerType::ElementwiseBinary:
324*89c4ff92SAndroid Build Coastguard Worker             {
325*89c4ff92SAndroid Build Coastguard Worker                 auto elementwiseDesc = *(PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&descriptor));
326*89c4ff92SAndroid Build Coastguard Worker                 return (elementwiseDesc.m_Operation == BinaryOperation::Add);
327*89c4ff92SAndroid Build Coastguard Worker             }
328*89c4ff92SAndroid Build Coastguard Worker             default:
329*89c4ff92SAndroid Build Coastguard Worker                 return false;
330*89c4ff92SAndroid Build Coastguard Worker         }
331*89c4ff92SAndroid Build Coastguard Worker     }
332*89c4ff92SAndroid Build Coastguard Worker 
IsInputSupported(const TensorInfo &,Optional<std::string &>) const333*89c4ff92SAndroid Build Coastguard Worker     bool IsInputSupported(const TensorInfo& /*input*/,
334*89c4ff92SAndroid Build Coastguard Worker                           Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
335*89c4ff92SAndroid Build Coastguard Worker     {
336*89c4ff92SAndroid Build Coastguard Worker         return true;
337*89c4ff92SAndroid Build Coastguard Worker     }
338*89c4ff92SAndroid Build Coastguard Worker 
IsOutputSupported(const TensorInfo &,Optional<std::string &>) const339*89c4ff92SAndroid Build Coastguard Worker     bool IsOutputSupported(const TensorInfo& /*input*/,
340*89c4ff92SAndroid Build Coastguard Worker                            Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
341*89c4ff92SAndroid Build Coastguard Worker     {
342*89c4ff92SAndroid Build Coastguard Worker         return true;
343*89c4ff92SAndroid Build Coastguard Worker     }
344*89c4ff92SAndroid Build Coastguard Worker 
IsAdditionSupported(const TensorInfo &,const TensorInfo &,const TensorInfo &,Optional<std::string &>) const345*89c4ff92SAndroid Build Coastguard Worker     bool IsAdditionSupported(const TensorInfo& /*input0*/,
346*89c4ff92SAndroid Build Coastguard Worker                              const TensorInfo& /*input1*/,
347*89c4ff92SAndroid Build Coastguard Worker                              const TensorInfo& /*output*/,
348*89c4ff92SAndroid Build Coastguard Worker                              Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
349*89c4ff92SAndroid Build Coastguard Worker     {
350*89c4ff92SAndroid Build Coastguard Worker         return true;
351*89c4ff92SAndroid Build Coastguard Worker     }
352*89c4ff92SAndroid Build Coastguard Worker 
IsConvolution2dSupported(const TensorInfo &,const TensorInfo &,const Convolution2dDescriptor &,const TensorInfo &,const Optional<TensorInfo> &,Optional<std::string &>) const353*89c4ff92SAndroid Build Coastguard Worker     bool IsConvolution2dSupported(const TensorInfo& /*input*/,
354*89c4ff92SAndroid Build Coastguard Worker                                   const TensorInfo& /*output*/,
355*89c4ff92SAndroid Build Coastguard Worker                                   const Convolution2dDescriptor& /*descriptor*/,
356*89c4ff92SAndroid Build Coastguard Worker                                   const TensorInfo& /*weights*/,
357*89c4ff92SAndroid Build Coastguard Worker                                   const Optional<TensorInfo>& /*biases*/,
358*89c4ff92SAndroid Build Coastguard Worker                                   Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override
359*89c4ff92SAndroid Build Coastguard Worker     {
360*89c4ff92SAndroid Build Coastguard Worker         return true;
361*89c4ff92SAndroid Build Coastguard Worker     }
362*89c4ff92SAndroid Build Coastguard Worker };
363*89c4ff92SAndroid Build Coastguard Worker 
364*89c4ff92SAndroid Build Coastguard Worker }    // namespace armnn
365