xref: /aosp_15_r20/external/armnn/src/backends/reference/test/RefTensorHandleTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 
7 #include <doctest/doctest.h>
8 #include <armnn/BackendId.hpp>
9 #include <armnn/INetwork.hpp>
10 #include <armnn/Tensor.hpp>
11 #include <armnn/Types.hpp>
12 #include <armnn/backends/ITensorHandle.hpp>
13 #include <armnn/backends/ITensorHandleFactory.hpp>
14 #include <armnn/backends/TensorHandle.hpp>
15 #include <armnn/utility/Assert.hpp>
16 #include <reference/RefTensorHandle.hpp>
17 #include <reference/RefTensorHandleFactory.hpp>
18 #include <reference/RefMemoryManager.hpp>
19 #include <memory>
20 #include <vector>
21 
22 namespace armnn
23 {
24 class Exception;
25 class NullPointerException;
26 }
27 
28 TEST_SUITE("RefTensorHandleTests")
29 {
30 using namespace armnn;
31 
32 TEST_CASE("AcquireAndRelease")
33 {
34     std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
35 
36     TensorInfo info({ 1, 1, 1, 1 }, DataType::Float32);
37     RefTensorHandle handle(info, memoryManager);
38 
39     handle.Manage();
40     handle.Allocate();
41 
42     memoryManager->Acquire();
43     {
44         float* buffer = reinterpret_cast<float*>(handle.Map());
45 
46         CHECK(buffer != nullptr); // Yields a valid pointer
47 
48         buffer[0] = 2.5f;
49 
50         CHECK(buffer[0] == 2.5f); // Memory is writable and readable
51 
52     }
53     memoryManager->Release();
54 
55     memoryManager->Acquire();
56     {
57         float* buffer = reinterpret_cast<float*>(handle.Map());
58 
59         CHECK(buffer != nullptr); // Yields a valid pointer
60 
61         buffer[0] = 3.5f;
62 
63         CHECK(buffer[0] == 3.5f); // Memory is writable and readable
64     }
65     memoryManager->Release();
66 }
67 
68 TEST_CASE("RefTensorHandleFactoryMemoryManaged")
69 {
70     std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
71     RefTensorHandleFactory handleFactory(memoryManager);
72     TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
73 
74     // create TensorHandle with memory managed
75     auto handle = handleFactory.CreateTensorHandle(info, true);
76     handle->Manage();
77     handle->Allocate();
78 
79     memoryManager->Acquire();
80     {
81         float* buffer = reinterpret_cast<float*>(handle->Map());
82         CHECK(buffer != nullptr); // Yields a valid pointer
83         buffer[0] = 1.5f;
84         buffer[1] = 2.5f;
85         CHECK(buffer[0] == 1.5f); // Memory is writable and readable
86         CHECK(buffer[1] == 2.5f); // Memory is writable and readable
87     }
88     memoryManager->Release();
89 
90     memoryManager->Acquire();
91     {
92         float* buffer = reinterpret_cast<float*>(handle->Map());
93         CHECK(buffer != nullptr); // Yields a valid pointer
94         buffer[0] = 3.5f;
95         buffer[1] = 4.5f;
96         CHECK(buffer[0] == 3.5f); // Memory is writable and readable
97         CHECK(buffer[1] == 4.5f); // Memory is writable and readable
98     }
99     memoryManager->Release();
100 
101     float testPtr[2] = { 2.5f, 5.5f };
102     // Check import overlays contents
103     CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
104     {
105         float* buffer = reinterpret_cast<float*>(handle->Map());
106         CHECK(buffer != nullptr); // Yields a valid pointer
107         CHECK(buffer[0] == 2.5f); // Memory is writable and readable
108         CHECK(buffer[1] == 5.5f); // Memory is writable and readable
109     }
110 }
111 
112 TEST_CASE("RefTensorHandleFactoryImport")
113 {
114     std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
115     RefTensorHandleFactory handleFactory(memoryManager);
116     TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
117 
118     // create TensorHandle without memory managed
119     auto handle = handleFactory.CreateTensorHandle(info, false);
120     handle->Manage();
121     handle->Allocate();
122     memoryManager->Acquire();
123 
124     // Check storage has been allocated
125     void* unmanagedStorage = handle->Map();
126     CHECK(unmanagedStorage != nullptr);
127 
128     // Check importing overlays the storage
129     float testPtr[2] = { 2.5f, 5.5f };
130     CHECK(handle->Import(static_cast<void*>(testPtr), MemorySource::Malloc));
131     float* buffer = reinterpret_cast<float*>(handle->Map());
132     CHECK(buffer != nullptr); // Yields a valid pointer after import
133     CHECK(buffer == testPtr); // buffer is pointing to testPtr
134     // Memory is writable and readable with correct value
135     CHECK(buffer[0] == 2.5f);
136     CHECK(buffer[1] == 5.5f);
137     buffer[0] = 3.5f;
138     buffer[1] = 10.0f;
139     CHECK(buffer[0] == 3.5f);
140     CHECK(buffer[1] == 10.0f);
141     memoryManager->Release();
142 }
143 
144 TEST_CASE("RefTensorHandleImport")
145 {
146     TensorInfo info({ 1, 1, 2, 1 }, DataType::Float32);
147     RefTensorHandle handle(info);
148 
149     handle.Manage();
150     handle.Allocate();
151 
152     // Check unmanaged memory allocated
153     CHECK(handle.Map());
154 
155     float testPtr[2] = { 2.5f, 5.5f };
156     // Check imoport overlays the unamaged memory
157     CHECK(handle.Import(static_cast<void*>(testPtr), MemorySource::Malloc));
158     float* buffer = reinterpret_cast<float*>(handle.Map());
159     CHECK(buffer != nullptr); // Yields a valid pointer after import
160     CHECK(buffer == testPtr); // buffer is pointing to testPtr
161     // Memory is writable and readable with correct value
162     CHECK(buffer[0] == 2.5f);
163     CHECK(buffer[1] == 5.5f);
164     buffer[0] = 3.5f;
165     buffer[1] = 10.0f;
166     CHECK(buffer[0] == 3.5f);
167     CHECK(buffer[1] == 10.0f);
168 }
169 
170 TEST_CASE("RefTensorHandleGetCapabilities")
171 {
172     std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
173     RefTensorHandleFactory handleFactory(memoryManager);
174 
175     // Builds up the structure of the network.
176     INetworkPtr net(INetwork::Create());
177     IConnectableLayer* input = net->AddInputLayer(0);
178     IConnectableLayer* output = net->AddOutputLayer(0);
179     input->GetOutputSlot(0).Connect(output->GetInputSlot(0));
180 
181     std::vector<Capability> capabilities = handleFactory.GetCapabilities(input,
182                                                                          output,
183                                                                          CapabilityClass::PaddingRequired);
184     CHECK(capabilities.empty());
185 }
186 
187 TEST_CASE("RefTensorHandleSupportsInPlaceComputation")
188 {
189     std::shared_ptr<RefMemoryManager> memoryManager = std::make_shared<RefMemoryManager>();
190     RefTensorHandleFactory handleFactory(memoryManager);
191 
192     // RefTensorHandleFactory does not support InPlaceComputation
193     ARMNN_ASSERT(!(handleFactory.SupportsInPlaceComputation()));
194 }
195 
196 TEST_CASE("TestManagedConstTensorHandle")
197 {
198     // Initialize arguments
199     void* mem = nullptr;
200     TensorInfo info;
201 
202     // Use PassthroughTensor as others are abstract
203     auto passThroughHandle = std::make_shared<PassthroughTensorHandle>(info, mem);
204 
205     // Test managed handle is initialized with m_Mapped unset and once Map() called its set
206     ManagedConstTensorHandle managedHandle(passThroughHandle);
207     CHECK(!managedHandle.IsMapped());
208     managedHandle.Map();
209     CHECK(managedHandle.IsMapped());
210 
211     // Test it can then be unmapped
212     managedHandle.Unmap();
213     CHECK(!managedHandle.IsMapped());
214 
215     // Test member function
216     CHECK(managedHandle.GetTensorInfo() == info);
217 
218     // Test that nullptr tensor handle doesn't get mapped
219     ManagedConstTensorHandle managedHandleNull(nullptr);
220     CHECK(!managedHandleNull.IsMapped());
221     CHECK_THROWS_AS(managedHandleNull.Map(), armnn::Exception);
222     CHECK(!managedHandleNull.IsMapped());
223 
224     // Check Unmap() when m_Mapped already false
225     managedHandleNull.Unmap();
226     CHECK(!managedHandleNull.IsMapped());
227 }
228 
229 #if !defined(__ANDROID__)
230 // Only run these tests on non Android platforms
231 TEST_CASE("CheckSourceType")
232 {
233     TensorInfo info({1}, DataType::Float32);
234     RefTensorHandle handle(info);
235 
236     int* testPtr = new int(4);
237 
238     // Not supported
239     CHECK(!handle.Import(static_cast<void *>(testPtr), MemorySource::DmaBuf));
240 
241     // Not supported
242     CHECK(!handle.Import(static_cast<void *>(testPtr), MemorySource::DmaBufProtected));
243 
244     // Supported
245     CHECK(handle.Import(static_cast<void *>(testPtr), MemorySource::Malloc));
246 
247     delete testPtr;
248 }
249 
250 TEST_CASE("ReusePointer")
251 {
252     TensorInfo info({1}, DataType::Float32);
253     RefTensorHandle handle(info);
254 
255     int* testPtr = new int(4);
256 
257     handle.Import(static_cast<void *>(testPtr), MemorySource::Malloc);
258 
259     // Reusing previously Imported pointer
260     CHECK(handle.Import(static_cast<void *>(testPtr), MemorySource::Malloc));
261 
262     delete testPtr;
263 }
264 
265 TEST_CASE("MisalignedPointer")
266 {
267     TensorInfo info({2}, DataType::Float32);
268     RefTensorHandle handle(info);
269 
270     // Allocate a 2 int array
271     int* testPtr = new int[2];
272 
273     // Increment pointer by 1 byte
274     void* misalignedPtr = static_cast<void*>(reinterpret_cast<char*>(testPtr) + 1);
275 
276     CHECK(!handle.Import(misalignedPtr, MemorySource::Malloc));
277 
278     delete[] testPtr;
279 }
280 
281 TEST_CASE("CheckCanBeImported")
282 {
283     TensorInfo info({1}, DataType::Float32);
284     RefTensorHandle handle(info);
285 
286     int* testPtr = new int(4);
287 
288     // Not supported
289     CHECK(!handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::DmaBuf));
290 
291     // Supported
292     CHECK(handle.CanBeImported(static_cast<void *>(testPtr), MemorySource::Malloc));
293 
294     delete testPtr;
295 
296 }
297 
298 TEST_CASE("MisalignedCanBeImported")
299 {
300     TensorInfo info({2}, DataType::Float32);
301     RefTensorHandle handle(info);
302 
303     // Allocate a 2 int array
304     int* testPtr = new int[2];
305 
306     // Increment pointer by 1 byte
307     void* misalignedPtr = static_cast<void*>(reinterpret_cast<char*>(testPtr) + 1);
308 
309     CHECK(!handle.Import(misalignedPtr, MemorySource::Malloc));
310 
311     delete[] testPtr;
312 }
313 
314 #endif
315 
316 }
317