xref: /aosp_15_r20/external/armnn/src/backends/cl/ClImportTensorHandleFactory.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <aclCommon/BaseMemoryManager.hpp>
8 #include <armnn/MemorySources.hpp>
9 #include <armnn/backends/IMemoryManager.hpp>
10 #include <armnn/backends/ITensorHandleFactory.hpp>
11 
12 namespace armnn
13 {
14 
ClImportTensorHandleFactoryId()15 constexpr const char* ClImportTensorHandleFactoryId()
16 {
17     return "Arm/Cl/ImportTensorHandleFactory";
18 }
19 
20 /**
21  * This factory creates ClImportTensorHandles that refer to imported memory tensors.
22  */
23 class ClImportTensorHandleFactory : public ITensorHandleFactory
24 {
25 public:
26     static const FactoryId m_Id;
27 
28     /**
29      * Create a tensor handle factory for tensors that will be imported or exported.
30      *
31      * @param importFlags
32      * @param exportFlags
33      */
ClImportTensorHandleFactory(MemorySourceFlags importFlags,MemorySourceFlags exportFlags)34     ClImportTensorHandleFactory(MemorySourceFlags importFlags, MemorySourceFlags exportFlags)
35         : m_ImportFlags(importFlags)
36         , m_ExportFlags(exportFlags)
37     {}
38 
39     std::unique_ptr<ITensorHandle> CreateSubTensorHandle(ITensorHandle& parent,
40                                                          const TensorShape& subTensorShape,
41                                                          const unsigned int* subTensorOrigin) const override;
42 
43     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo) const override;
44 
45     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
46                                                       DataLayout dataLayout) const override;
47 
48     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
49                                                       const bool IsMemoryManaged) const override;
50 
51     std::unique_ptr<ITensorHandle> CreateTensorHandle(const TensorInfo& tensorInfo,
52                                                       DataLayout dataLayout,
53                                                       const bool IsMemoryManaged) const override;
54 
55     static const FactoryId& GetIdStatic();
56 
57     const FactoryId& GetId() const override;
58 
59     bool SupportsSubTensors() const override;
60 
61     bool SupportsMapUnmap() const override;
62 
63     MemorySourceFlags GetExportFlags() const override;
64 
65     MemorySourceFlags GetImportFlags() const override;
66 
67     std::vector<Capability> GetCapabilities(const IConnectableLayer* layer,
68                                             const IConnectableLayer* connectedLayer,
69                                             CapabilityClass capabilityClass) override;
70 
71 private:
72     MemorySourceFlags m_ImportFlags;
73     MemorySourceFlags m_ExportFlags;
74 };
75 
76 }    // namespace armnn