1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #include "armnnTestUtils/MockTensorHandle.hpp"
7
8 namespace armnn
9 {
10
MockTensorHandle(const TensorInfo & tensorInfo,std::shared_ptr<MockMemoryManager> & memoryManager)11 MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager)
12 : m_TensorInfo(tensorInfo)
13 , m_MemoryManager(memoryManager)
14 , m_Pool(nullptr)
15 , m_UnmanagedMemory(nullptr)
16 , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17 , m_Imported(false)
18 , m_IsImportEnabled(false)
19 {}
20
MockTensorHandle(const TensorInfo & tensorInfo,MemorySourceFlags importFlags)21 MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
22 : m_TensorInfo(tensorInfo)
23 , m_Pool(nullptr)
24 , m_UnmanagedMemory(nullptr)
25 , m_ImportFlags(importFlags)
26 , m_Imported(false)
27 , m_IsImportEnabled(true)
28 {}
29
~MockTensorHandle()30 MockTensorHandle::~MockTensorHandle()
31 {
32 if (!m_Pool)
33 {
34 // unmanaged
35 if (!m_Imported)
36 {
37 ::operator delete(m_UnmanagedMemory);
38 }
39 }
40 }
41
Manage()42 void MockTensorHandle::Manage()
43 {
44 if (!m_IsImportEnabled)
45 {
46 ARMNN_ASSERT_MSG(!m_Pool, "MockTensorHandle::Manage() called twice");
47 ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "MockTensorHandle::Manage() called after Allocate()");
48
49 m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
50 }
51 }
52
Allocate()53 void MockTensorHandle::Allocate()
54 {
55 // If import is enabled, do not allocate the tensor
56 if (!m_IsImportEnabled)
57 {
58
59 if (!m_UnmanagedMemory)
60 {
61 if (!m_Pool)
62 {
63 // unmanaged
64 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
65 }
66 else
67 {
68 m_MemoryManager->Allocate(m_Pool);
69 }
70 }
71 else
72 {
73 throw InvalidArgumentException("MockTensorHandle::Allocate Trying to allocate a MockTensorHandle"
74 "that already has allocated memory.");
75 }
76 }
77 }
78
Map(bool) const79 const void* MockTensorHandle::Map(bool /*unused*/) const
80 {
81 return GetPointer();
82 }
83
GetPointer() const84 void* MockTensorHandle::GetPointer() const
85 {
86 if (m_UnmanagedMemory)
87 {
88 return m_UnmanagedMemory;
89 }
90 else if (m_Pool)
91 {
92 return m_MemoryManager->GetPointer(m_Pool);
93 }
94 else
95 {
96 throw NullPointerException("MockTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
97 }
98 }
99
CopyOutTo(void * dest) const100 void MockTensorHandle::CopyOutTo(void* dest) const
101 {
102 const void* src = GetPointer();
103 ARMNN_ASSERT(src);
104 memcpy(dest, src, m_TensorInfo.GetNumBytes());
105 }
106
CopyInFrom(const void * src)107 void MockTensorHandle::CopyInFrom(const void* src)
108 {
109 void* dest = GetPointer();
110 ARMNN_ASSERT(dest);
111 memcpy(dest, src, m_TensorInfo.GetNumBytes());
112 }
113
Import(void * memory,MemorySource source)114 bool MockTensorHandle::Import(void* memory, MemorySource source)
115 {
116 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117 {
118 if (m_IsImportEnabled && source == MemorySource::Malloc)
119 {
120 // Check memory alignment
121 if (!CanBeImported(memory, source))
122 {
123 if (m_Imported)
124 {
125 m_Imported = false;
126 m_UnmanagedMemory = nullptr;
127 }
128
129 return false;
130 }
131
132 // m_UnmanagedMemory not yet allocated.
133 if (!m_Imported && !m_UnmanagedMemory)
134 {
135 m_UnmanagedMemory = memory;
136 m_Imported = true;
137 return true;
138 }
139
140 // m_UnmanagedMemory initially allocated with Allocate().
141 if (!m_Imported && m_UnmanagedMemory)
142 {
143 return false;
144 }
145
146 // m_UnmanagedMemory previously imported.
147 if (m_Imported)
148 {
149 m_UnmanagedMemory = memory;
150 return true;
151 }
152 }
153 }
154
155 return false;
156 }
157
CanBeImported(void * memory,MemorySource source)158 bool MockTensorHandle::CanBeImported(void* memory, MemorySource source)
159 {
160 if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
161 {
162 if (m_IsImportEnabled && source == MemorySource::Malloc)
163 {
164 uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
165 if (reinterpret_cast<uintptr_t>(memory) % alignment)
166 {
167 return false;
168 }
169
170 return true;
171 }
172 }
173 return false;
174 }
175
176 } // namespace armnn
177