xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/layerTests/GatherNdTestImpl.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "GatherNdTestImpl.hpp"
7 
8 #include <DataTypeUtils.hpp>
9 #include <armnnTestUtils/TensorCopyUtils.hpp>
10 #include <armnnTestUtils/WorkloadTestUtils.hpp>
11 
12 namespace
13 {
14 
15 template<armnn::DataType ArmnnType,
16         typename T = armnn::ResolveType<ArmnnType>,
17         size_t ParamsDim,
18         size_t IndicesDim,
19         size_t OutputDim>
GatherNdTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const armnn::TensorInfo & paramsInfo,const armnn::TensorInfo & indicesInfo,const armnn::TensorInfo & outputInfo,const std::vector<T> & paramsData,const std::vector<int32_t> & indicesData,const std::vector<T> & outputData)20 LayerTestResult<T, OutputDim> GatherNdTestImpl(
21         armnn::IWorkloadFactory &workloadFactory,
22         const armnn::IBackendInternal::IMemoryManagerSharedPtr &memoryManager,
23         const armnn::ITensorHandleFactory &tensorHandleFactory,
24         const armnn::TensorInfo &paramsInfo,
25         const armnn::TensorInfo &indicesInfo,
26         const armnn::TensorInfo &outputInfo,
27         const std::vector<T> &paramsData,
28         const std::vector<int32_t> &indicesData,
29         const std::vector<T> &outputData)
30 {
31     IgnoreUnused(memoryManager);
32 
33     std::vector<T> actualOutput(outputInfo.GetNumElements());
34 
35     std::unique_ptr<armnn::ITensorHandle> paramsHandle = tensorHandleFactory.CreateTensorHandle(paramsInfo);
36     std::unique_ptr<armnn::ITensorHandle> indicesHandle = tensorHandleFactory.CreateTensorHandle(indicesInfo);
37     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
38 
39     armnn::GatherNdQueueDescriptor data;
40     armnn::WorkloadInfo info;
41     AddInputToWorkload(data, info, paramsInfo, paramsHandle.get());
42     AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
43     AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
44 
45     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::GatherNd,
46                                                                                 data,
47                                                                                 info);
48 
49     paramsHandle->Allocate();
50     indicesHandle->Allocate();
51     outputHandle->Allocate();
52 
53     CopyDataToITensorHandle(paramsHandle.get(), paramsData.data());
54     CopyDataToITensorHandle(indicesHandle.get(), indicesData.data());
55 
56     workload->Execute();
57 
58     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
59 
60     return LayerTestResult<T, OutputDim>(actualOutput,
61                                          outputData,
62                                          outputHandle->GetShape(),
63                                          outputInfo.GetShape());
64 }
65 } // anonymous namespace
66 
67 template<armnn::DataType ArmnnType, typename T>
SimpleGatherNd2dTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)68 LayerTestResult<T, 2> SimpleGatherNd2dTest(
69         armnn::IWorkloadFactory& workloadFactory,
70         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
71         const armnn::ITensorHandleFactory& tensorHandleFactory)
72 {
73     armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
74     armnn::TensorInfo indicesInfo({ 3, 1 }, armnn::DataType::Signed32);
75     armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
76     if (armnn::IsQuantizedType<T>())
77     {
78         paramsInfo.SetQuantizationScale(1.0f);
79         paramsInfo.SetQuantizationOffset(1);
80         outputInfo.SetQuantizationScale(1.0f);
81         outputInfo.SetQuantizationOffset(1);
82     }
83     const std::vector<T> params = ConvertToDataType<ArmnnType>(
84             { 1, 2,
85               3, 4,
86               5, 6,
87               7, 8,
88               9, 10},
89             paramsInfo);
90     const std::vector<int32_t> indices  = ConvertToDataType<armnn::DataType::Signed32>(
91             { 1, 0, 4},
92             indicesInfo);
93     const std::vector<T> expectedOutput = ConvertToDataType<ArmnnType>(
94             { 3, 4,
95               1, 2,
96               9, 10},
97             outputInfo);
98     return GatherNdTestImpl<ArmnnType, T, 2, 2, 2>(
99             workloadFactory,
100             memoryManager,
101             tensorHandleFactory,
102             paramsInfo,
103             indicesInfo,
104             outputInfo,
105             params,
106             indices,
107             expectedOutput);
108 }
109 
110 template<armnn::DataType ArmnnType, typename T>
SimpleGatherNd3dTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)111 LayerTestResult<T, 3> SimpleGatherNd3dTest(
112         armnn::IWorkloadFactory& workloadFactory,
113         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
114         const armnn::ITensorHandleFactory& tensorHandleFactory)
115 {
116     armnn::TensorInfo paramsInfo({ 2, 3, 8, 4 }, ArmnnType);
117     armnn::TensorInfo indicesInfo({ 2, 2 }, armnn::DataType::Signed32);
118     armnn::TensorInfo outputInfo({ 2, 8, 4 }, ArmnnType);
119 
120     if (armnn::IsQuantizedType<T>())
121     {
122         paramsInfo.SetQuantizationScale(1.0f);
123         paramsInfo.SetQuantizationOffset(0);
124         outputInfo.SetQuantizationScale(1.0f);
125         outputInfo.SetQuantizationOffset(0);
126     }
127     const std::vector<T> params = ConvertToDataType<ArmnnType>(
128             { 0,   1,   2,   3, 4,   5,   6,   7, 8,   9,  10,  11, 12,  13,  14,  15,
129              16,  17,  18,  19, 20,  21,  22,  23, 24,  25,  26,  27, 28,  29,  30,  31,
130 
131              32,  33,  34,  35, 36,  37,  38,  39, 40,  41,  42,  43, 44,  45,  46,  47,
132              48,  49,  50,  51, 52,  53,  54,  55, 56,  57,  58,  59, 60,  61,  62,  63,
133 
134              64,  65,  66,  67, 68,  69,  70,  71, 72,  73,  74,  75, 76,  77,  78,  79,
135              80,  81,  82,  83, 84,  85,  86,  87, 88,  89,  90,  91, 92,  93,  94,  95,
136 
137              96,  97,  98,  99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111,
138             112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127,
139 
140             128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
141             144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159,
142 
143             160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
144             176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191 },
145             paramsInfo);
146 
147     const std::vector<int32_t> indices  = ConvertToDataType<armnn::DataType::Signed32>(
148             { 1, 2, 1, 1},
149             indicesInfo);
150 
151     const std::vector<T> expectedOutput = ConvertToDataType<ArmnnType>(
152             { 160, 161, 162, 163, 164, 165, 166, 167, 168, 169, 170, 171, 172, 173, 174, 175,
153             176, 177, 178, 179, 180, 181, 182, 183, 184, 185, 186, 187, 188, 189, 190, 191,
154 
155             128, 129, 130, 131, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143,
156             144, 145, 146, 147, 148, 149, 150, 151, 152, 153, 154, 155, 156, 157, 158, 159},
157             outputInfo);
158 
159     return GatherNdTestImpl<ArmnnType, T, 4, 2, 3>(
160             workloadFactory,
161             memoryManager,
162             tensorHandleFactory,
163             paramsInfo,
164             indicesInfo,
165             outputInfo,
166             params,
167             indices,
168             expectedOutput);
169 }
170 
171 template<armnn::DataType ArmnnType, typename T>
SimpleGatherNd4dTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)172 LayerTestResult<T, 4> SimpleGatherNd4dTest(
173         armnn::IWorkloadFactory& workloadFactory,
174         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
175         const armnn::ITensorHandleFactory& tensorHandleFactory)
176 {
177     armnn::TensorInfo paramsInfo({ 5, 5, 2 }, ArmnnType);
178     armnn::TensorInfo indicesInfo({ 2, 2, 3, 2 }, armnn::DataType::Signed32);
179     armnn::TensorInfo outputInfo({ 2, 2, 3, 2 }, ArmnnType);
180 
181     if (armnn::IsQuantizedType<T>())
182     {
183         paramsInfo.SetQuantizationScale(1.0f);
184         paramsInfo.SetQuantizationOffset(0);
185         outputInfo.SetQuantizationScale(1.0f);
186         outputInfo.SetQuantizationOffset(0);
187     }
188     const std::vector<T> params = ConvertToDataType<ArmnnType>(
189         { 0,  1,    2,  3,    4,  5,    6,  7,    8,  9,
190          10, 11,   12,  13,   14, 15,   16, 17,   18, 19,
191          20, 21,   22,  23,   24, 25,   26, 27,   28, 29,
192          30, 31,   32,  33,   34, 35,   36, 37,   38, 39,
193          40, 41,   42,  43,   44, 45,   46, 47,   48, 49 },
194         paramsInfo);
195 
196     const std::vector<int32_t> indices  = ConvertToDataType<armnn::DataType::Signed32>(
197         { 0, 0,
198           3, 3,
199           4, 4,
200 
201           0, 0,
202           1, 1,
203           2, 2,
204 
205           4, 4,
206           3, 3,
207           0, 0,
208 
209           2, 2,
210           1, 1,
211           0, 0 },
212         indicesInfo);
213 
214     const std::vector<T> expectedOutput = ConvertToDataType<ArmnnType>(
215         {  0,  1,
216           36, 37,
217           48, 49,
218 
219            0,  1,
220           12, 13,
221           24, 25,
222 
223           48, 49,
224           36, 37,
225            0,  1,
226 
227           24, 25,
228           12, 13,
229            0,  1 },
230         outputInfo);
231 
232     return GatherNdTestImpl<ArmnnType, T, 3, 4, 4>(
233             workloadFactory,
234             memoryManager,
235             tensorHandleFactory,
236             paramsInfo,
237             indicesInfo,
238             outputInfo,
239             params,
240             indices,
241             expectedOutput);
242 }
243 
244 //
245 // Explicit template specializations
246 //
247 
248 template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 2>
249 SimpleGatherNd2dTest<armnn::DataType::Float32>(
250         armnn::IWorkloadFactory& workloadFactory,
251         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
252         const armnn::ITensorHandleFactory& tensorHandleFactory);
253 
254 template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 3>
255 SimpleGatherNd3dTest<armnn::DataType::Float32>(
256         armnn::IWorkloadFactory& workloadFactory,
257         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
258         const armnn::ITensorHandleFactory& tensorHandleFactory);
259 
260 template LayerTestResult<armnn::ResolveType<armnn::DataType::Float32>, 4>
261 SimpleGatherNd4dTest<armnn::DataType::Float32>(
262         armnn::IWorkloadFactory& workloadFactory,
263         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
264         const armnn::ITensorHandleFactory& tensorHandleFactory);
265 
266 template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 2>
267 SimpleGatherNd2dTest<armnn::DataType::QAsymmS8>(
268         armnn::IWorkloadFactory& workloadFactory,
269         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
270         const armnn::ITensorHandleFactory& tensorHandleFactory);
271 
272 template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 3>
273 SimpleGatherNd3dTest<armnn::DataType::QAsymmS8>(
274         armnn::IWorkloadFactory& workloadFactory,
275         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
276         const armnn::ITensorHandleFactory& tensorHandleFactory);
277 
278 template LayerTestResult<armnn::ResolveType<armnn::DataType::QAsymmS8>, 4>
279 SimpleGatherNd4dTest<armnn::DataType::QAsymmS8>(
280         armnn::IWorkloadFactory& workloadFactory,
281         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
282         const armnn::ITensorHandleFactory& tensorHandleFactory);
283 
284 template LayerTestResult<armnn::ResolveType<armnn::DataType::Signed32>, 2>
285 SimpleGatherNd2dTest<armnn::DataType::Signed32>(
286         armnn::IWorkloadFactory& workloadFactory,
287         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
288         const armnn::ITensorHandleFactory& tensorHandleFactory);
289 
290 template LayerTestResult<armnn::ResolveType<armnn::DataType::Signed32>, 3>
291 SimpleGatherNd3dTest<armnn::DataType::Signed32>(
292         armnn::IWorkloadFactory& workloadFactory,
293         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
294         const armnn::ITensorHandleFactory& tensorHandleFactory);
295 
296 template LayerTestResult<armnn::ResolveType<armnn::DataType::Signed32>, 4>
297 SimpleGatherNd4dTest<armnn::DataType::Signed32>(
298         armnn::IWorkloadFactory& workloadFactory,
299         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
300         const armnn::ITensorHandleFactory& tensorHandleFactory);