xref: /aosp_15_r20/external/armnn/src/armnnTestUtils/MockTensorHandle.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "armnnTestUtils/MockTensorHandle.hpp"
7 
8 namespace armnn
9 {
10 
MockTensorHandle(const TensorInfo & tensorInfo,std::shared_ptr<MockMemoryManager> & memoryManager)11 MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, std::shared_ptr<MockMemoryManager>& memoryManager)
12     : m_TensorInfo(tensorInfo)
13     , m_MemoryManager(memoryManager)
14     , m_Pool(nullptr)
15     , m_UnmanagedMemory(nullptr)
16     , m_ImportFlags(static_cast<MemorySourceFlags>(MemorySource::Undefined))
17     , m_Imported(false)
18     , m_IsImportEnabled(false)
19 {}
20 
MockTensorHandle(const TensorInfo & tensorInfo,MemorySourceFlags importFlags)21 MockTensorHandle::MockTensorHandle(const TensorInfo& tensorInfo, MemorySourceFlags importFlags)
22     : m_TensorInfo(tensorInfo)
23     , m_Pool(nullptr)
24     , m_UnmanagedMemory(nullptr)
25     , m_ImportFlags(importFlags)
26     , m_Imported(false)
27     , m_IsImportEnabled(true)
28 {}
29 
~MockTensorHandle()30 MockTensorHandle::~MockTensorHandle()
31 {
32     if (!m_Pool)
33     {
34         // unmanaged
35         if (!m_Imported)
36         {
37             ::operator delete(m_UnmanagedMemory);
38         }
39     }
40 }
41 
Manage()42 void MockTensorHandle::Manage()
43 {
44     if (!m_IsImportEnabled)
45     {
46         ARMNN_ASSERT_MSG(!m_Pool, "MockTensorHandle::Manage() called twice");
47         ARMNN_ASSERT_MSG(!m_UnmanagedMemory, "MockTensorHandle::Manage() called after Allocate()");
48 
49         m_Pool = m_MemoryManager->Manage(m_TensorInfo.GetNumBytes());
50     }
51 }
52 
Allocate()53 void MockTensorHandle::Allocate()
54 {
55     // If import is enabled, do not allocate the tensor
56     if (!m_IsImportEnabled)
57     {
58 
59         if (!m_UnmanagedMemory)
60         {
61             if (!m_Pool)
62             {
63                 // unmanaged
64                 m_UnmanagedMemory = ::operator new(m_TensorInfo.GetNumBytes());
65             }
66             else
67             {
68                 m_MemoryManager->Allocate(m_Pool);
69             }
70         }
71         else
72         {
73             throw InvalidArgumentException("MockTensorHandle::Allocate Trying to allocate a MockTensorHandle"
74                                            "that already has allocated memory.");
75         }
76     }
77 }
78 
Map(bool) const79 const void* MockTensorHandle::Map(bool /*unused*/) const
80 {
81     return GetPointer();
82 }
83 
GetPointer() const84 void* MockTensorHandle::GetPointer() const
85 {
86     if (m_UnmanagedMemory)
87     {
88         return m_UnmanagedMemory;
89     }
90     else if (m_Pool)
91     {
92         return m_MemoryManager->GetPointer(m_Pool);
93     }
94     else
95     {
96         throw NullPointerException("MockTensorHandle::GetPointer called on unmanaged, unallocated tensor handle");
97     }
98 }
99 
CopyOutTo(void * dest) const100 void MockTensorHandle::CopyOutTo(void* dest) const
101 {
102     const void* src = GetPointer();
103     ARMNN_ASSERT(src);
104     memcpy(dest, src, m_TensorInfo.GetNumBytes());
105 }
106 
CopyInFrom(const void * src)107 void MockTensorHandle::CopyInFrom(const void* src)
108 {
109     void* dest = GetPointer();
110     ARMNN_ASSERT(dest);
111     memcpy(dest, src, m_TensorInfo.GetNumBytes());
112 }
113 
Import(void * memory,MemorySource source)114 bool MockTensorHandle::Import(void* memory, MemorySource source)
115 {
116     if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
117     {
118         if (m_IsImportEnabled && source == MemorySource::Malloc)
119         {
120             // Check memory alignment
121             if (!CanBeImported(memory, source))
122             {
123                 if (m_Imported)
124                 {
125                     m_Imported        = false;
126                     m_UnmanagedMemory = nullptr;
127                 }
128 
129                 return false;
130             }
131 
132             // m_UnmanagedMemory not yet allocated.
133             if (!m_Imported && !m_UnmanagedMemory)
134             {
135                 m_UnmanagedMemory = memory;
136                 m_Imported        = true;
137                 return true;
138             }
139 
140             // m_UnmanagedMemory initially allocated with Allocate().
141             if (!m_Imported && m_UnmanagedMemory)
142             {
143                 return false;
144             }
145 
146             // m_UnmanagedMemory previously imported.
147             if (m_Imported)
148             {
149                 m_UnmanagedMemory = memory;
150                 return true;
151             }
152         }
153     }
154 
155     return false;
156 }
157 
CanBeImported(void * memory,MemorySource source)158 bool MockTensorHandle::CanBeImported(void* memory, MemorySource source)
159 {
160     if (m_ImportFlags & static_cast<MemorySourceFlags>(source))
161     {
162         if (m_IsImportEnabled && source == MemorySource::Malloc)
163         {
164             uintptr_t alignment = GetDataTypeSize(m_TensorInfo.GetDataType());
165             if (reinterpret_cast<uintptr_t>(memory) % alignment)
166             {
167                 return false;
168             }
169 
170             return true;
171         }
172     }
173     return false;
174 }
175 
176 }    // namespace armnn
177