xref: /aosp_15_r20/external/armnn/src/backends/reference/RefTensorHandle.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #include "RefTensorHandle.hpp"
6 
7 namespace armnn
8 {
9 
RefTensorHandle(const TensorInfo & tensorInfo,std::shared_ptr<RefMemoryManager> & memoryManager)10 RefTensorHandle::RefTensorHandle(const TensorInfo &tensorInfo, std::shared_ptr<RefMemoryManager> &memoryManager):
11     m_TensorInfo(tensorInfo),
12     m_MemoryManager(memoryManager),
13     m_Pool(nullptr),
14     m_UnmanagedMemory(nullptr),
15     m_ImportedMemory(nullptr)
16 {
17 
18 }
19 
RefTensorHandle(const TensorInfo & tensorInfo)20 RefTensorHandle::RefTensorHandle(const TensorInfo& tensorInfo)
21                                  : m_TensorInfo(tensorInfo),
22                                    m_Pool(nullptr),
23                                    m_UnmanagedMemory(nullptr),
24                                    m_ImportedMemory(nullptr)
25 {
26 
27 }
28 
~RefTensorHandle()29 RefTensorHandle::~RefTensorHandle()
30 {
31     ::operator delete(m_UnmanagedMemory);
32 }
33 
Manage()34 void RefTensorHandle::Manage()
35 {
36     ARMNN_ASSERT_MSG(!m_Pool, "RefTensorHandle::Manage() called twice");
37     ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "RefTensorHandle::Manage() called after Allocate()");
38 
39     if (m_MemoryManager)
40     {
41         m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
42     }
43 }
44 
Allocate()45 void RefTensorHandle::Allocate()
46 {
47     if (!m_UnmanagedMemory)
48     {
49         if (!m_Pool)
50         {
51             // unmanaged
52             m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
53         }
54         else
55         {
56             m_MemoryManager->Allocate(m_Pool);
57         }
58     }
59     else
60     {
61         throw InvalidArgumentException("RefTensorHandle::Allocate Trying to allocate a RefTensorHandle"
62                                        "that already has allocated memory.");
63     }
64 }
65 
Map(bool) const66 const void* RefTensorHandle::Map(bool /*unused*/) const
67 {
68     return GetPointer();
69 }
70 
GetPointer() const71 void* RefTensorHandle::GetPointer() const
72 {
73     if (m_ImportedMemory)
74     {
75         return m_ImportedMemory;
76     }
77     else if (m_UnmanagedMemory)
78     {
79         return m_UnmanagedMemory;
80     }
81     else if (m_Pool)
82     {
83         return m_MemoryManager->GetPointer(m_Pool);
84     }
85     else
86     {
87         throw NullPointerException("RefTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
88     }
89 }
90 
CopyOutTo(void * dest) const91 void RefTensorHandle::CopyOutTo(void* dest) const
92 {
93     const void *src = GetPointer();
94     ARMNN_ASSERT(src);
95     memcpy(dest, src, m_TensorInfo.GetNumBytes());
96 }
97 
CopyInFrom(const void * src)98 void RefTensorHandle::CopyInFrom(const void* src)
99 {
100     void *dest = GetPointer();
101     ARMNN_ASSERT(dest);
102     memcpy(dest, src, m_TensorInfo.GetNumBytes());
103 }
104 
GetImportFlags() const105 MemorySourceFlags RefTensorHandle::GetImportFlags() const
106 {
107     return static_cast<MemorySourceFlags>(MemorySource::Malloc);
108 }
109 
Import(void * memory,MemorySource source)110 bool RefTensorHandle::Import(void* memory, MemorySource source)
111 {
112     if (source == MemorySource::Malloc)
113     {
114         // Check memory alignment
115         if(!CanBeImported(memory, source))
116         {
117             m_ImportedMemory = nullptr;
118             return false;
119         }
120 
121         m_ImportedMemory = memory;
122         return true;
123     }
124 
125     return false;
126 }
127 
CanBeImported(void * memory,MemorySource source)128 bool RefTensorHandle::CanBeImported(void *memory, MemorySource source)
129 {
130     if (source == MemorySource::Malloc)
131     {
132         uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
133         if (reinterpret_cast<uintptr_t>(memory) % alignment)
134         {
135             return false;
136         }
137         return true;
138     }
139     return false;
140 }
141 
142 }
143