1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #pragma once 7 8 #include <armnn/backends/ITensorHandleFactory.hpp> 9 #include <armnnTestUtils/MockMemoryManager.hpp> 10 11 namespace armnn 12 { 13 MockTensorHandleFactoryId()14constexpr const char* MockTensorHandleFactoryId() 15 { 16 return "Arm/Mock/TensorHandleFactory"; 17 } 18 19 class MockTensorHandleFactory : public ITensorHandleFactory 20 { 21 22 public: MockTensorHandleFactory(std::shared_ptr<MockMemoryManager> mgr)23 explicit MockTensorHandleFactory(std::shared_ptr<MockMemoryManager> mgr) 24 : m_MemoryManager(mgr) 25 , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)) 26 , m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc)) 27 {} 28 29 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent, 30 TensorShape const& subTensorShape, 31 unsigned int const* subTensorOrigin) const override; 32 33 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override; 34 35 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 36 DataLayout dataLayout) const override; 37 38 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 39 const bool IsMemoryManaged) const override; 40 41 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo, 42 DataLayout dataLayout, 43 const bool IsMemoryManaged) const override; 44 45 static const FactoryId& GetIdStatic(); 46 47 const FactoryId& GetId() const override; 48 49 bool SupportsSubTensors() const override; 50 51 MemorySourceFlags GetExportFlags() const override; 52 53 MemorySourceFlags GetImportFlags() const override; 54 55 private: 56 mutable std::shared_ptr<MockMemoryManager> m_MemoryManager; 57 MemorySourceFlags m_ImportFlags; 58 MemorySourceFlags m_ExportFlags; 59 }; 60 61 } // namespace armnn 62