1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Deprecated.hpp> 8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Descriptors.hpp> 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Exceptions.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <armnn/IRuntime.hpp> 11*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Optional.hpp> 12*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Tensor.hpp> 13*89c4ff92SAndroid Build Coastguard Worker #include <armnn/Types.hpp> 14*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IBackendInternal.hpp> 15*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/MemCopyWorkload.hpp> 16*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/ITensorHandle.hpp> 17*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/IWorkload.hpp> 18*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/OptimizationViews.hpp> 19*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/SubgraphView.hpp> 20*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadData.hpp> 21*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadFactory.hpp> 22*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/WorkloadInfo.hpp> 23*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/IgnoreUnused.hpp> 24*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp> 25*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockTensorHandle.hpp> 26*89c4ff92SAndroid Build Coastguard Worker #include <backendsCommon/LayerSupportBase.hpp> 27*89c4ff92SAndroid Build Coastguard Worker 28*89c4ff92SAndroid Build Coastguard Worker #include <client/include/CounterValue.hpp> 29*89c4ff92SAndroid Build Coastguard Worker #include <client/include/ISendTimelinePacket.hpp> 30*89c4ff92SAndroid Build Coastguard Worker #include <client/include/Timestamp.hpp> 31*89c4ff92SAndroid Build Coastguard Worker #include <client/include/backends/IBackendProfiling.hpp> 32*89c4ff92SAndroid Build Coastguard Worker #include <client/include/backends/IBackendProfilingContext.hpp> 33*89c4ff92SAndroid Build Coastguard Worker #include <common/include/Optional.hpp> 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker #include <atomic> 36*89c4ff92SAndroid Build Coastguard Worker #include <cstdint> 37*89c4ff92SAndroid Build Coastguard Worker #include <memory> 38*89c4ff92SAndroid Build Coastguard Worker #include <string> 39*89c4ff92SAndroid Build Coastguard Worker #include <utility> 40*89c4ff92SAndroid Build Coastguard Worker #include <vector> 41*89c4ff92SAndroid Build Coastguard Worker 42*89c4ff92SAndroid Build Coastguard Worker namespace armnn 43*89c4ff92SAndroid Build Coastguard Worker { 44*89c4ff92SAndroid Build Coastguard Worker class BackendId; 45*89c4ff92SAndroid Build Coastguard Worker class ICustomAllocator; 46*89c4ff92SAndroid Build Coastguard Worker class MockMemoryManager; 47*89c4ff92SAndroid Build Coastguard Worker struct LstmInputParamsInfo; 48*89c4ff92SAndroid Build Coastguard Worker struct QuantizedLstmInputParamsInfo; 49*89c4ff92SAndroid Build Coastguard Worker 50*89c4ff92SAndroid Build Coastguard Worker // A bare bones Mock backend to enable unit testing of simple tensor manipulation features. 51*89c4ff92SAndroid Build Coastguard Worker class MockBackend : public IBackendInternal 52*89c4ff92SAndroid Build Coastguard Worker { 53*89c4ff92SAndroid Build Coastguard Worker public: 54*89c4ff92SAndroid Build Coastguard Worker MockBackend() = default; 55*89c4ff92SAndroid Build Coastguard Worker 56*89c4ff92SAndroid Build Coastguard Worker ~MockBackend() = default; 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker static const BackendId& GetIdStatic(); 59*89c4ff92SAndroid Build Coastguard Worker GetId() const60*89c4ff92SAndroid Build Coastguard Worker const BackendId& GetId() const override 61*89c4ff92SAndroid Build Coastguard Worker { 62*89c4ff92SAndroid Build Coastguard Worker return GetIdStatic(); 63*89c4ff92SAndroid Build Coastguard Worker } 64*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IWorkloadFactoryPtr 65*89c4ff92SAndroid Build Coastguard Worker CreateWorkloadFactory(const IBackendInternal::IMemoryManagerSharedPtr& memoryManager = nullptr) const override; 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::ILayerSupportSharedPtr GetLayerSupport() const override; 68*89c4ff92SAndroid Build Coastguard Worker 69*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IMemoryManagerUniquePtr CreateMemoryManager() const override; 70*89c4ff92SAndroid Build Coastguard Worker 71*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendContextPtr CreateBackendContext(const IRuntime::CreationOptions&) const override; 72*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendProfilingContextPtr 73*89c4ff92SAndroid Build Coastguard Worker CreateBackendProfilingContext(const IRuntime::CreationOptions& creationOptions, 74*89c4ff92SAndroid Build Coastguard Worker IBackendProfilingPtr& backendProfiling) override; 75*89c4ff92SAndroid Build Coastguard Worker 76*89c4ff92SAndroid Build Coastguard Worker OptimizationViews OptimizeSubgraphView(const SubgraphView& subgraph) const override; 77*89c4ff92SAndroid Build Coastguard Worker 78*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ICustomAllocator> GetDefaultAllocator() const override; 79*89c4ff92SAndroid Build Coastguard Worker }; 80*89c4ff92SAndroid Build Coastguard Worker 81*89c4ff92SAndroid Build Coastguard Worker class MockWorkloadFactory : public IWorkloadFactory 82*89c4ff92SAndroid Build Coastguard Worker { 83*89c4ff92SAndroid Build Coastguard Worker 84*89c4ff92SAndroid Build Coastguard Worker public: 85*89c4ff92SAndroid Build Coastguard Worker explicit MockWorkloadFactory(const std::shared_ptr<MockMemoryManager>& memoryManager); 86*89c4ff92SAndroid Build Coastguard Worker MockWorkloadFactory(); 87*89c4ff92SAndroid Build Coastguard Worker ~MockWorkloadFactory()88*89c4ff92SAndroid Build Coastguard Worker ~MockWorkloadFactory() 89*89c4ff92SAndroid Build Coastguard Worker {} 90*89c4ff92SAndroid Build Coastguard Worker 91*89c4ff92SAndroid Build Coastguard Worker const BackendId& GetBackendId() const override; 92*89c4ff92SAndroid Build Coastguard Worker SupportsSubTensors() const93*89c4ff92SAndroid Build Coastguard Worker bool SupportsSubTensors() const override 94*89c4ff92SAndroid Build Coastguard Worker { 95*89c4ff92SAndroid Build Coastguard Worker return false; 96*89c4ff92SAndroid Build Coastguard Worker } 97*89c4ff92SAndroid Build Coastguard Worker 98*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateSubTensorHandle instead") CreateSubTensorHandle(ITensorHandle &,TensorShape const &,unsigned int const *) const99*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle&, 100*89c4ff92SAndroid Build Coastguard Worker TensorShape const&, 101*89c4ff92SAndroid Build Coastguard Worker unsigned int const*) const override 102*89c4ff92SAndroid Build Coastguard Worker { 103*89c4ff92SAndroid Build Coastguard Worker return nullptr; 104*89c4ff92SAndroid Build Coastguard Worker } 105*89c4ff92SAndroid Build Coastguard Worker 106*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead") CreateTensorHandle(const TensorInfo & tensorInfo,const bool IsMemoryManaged=true) const107*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 108*89c4ff92SAndroid Build Coastguard Worker const bool IsMemoryManaged = true) const override 109*89c4ff92SAndroid Build Coastguard Worker { 110*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(IsMemoryManaged); 111*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MockTensorHandle>(tensorInfo, m_MemoryManager); 112*89c4ff92SAndroid Build Coastguard Worker }; 113*89c4ff92SAndroid Build Coastguard Worker 114*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG("Use ITensorHandleFactory::CreateTensorHandle instead") CreateTensorHandle(const TensorInfo & tensorInfo,DataLayout dataLayout,const bool IsMemoryManaged=true) const115*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 116*89c4ff92SAndroid Build Coastguard Worker DataLayout dataLayout, 117*89c4ff92SAndroid Build Coastguard Worker const bool IsMemoryManaged = true) const override 118*89c4ff92SAndroid Build Coastguard Worker { 119*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(dataLayout, IsMemoryManaged); 120*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<MockTensorHandle>(tensorInfo, static_cast<unsigned int>(MemorySource::Malloc)); 121*89c4ff92SAndroid Build Coastguard Worker }; 122*89c4ff92SAndroid Build Coastguard Worker 123*89c4ff92SAndroid Build Coastguard Worker ARMNN_DEPRECATED_MSG_REMOVAL_DATE( 124*89c4ff92SAndroid Build Coastguard Worker "Use ABI stable " 125*89c4ff92SAndroid Build Coastguard Worker "CreateWorkload(LayerType, const QueueDescriptor&, const WorkloadInfo& info) instead.", 126*89c4ff92SAndroid Build Coastguard Worker "23.08") CreateInput(const InputQueueDescriptor & descriptor,const WorkloadInfo & info) const127*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> CreateInput(const InputQueueDescriptor& descriptor, 128*89c4ff92SAndroid Build Coastguard Worker const WorkloadInfo& info) const override 129*89c4ff92SAndroid Build Coastguard Worker { 130*89c4ff92SAndroid Build Coastguard Worker if (info.m_InputTensorInfos.empty()) 131*89c4ff92SAndroid Build Coastguard Worker { 132*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Input cannot be zero length"); 133*89c4ff92SAndroid Build Coastguard Worker } 134*89c4ff92SAndroid Build Coastguard Worker if (info.m_OutputTensorInfos.empty()) 135*89c4ff92SAndroid Build Coastguard Worker { 136*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("MockWorkloadFactory::CreateInput: Output cannot be zero length"); 137*89c4ff92SAndroid Build Coastguard Worker } 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker if (info.m_InputTensorInfos[0].GetNumBytes() != info.m_OutputTensorInfos[0].GetNumBytes()) 140*89c4ff92SAndroid Build Coastguard Worker { 141*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException( 142*89c4ff92SAndroid Build Coastguard Worker "MockWorkloadFactory::CreateInput: data input and output differ in byte count."); 143*89c4ff92SAndroid Build Coastguard Worker } 144*89c4ff92SAndroid Build Coastguard Worker 145*89c4ff92SAndroid Build Coastguard Worker return std::make_unique<CopyMemGenericWorkload>(descriptor, info); 146*89c4ff92SAndroid Build Coastguard Worker }; 147*89c4ff92SAndroid Build Coastguard Worker 148*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<IWorkload> 149*89c4ff92SAndroid Build Coastguard Worker CreateWorkload(LayerType type, const QueueDescriptor& descriptor, const WorkloadInfo& info) const override; 150*89c4ff92SAndroid Build Coastguard Worker 151*89c4ff92SAndroid Build Coastguard Worker private: 152*89c4ff92SAndroid Build Coastguard Worker mutable std::shared_ptr<MockMemoryManager> m_MemoryManager; 153*89c4ff92SAndroid Build Coastguard Worker }; 154*89c4ff92SAndroid Build Coastguard Worker 155*89c4ff92SAndroid Build Coastguard Worker class MockBackendInitialiser 156*89c4ff92SAndroid Build Coastguard Worker { 157*89c4ff92SAndroid Build Coastguard Worker public: 158*89c4ff92SAndroid Build Coastguard Worker MockBackendInitialiser(); 159*89c4ff92SAndroid Build Coastguard Worker ~MockBackendInitialiser(); 160*89c4ff92SAndroid Build Coastguard Worker }; 161*89c4ff92SAndroid Build Coastguard Worker 162*89c4ff92SAndroid Build Coastguard Worker class MockBackendProfilingContext : public arm::pipe::IBackendProfilingContext 163*89c4ff92SAndroid Build Coastguard Worker { 164*89c4ff92SAndroid Build Coastguard Worker public: MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr & backendProfiling)165*89c4ff92SAndroid Build Coastguard Worker MockBackendProfilingContext(IBackendInternal::IBackendProfilingPtr& backendProfiling) 166*89c4ff92SAndroid Build Coastguard Worker : m_BackendProfiling(std::move(backendProfiling)) 167*89c4ff92SAndroid Build Coastguard Worker , m_CapturePeriod(0) 168*89c4ff92SAndroid Build Coastguard Worker , m_IsTimelineEnabled(true) 169*89c4ff92SAndroid Build Coastguard Worker {} 170*89c4ff92SAndroid Build Coastguard Worker 171*89c4ff92SAndroid Build Coastguard Worker ~MockBackendProfilingContext() = default; 172*89c4ff92SAndroid Build Coastguard Worker GetBackendProfiling()173*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendProfilingPtr& GetBackendProfiling() 174*89c4ff92SAndroid Build Coastguard Worker { 175*89c4ff92SAndroid Build Coastguard Worker return m_BackendProfiling; 176*89c4ff92SAndroid Build Coastguard Worker } 177*89c4ff92SAndroid Build Coastguard Worker RegisterCounters(uint16_t currentMaxGlobalCounterId)178*89c4ff92SAndroid Build Coastguard Worker uint16_t RegisterCounters(uint16_t currentMaxGlobalCounterId) 179*89c4ff92SAndroid Build Coastguard Worker { 180*89c4ff92SAndroid Build Coastguard Worker std::unique_ptr<arm::pipe::IRegisterBackendCounters> counterRegistrar = 181*89c4ff92SAndroid Build Coastguard Worker m_BackendProfiling->GetCounterRegistrationInterface(static_cast<uint16_t>(currentMaxGlobalCounterId)); 182*89c4ff92SAndroid Build Coastguard Worker 183*89c4ff92SAndroid Build Coastguard Worker std::string categoryName("MockCounters"); 184*89c4ff92SAndroid Build Coastguard Worker counterRegistrar->RegisterCategory(categoryName); 185*89c4ff92SAndroid Build Coastguard Worker 186*89c4ff92SAndroid Build Coastguard Worker counterRegistrar->RegisterCounter(0, categoryName, 0, 0, 1.f, "Mock Counter One", "Some notional counter"); 187*89c4ff92SAndroid Build Coastguard Worker 188*89c4ff92SAndroid Build Coastguard Worker counterRegistrar->RegisterCounter(1, categoryName, 0, 0, 1.f, "Mock Counter Two", 189*89c4ff92SAndroid Build Coastguard Worker "Another notional counter"); 190*89c4ff92SAndroid Build Coastguard Worker 191*89c4ff92SAndroid Build Coastguard Worker std::string units("microseconds"); 192*89c4ff92SAndroid Build Coastguard Worker uint16_t nextMaxGlobalCounterId = 193*89c4ff92SAndroid Build Coastguard Worker counterRegistrar->RegisterCounter(2, categoryName, 0, 0, 1.f, "Mock MultiCore Counter", 194*89c4ff92SAndroid Build Coastguard Worker "A dummy four core counter", units, 4); 195*89c4ff92SAndroid Build Coastguard Worker return nextMaxGlobalCounterId; 196*89c4ff92SAndroid Build Coastguard Worker } 197*89c4ff92SAndroid Build Coastguard Worker ActivateCounters(uint32_t capturePeriod,const std::vector<uint16_t> & counterIds)198*89c4ff92SAndroid Build Coastguard Worker arm::pipe::Optional<std::string> ActivateCounters(uint32_t capturePeriod, const std::vector<uint16_t>& counterIds) 199*89c4ff92SAndroid Build Coastguard Worker { 200*89c4ff92SAndroid Build Coastguard Worker if (capturePeriod == 0 || counterIds.size() == 0) 201*89c4ff92SAndroid Build Coastguard Worker { 202*89c4ff92SAndroid Build Coastguard Worker m_ActiveCounters.clear(); 203*89c4ff92SAndroid Build Coastguard Worker } 204*89c4ff92SAndroid Build Coastguard Worker else if (capturePeriod == 15939u) 205*89c4ff92SAndroid Build Coastguard Worker { 206*89c4ff92SAndroid Build Coastguard Worker return arm::pipe::Optional<std::string>("ActivateCounters example test error"); 207*89c4ff92SAndroid Build Coastguard Worker } 208*89c4ff92SAndroid Build Coastguard Worker m_CapturePeriod = capturePeriod; 209*89c4ff92SAndroid Build Coastguard Worker m_ActiveCounters = counterIds; 210*89c4ff92SAndroid Build Coastguard Worker return arm::pipe::Optional<std::string>(); 211*89c4ff92SAndroid Build Coastguard Worker } 212*89c4ff92SAndroid Build Coastguard Worker ReportCounterValues()213*89c4ff92SAndroid Build Coastguard Worker std::vector<arm::pipe::Timestamp> ReportCounterValues() 214*89c4ff92SAndroid Build Coastguard Worker { 215*89c4ff92SAndroid Build Coastguard Worker std::vector<arm::pipe::CounterValue> counterValues; 216*89c4ff92SAndroid Build Coastguard Worker 217*89c4ff92SAndroid Build Coastguard Worker for (auto counterId : m_ActiveCounters) 218*89c4ff92SAndroid Build Coastguard Worker { 219*89c4ff92SAndroid Build Coastguard Worker counterValues.emplace_back(arm::pipe::CounterValue{ counterId, counterId + 1u }); 220*89c4ff92SAndroid Build Coastguard Worker } 221*89c4ff92SAndroid Build Coastguard Worker 222*89c4ff92SAndroid Build Coastguard Worker uint64_t timestamp = m_CapturePeriod; 223*89c4ff92SAndroid Build Coastguard Worker return { arm::pipe::Timestamp{ timestamp, counterValues } }; 224*89c4ff92SAndroid Build Coastguard Worker } 225*89c4ff92SAndroid Build Coastguard Worker EnableProfiling(bool)226*89c4ff92SAndroid Build Coastguard Worker bool EnableProfiling(bool) 227*89c4ff92SAndroid Build Coastguard Worker { 228*89c4ff92SAndroid Build Coastguard Worker auto sendTimelinePacket = m_BackendProfiling->GetSendTimelinePacket(); 229*89c4ff92SAndroid Build Coastguard Worker sendTimelinePacket->SendTimelineEntityBinaryPacket(4256); 230*89c4ff92SAndroid Build Coastguard Worker sendTimelinePacket->Commit(); 231*89c4ff92SAndroid Build Coastguard Worker return true; 232*89c4ff92SAndroid Build Coastguard Worker } 233*89c4ff92SAndroid Build Coastguard Worker EnableTimelineReporting(bool isEnabled)234*89c4ff92SAndroid Build Coastguard Worker bool EnableTimelineReporting(bool isEnabled) 235*89c4ff92SAndroid Build Coastguard Worker { 236*89c4ff92SAndroid Build Coastguard Worker m_IsTimelineEnabled = isEnabled; 237*89c4ff92SAndroid Build Coastguard Worker return isEnabled; 238*89c4ff92SAndroid Build Coastguard Worker } 239*89c4ff92SAndroid Build Coastguard Worker TimelineReportingEnabled()240*89c4ff92SAndroid Build Coastguard Worker bool TimelineReportingEnabled() 241*89c4ff92SAndroid Build Coastguard Worker { 242*89c4ff92SAndroid Build Coastguard Worker return m_IsTimelineEnabled; 243*89c4ff92SAndroid Build Coastguard Worker } 244*89c4ff92SAndroid Build Coastguard Worker 245*89c4ff92SAndroid Build Coastguard Worker private: 246*89c4ff92SAndroid Build Coastguard Worker IBackendInternal::IBackendProfilingPtr m_BackendProfiling; 247*89c4ff92SAndroid Build Coastguard Worker uint32_t m_CapturePeriod; 248*89c4ff92SAndroid Build Coastguard Worker std::vector<uint16_t> m_ActiveCounters; 249*89c4ff92SAndroid Build Coastguard Worker std::atomic<bool> m_IsTimelineEnabled; 250*89c4ff92SAndroid Build Coastguard Worker }; 251*89c4ff92SAndroid Build Coastguard Worker 252*89c4ff92SAndroid Build Coastguard Worker class MockBackendProfilingService 253*89c4ff92SAndroid Build Coastguard Worker { 254*89c4ff92SAndroid Build Coastguard Worker public: 255*89c4ff92SAndroid Build Coastguard Worker // Getter for the singleton instance Instance()256*89c4ff92SAndroid Build Coastguard Worker static MockBackendProfilingService& Instance() 257*89c4ff92SAndroid Build Coastguard Worker { 258*89c4ff92SAndroid Build Coastguard Worker static MockBackendProfilingService instance; 259*89c4ff92SAndroid Build Coastguard Worker return instance; 260*89c4ff92SAndroid Build Coastguard Worker } 261*89c4ff92SAndroid Build Coastguard Worker GetContext()262*89c4ff92SAndroid Build Coastguard Worker MockBackendProfilingContext* GetContext() 263*89c4ff92SAndroid Build Coastguard Worker { 264*89c4ff92SAndroid Build Coastguard Worker return m_sharedContext.get(); 265*89c4ff92SAndroid Build Coastguard Worker } 266*89c4ff92SAndroid Build Coastguard Worker SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared)267*89c4ff92SAndroid Build Coastguard Worker void SetProfilingContextPtr(std::shared_ptr<MockBackendProfilingContext> shared) 268*89c4ff92SAndroid Build Coastguard Worker { 269*89c4ff92SAndroid Build Coastguard Worker m_sharedContext = shared; 270*89c4ff92SAndroid Build Coastguard Worker } 271*89c4ff92SAndroid Build Coastguard Worker 272*89c4ff92SAndroid Build Coastguard Worker private: 273*89c4ff92SAndroid Build Coastguard Worker std::shared_ptr<MockBackendProfilingContext> m_sharedContext; 274*89c4ff92SAndroid Build Coastguard Worker }; 275*89c4ff92SAndroid Build Coastguard Worker 276*89c4ff92SAndroid Build Coastguard Worker class MockLayerSupport : public LayerSupportBase 277*89c4ff92SAndroid Build Coastguard Worker { 278*89c4ff92SAndroid Build Coastguard Worker public: IsLayerSupported(const LayerType & type,const std::vector<TensorInfo> & infos,const BaseDescriptor & descriptor,const Optional<LstmInputParamsInfo> &,const Optional<QuantizedLstmInputParamsInfo> &,Optional<std::string &> reasonIfUnsupported) const279*89c4ff92SAndroid Build Coastguard Worker bool IsLayerSupported(const LayerType& type, 280*89c4ff92SAndroid Build Coastguard Worker const std::vector<TensorInfo>& infos, 281*89c4ff92SAndroid Build Coastguard Worker const BaseDescriptor& descriptor, 282*89c4ff92SAndroid Build Coastguard Worker const Optional<LstmInputParamsInfo>& /*lstmParamsInfo*/, 283*89c4ff92SAndroid Build Coastguard Worker const Optional<QuantizedLstmInputParamsInfo>& /*quantizedLstmParamsInfo*/, 284*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> reasonIfUnsupported) const override 285*89c4ff92SAndroid Build Coastguard Worker { 286*89c4ff92SAndroid Build Coastguard Worker switch(type) 287*89c4ff92SAndroid Build Coastguard Worker { 288*89c4ff92SAndroid Build Coastguard Worker case LayerType::Input: 289*89c4ff92SAndroid Build Coastguard Worker return IsInputSupported(infos[0], reasonIfUnsupported); 290*89c4ff92SAndroid Build Coastguard Worker case LayerType::Output: 291*89c4ff92SAndroid Build Coastguard Worker return IsOutputSupported(infos[0], reasonIfUnsupported); 292*89c4ff92SAndroid Build Coastguard Worker case LayerType::Addition: 293*89c4ff92SAndroid Build Coastguard Worker return IsAdditionSupported(infos[0], infos[1], infos[2], reasonIfUnsupported); 294*89c4ff92SAndroid Build Coastguard Worker case LayerType::Convolution2d: 295*89c4ff92SAndroid Build Coastguard Worker { 296*89c4ff92SAndroid Build Coastguard Worker if (infos.size() != 4) 297*89c4ff92SAndroid Build Coastguard Worker { 298*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException("Invalid number of TransposeConvolution2d " 299*89c4ff92SAndroid Build Coastguard Worker "TensorInfos. TensorInfos should be of format: " 300*89c4ff92SAndroid Build Coastguard Worker "{input, output, weights, biases}."); 301*89c4ff92SAndroid Build Coastguard Worker } 302*89c4ff92SAndroid Build Coastguard Worker 303*89c4ff92SAndroid Build Coastguard Worker auto desc = *(PolymorphicDowncast<const Convolution2dDescriptor*>(&descriptor)); 304*89c4ff92SAndroid Build Coastguard Worker if (infos[3] == TensorInfo()) 305*89c4ff92SAndroid Build Coastguard Worker { 306*89c4ff92SAndroid Build Coastguard Worker return IsConvolution2dSupported(infos[0], 307*89c4ff92SAndroid Build Coastguard Worker infos[1], 308*89c4ff92SAndroid Build Coastguard Worker desc, 309*89c4ff92SAndroid Build Coastguard Worker infos[2], 310*89c4ff92SAndroid Build Coastguard Worker EmptyOptional(), 311*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported); 312*89c4ff92SAndroid Build Coastguard Worker } 313*89c4ff92SAndroid Build Coastguard Worker else 314*89c4ff92SAndroid Build Coastguard Worker { 315*89c4ff92SAndroid Build Coastguard Worker return IsConvolution2dSupported(infos[0], 316*89c4ff92SAndroid Build Coastguard Worker infos[1], 317*89c4ff92SAndroid Build Coastguard Worker desc, 318*89c4ff92SAndroid Build Coastguard Worker infos[2], 319*89c4ff92SAndroid Build Coastguard Worker infos[3], 320*89c4ff92SAndroid Build Coastguard Worker reasonIfUnsupported); 321*89c4ff92SAndroid Build Coastguard Worker } 322*89c4ff92SAndroid Build Coastguard Worker } 323*89c4ff92SAndroid Build Coastguard Worker case LayerType::ElementwiseBinary: 324*89c4ff92SAndroid Build Coastguard Worker { 325*89c4ff92SAndroid Build Coastguard Worker auto elementwiseDesc = *(PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&descriptor)); 326*89c4ff92SAndroid Build Coastguard Worker return (elementwiseDesc.m_Operation == BinaryOperation::Add); 327*89c4ff92SAndroid Build Coastguard Worker } 328*89c4ff92SAndroid Build Coastguard Worker default: 329*89c4ff92SAndroid Build Coastguard Worker return false; 330*89c4ff92SAndroid Build Coastguard Worker } 331*89c4ff92SAndroid Build Coastguard Worker } 332*89c4ff92SAndroid Build Coastguard Worker IsInputSupported(const TensorInfo &,Optional<std::string &>) const333*89c4ff92SAndroid Build Coastguard Worker bool IsInputSupported(const TensorInfo& /*input*/, 334*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override 335*89c4ff92SAndroid Build Coastguard Worker { 336*89c4ff92SAndroid Build Coastguard Worker return true; 337*89c4ff92SAndroid Build Coastguard Worker } 338*89c4ff92SAndroid Build Coastguard Worker IsOutputSupported(const TensorInfo &,Optional<std::string &>) const339*89c4ff92SAndroid Build Coastguard Worker bool IsOutputSupported(const TensorInfo& /*input*/, 340*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override 341*89c4ff92SAndroid Build Coastguard Worker { 342*89c4ff92SAndroid Build Coastguard Worker return true; 343*89c4ff92SAndroid Build Coastguard Worker } 344*89c4ff92SAndroid Build Coastguard Worker IsAdditionSupported(const TensorInfo &,const TensorInfo &,const TensorInfo &,Optional<std::string &>) const345*89c4ff92SAndroid Build Coastguard Worker bool IsAdditionSupported(const TensorInfo& /*input0*/, 346*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& /*input1*/, 347*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& /*output*/, 348*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override 349*89c4ff92SAndroid Build Coastguard Worker { 350*89c4ff92SAndroid Build Coastguard Worker return true; 351*89c4ff92SAndroid Build Coastguard Worker } 352*89c4ff92SAndroid Build Coastguard Worker IsConvolution2dSupported(const TensorInfo &,const TensorInfo &,const Convolution2dDescriptor &,const TensorInfo &,const Optional<TensorInfo> &,Optional<std::string &>) const353*89c4ff92SAndroid Build Coastguard Worker bool IsConvolution2dSupported(const TensorInfo& /*input*/, 354*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& /*output*/, 355*89c4ff92SAndroid Build Coastguard Worker const Convolution2dDescriptor& /*descriptor*/, 356*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& /*weights*/, 357*89c4ff92SAndroid Build Coastguard Worker const Optional<TensorInfo>& /*biases*/, 358*89c4ff92SAndroid Build Coastguard Worker Optional<std::string&> /*reasonIfUnsupported = EmptyOptional()*/) const override 359*89c4ff92SAndroid Build Coastguard Worker { 360*89c4ff92SAndroid Build Coastguard Worker return true; 361*89c4ff92SAndroid Build Coastguard Worker } 362*89c4ff92SAndroid Build Coastguard Worker }; 363*89c4ff92SAndroid Build Coastguard Worker 364*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn 365