1 //
2 // Copyright © 2021 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6
7 #include <aclCommon/BaseMemoryManager.hpp>
8 #include <armnn/MemorySources.hpp>
9 #include <armnn/backends/IMemoryManager.hpp>
10 #include <armnn/backends/ITensorHandleFactory.hpp>
11
12 namespace armnn
13 {
14
ClImportTensorHandleFactoryId()15 constexpr const char* ClImportTensorHandleFactoryId()
16 {
17 return "Arm/Cl/ImportTensorHandleFactory";
18 }
19
20 /**
21 * This factory creates ClImportTensorHandles that refer to imported memory tensors.
22 */
23 class ClImportTensorHandleFactory : public ITensorHandleFactory
24 {
25 public:
26 static const FactoryId m_Id;
27
28 /**
29 * Create a tensor handle factory for tensors that will be imported or exported.
30 *
31 * @param importFlags
32 * @param exportFlags
33 */
ClImportTensorHandleFactory(MemorySourceFlags importFlags,MemorySourceFlags exportFlags)34 ClImportTensorHandleFactory(MemorySourceFlags importFlags, MemorySourceFlags exportFlags)
35 : m_ImportFlags(importFlags)
36 , m_ExportFlags(exportFlags)
37 {}
38
39 std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
40 const TensorShape& subTensorShape,
41 const unsigned int* subTensorOrigin) const override;
42
43 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
44
45 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
46 DataLayout dataLayout) const override;
47
48 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
49 const bool IsMemoryManaged) const override;
50
51 std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
52 DataLayout dataLayout,
53 const bool IsMemoryManaged) const override;
54
55 static const FactoryId& GetIdStatic();
56
57 const FactoryId& GetId() const override;
58
59 bool SupportsSubTensors() const override;
60
61 bool SupportsMapUnmap() const override;
62
63 MemorySourceFlags GetExportFlags() const override;
64
65 MemorySourceFlags GetImportFlags() const override;
66
67 std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
68 const IConnectableLayer* connectedLayer,
69 CapabilityClass capabilityClass) override;
70
71 private:
72 MemorySourceFlags m_ImportFlags;
73 MemorySourceFlags m_ExportFlags;
74 };
75
76 } // namespace armnn