xref: /aosp_15_r20/external/armnn/src/backends/cl/test/ClImportTensorHandleFactoryTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include <armnn/utility/Assert.hpp>
7 
8 #include <cl/ClImportTensorHandleFactory.hpp>
9 
10 #include <doctest/doctest.h>
11 
12 TEST_SUITE("ClImportTensorHandleFactoryTests")
13 {
14 using namespace armnn;
15 
16 TEST_CASE("ImportTensorFactoryAskedToCreateManagedTensorThrowsException")
17 {
18     // Create the factory to import tensors.
19     ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
20                                         static_cast<MemorySourceFlags>(MemorySource::Malloc));
21     TensorInfo tensorInfo;
22     // This factory is designed to import the memory of tensors. Asking for a handle that requires
23     // a memory manager should result in an exception.
24     REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, true), InvalidArgumentException);
25     REQUIRE_THROWS_AS(factory.CreateTensorHandle(tensorInfo, DataLayout::NCHW, true), InvalidArgumentException);
26 }
27 
28 TEST_CASE("ImportTensorFactoryCreateMallocTensorHandle")
29 {
30     // Create the factory to import tensors.
31     ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
32                                         static_cast<MemorySourceFlags>(MemorySource::Malloc));
33     TensorShape tensorShape{ 6, 7, 8, 9 };
34     TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
35     // Start with the TensorInfo factory method. Create an import tensor handle and verify the data is
36     // passed through correctly.
37     auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
38     ARMNN_ASSERT(tensorHandle);
39     ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
40     ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
41 
42     // Same method but explicitly specifying isManaged = false.
43     tensorHandle = factory.CreateTensorHandle(tensorInfo, false);
44     CHECK(tensorHandle);
45     ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
46     ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
47 
48     // Now try TensorInfo and DataLayout factory method.
49     tensorHandle = factory.CreateTensorHandle(tensorInfo, DataLayout::NHWC);
50     CHECK(tensorHandle);
51     ARMNN_ASSERT(tensorHandle->GetImportFlags() == static_cast<MemorySourceFlags>(MemorySource::Malloc));
52     ARMNN_ASSERT(tensorHandle->GetShape() == tensorShape);
53 }
54 
55 TEST_CASE("CreateSubtensorOfImportTensor")
56 {
57     // Create the factory to import tensors.
58     ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
59                                         static_cast<MemorySourceFlags>(MemorySource::Malloc));
60     // Create a standard inport tensor.
61     TensorShape tensorShape{ 224, 224, 1, 1 };
62     TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
63     auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
64     // Use the factory to create a 16x16 sub tensor.
65     TensorShape subTensorShape{ 16, 16, 1, 1 };
66     // Starting at an offset of 1x1.
67     uint32_t origin[4] = { 1, 1, 0, 0 };
68     auto subTensor     = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
69     CHECK(subTensor);
70     ARMNN_ASSERT(subTensor->GetShape() == subTensorShape);
71     ARMNN_ASSERT(subTensor->GetParent() == tensorHandle.get());
72 }
73 
74 TEST_CASE("CreateSubtensorNonZeroXYIsInvalid")
75 {
76     // Create the factory to import tensors.
77     ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
78                                         static_cast<MemorySourceFlags>(MemorySource::Malloc));
79     // Create a standard import tensor.
80     TensorShape tensorShape{ 224, 224, 1, 1 };
81     TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
82     auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
83     // Use the factory to create a 16x16 sub tensor.
84     TensorShape subTensorShape{ 16, 16, 1, 1 };
85     // This looks a bit backwards because of how Cl specifies tensors. Essentially we want to trigger our
86     // check "(coords.x() != 0 || coords.y() != 0)"
87     uint32_t origin[4] = { 0, 0, 1, 1 };
88     auto subTensor     = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
89     // We expect a nullptr.
90     ARMNN_ASSERT(subTensor == nullptr);
91 }
92 
93 TEST_CASE("CreateSubtensorXYMustMatchParent")
94 {
95     // Create the factory to import tensors.
96     ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
97                                         static_cast<MemorySourceFlags>(MemorySource::Malloc));
98     // Create a standard import tensor.
99     TensorShape tensorShape{ 224, 224, 1, 1 };
100     TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
101     auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
102     // Use the factory to create a 16x16 sub tensor but make the CL x and y axis different.
103     TensorShape subTensorShape{ 16, 16, 2, 2 };
104     // We want to trigger our ((parentShape.x() != shape.x()) || (parentShape.y() != shape.y()))
105     uint32_t origin[4] = { 1, 1, 0, 0 };
106     auto subTensor     = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
107     // We expect a nullptr.
108     ARMNN_ASSERT(subTensor == nullptr);
109 }
110 
111 TEST_CASE("CreateSubtensorMustBeSmallerThanParent")
112 {
113     // Create the factory to import tensors.
114     ClImportTensorHandleFactory factory(static_cast<MemorySourceFlags>(MemorySource::Malloc),
115                                         static_cast<MemorySourceFlags>(MemorySource::Malloc));
116     // Create a standard import tensor.
117     TensorShape tensorShape{ 224, 224, 1, 1 };
118     TensorInfo tensorInfo(tensorShape, armnn::DataType::Float32);
119     auto tensorHandle = factory.CreateTensorHandle(tensorInfo);
120     // Ask for a subtensor that's the same size as the parent.
121     TensorShape subTensorShape{ 224, 224, 1, 1 };
122     uint32_t origin[4] = { 1, 1, 0, 0 };
123     // This should result in a nullptr.
124     auto subTensor = factory.CreateSubTensorHandle(*tensorHandle, subTensorShape, origin);
125     ARMNN_ASSERT(subTensor == nullptr);
126 }
127 
128 }
129