xref: /aosp_15_r20/external/armnn/src/armnnTestUtils/MockTensorHandleFactory.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/backends/ITensorHandleFactory.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnTestUtils/MockMemoryManager.hpp>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker namespace armnn
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker 
MockTensorHandleFactoryId()14*89c4ff92SAndroid Build Coastguard Worker constexpr const char* MockTensorHandleFactoryId()
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker     return "Arm/Mock/TensorHandleFactory";
17*89c4ff92SAndroid Build Coastguard Worker }
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker class MockTensorHandleFactory : public ITensorHandleFactory
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker public:
MockTensorHandleFactory(std::shared_ptr<MockMemoryManager> mgr)23*89c4ff92SAndroid Build Coastguard Worker     explicit MockTensorHandleFactory(std::shared_ptr<MockMemoryManager> mgr)
24*89c4ff92SAndroid Build Coastguard Worker         : m_MemoryManager(mgr)
25*89c4ff92SAndroid Build Coastguard Worker         , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
26*89c4ff92SAndroid Build Coastguard Worker         , m_ExportFlags(static_cast<MemorySourceFlags>(MemorySource::Malloc))
27*89c4ff92SAndroid Build Coastguard Worker     {}
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
30*89c4ff92SAndroid Build Coastguard Worker                                                          TensorShape const& subTensorShape,
31*89c4ff92SAndroid Build Coastguard Worker                                                          unsigned int const* subTensorOrigin) const override;
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
36*89c4ff92SAndroid Build Coastguard Worker                                                       DataLayout dataLayout) const override;
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
39*89c4ff92SAndroid Build Coastguard Worker                                                       const bool IsMemoryManaged) const override;
40*89c4ff92SAndroid Build Coastguard Worker 
41*89c4ff92SAndroid Build Coastguard Worker     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
42*89c4ff92SAndroid Build Coastguard Worker                                                       DataLayout dataLayout,
43*89c4ff92SAndroid Build Coastguard Worker                                                       const bool IsMemoryManaged) const override;
44*89c4ff92SAndroid Build Coastguard Worker 
45*89c4ff92SAndroid Build Coastguard Worker     static const FactoryId& GetIdStatic();
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker     const FactoryId& GetId() const override;
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker     bool SupportsSubTensors() const override;
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags GetExportFlags() const override;
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags GetImportFlags() const override;
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker private:
56*89c4ff92SAndroid Build Coastguard Worker     mutable std::shared_ptr<MockMemoryManager> m_MemoryManager;
57*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags m_ImportFlags;
58*89c4ff92SAndroid Build Coastguard Worker     MemorySourceFlags m_ExportFlags;
59*89c4ff92SAndroid Build Coastguard Worker };
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker }    // namespace armnn
62