xref: /aosp_15_r20/external/armnn/src/backends/reference/RefTensorHandle.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include <armnn/backends/TensorHandle.hpp>
8 
9 #include "RefMemoryManager.hpp"
10 
11 namespace armnn
12 {
13 
14 // An implementation of ITensorHandle with simple "bump the pointer" memory-management behaviour
15 class RefTensorHandle : public ITensorHandle
16 {
17 public:
18     RefTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager);
19 
20     RefTensorHandle(const TensorInfo& tensorInfo);
21 
22     ~RefTensorHandle();
23 
24     virtual void Manage() override;
25 
26     virtual void Allocate() override;
27 
GetParent() const28     virtual ITensorHandle* GetParent() const override
29     {
30         return nullptr;
31     }
32 
33     virtual const void* Map(bool /* blocking = true */) const override;
34     using ITensorHandle::Map;
35 
Unmap() const36     virtual void Unmap() const override
37     {}
38 
GetStrides() const39     TensorShape GetStrides() const override
40     {
41         return GetUnpaddedTensorStrides(m_TensorInfo);
42     }
43 
GetShape() const44     TensorShape GetShape() const override
45     {
46         return m_TensorInfo.GetShape();
47     }
48 
GetTensorInfo() const49     const TensorInfo& GetTensorInfo() const
50     {
51         return m_TensorInfo;
52     }
53 
54     virtual MemorySourceFlags GetImportFlags() const override;
55 
56     virtual bool Import(void* memory, MemorySource source) override;
57     virtual bool CanBeImported(void* memory, MemorySource source) override;
58 
59 private:
60     // Only used for testing
61     void CopyOutTo(void*) const override;
62     void CopyInFrom(const void*) override;
63 
64     void* GetPointer() const;
65 
66     RefTensorHandle(const RefTensorHandle& other) = delete; // noncopyable
67     RefTensorHandle& operator=(const RefTensorHandle& other) = delete; //noncopyable
68 
69     TensorInfo m_TensorInfo;
70 
71     std::shared_ptr<RefMemoryManager> m_MemoryManager;
72     RefMemoryManager::Pool* m_Pool;
73     mutable void* m_UnmanagedMemory;
74     void* m_ImportedMemory;
75 };
76 
77 }
78