xref: /aosp_15_r20/external/armnn/src/armnnTestUtils/MockTensorHandleFactory.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/backends/ITensorHandleFactory.hpp>
9 #include <armnnTestUtils/MockMemoryManager.hpp>
10 
11 namespace armnn
12 {
13 
MockTensorHandleFactoryId()14 constexpr const char* MockTensorHandleFactoryId()
15 {
16     return "Arm/Mock/TensorHandleFactory";
17 }
18 
19 class MockTensorHandleFactory : public ITensorHandleFactory
20 {
21 
22 public:
MockTensorHandleFactory(std::shared_ptr<MockMemoryManager> mgr)23     explicit MockTensorHandleFactory(std::shared_ptr<MockMemoryManager> mgr)
24         : m_MemoryManager(mgr)
25         , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
26         , m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
27     {}
28 
29     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
30                                                          TensorShape const& subTensorShape,
31                                                          unsigned int const* subTensorOrigin) const override;
32 
33     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
34 
35     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
36                                                       DataLayout dataLayout) const override;
37 
38     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
39                                                       const bool IsMemoryManaged) const override;
40 
41     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
42                                                       DataLayout dataLayout,
43                                                       const bool IsMemoryManaged) const override;
44 
45     static const FactoryId& GetIdStatic();
46 
47     const FactoryId& GetId() const override;
48 
49     bool SupportsSubTensors() const override;
50 
51     MemorySourceFlags GetExportFlags() const override;
52 
53     MemorySourceFlags GetImportFlags() const override;
54 
55 private:
56     mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
57     MemorySourceFlags m_ImportFlags;
58     MemorySourceFlags m_ExportFlags;
59 };
60 
61 }    // namespace armnn
62