xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/layerTests/GatherTestImpl.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017,2022 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "GatherTestImpl.hpp"
7 
8 #include <ResolveType.hpp>
9 
10 #include <armnnTestUtils/TensorCopyUtils.hpp>
11 #include <armnnTestUtils/WorkloadTestUtils.hpp>
12 #include <armnnTestUtils/TensorHelpers.hpp>
13 #include <utility>
14 
15 namespace
16 {
17 
18 template <armnn::DataType ArmnnType,
19           typename T = armnn::ResolveType<ArmnnType>,
20           size_t ParamsDim,
21           size_t IndicesDim,
22           size_t OutputDim>
GatherTestImpl(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory,const armnn::TensorInfo & paramsInfo,const armnn::TensorInfo & indicesInfo,const armnn::TensorInfo & outputInfo,const std::vector<T> & paramsData,const std::vector<int32_t> & indicesData,const std::vector<T> & outputData,armnn::GatherDescriptor descriptor=armnn::GatherDescriptor ())23 LayerTestResult<T, OutputDim> GatherTestImpl(
24     armnn::IWorkloadFactory& workloadFactory,
25     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
26     const armnn::ITensorHandleFactory& tensorHandleFactory,
27     const armnn::TensorInfo& paramsInfo,
28     const armnn::TensorInfo& indicesInfo,
29     const armnn::TensorInfo& outputInfo,
30     const std::vector<T>& paramsData,
31     const std::vector<int32_t>& indicesData,
32     const std::vector<T>& outputData,
33     armnn::GatherDescriptor descriptor= armnn::GatherDescriptor())
34 {
35     IgnoreUnused(memoryManager);
36 
37     std::vector<T> actualOutput(outputInfo.GetNumElements());
38 
39     std::unique_ptr<armnn::ITensorHandle> paramsHandle = tensorHandleFactory.CreateTensorHandle(paramsInfo);
40     std::unique_ptr<armnn::ITensorHandle> indicesHandle = tensorHandleFactory.CreateTensorHandle(indicesInfo);
41     std::unique_ptr<armnn::ITensorHandle> outputHandle = tensorHandleFactory.CreateTensorHandle(outputInfo);
42 
43     armnn::GatherQueueDescriptor data;
44     data.m_Parameters = std::move(descriptor);
45     armnn::WorkloadInfo info;
46     AddInputToWorkload(data,  info, paramsInfo, paramsHandle.get());
47     AddInputToWorkload(data, info, indicesInfo, indicesHandle.get());
48     AddOutputToWorkload(data, info, outputInfo, outputHandle.get());
49 
50     std::unique_ptr<armnn::IWorkload> workload = workloadFactory.CreateWorkload(armnn::LayerType::Gather, data, info);
51 
52     paramsHandle->Allocate();
53     indicesHandle->Allocate();
54     outputHandle->Allocate();
55 
56     CopyDataToITensorHandle(paramsHandle.get(), paramsData.data());
57     CopyDataToITensorHandle(indicesHandle.get(), indicesData.data());
58 
59     workload->Execute();
60 
61     CopyDataFromITensorHandle(actualOutput.data(), outputHandle.get());
62 
63     return LayerTestResult<T, OutputDim>(actualOutput,
64                                          outputData,
65                                          outputHandle->GetShape(),
66                                          outputInfo.GetShape());
67 }
68 
69 template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
70 struct GatherTestHelper
71 {
Gather1dParamsTestImpl__anon0dc3d3520111::GatherTestHelper72     static LayerTestResult<T, 1> Gather1dParamsTestImpl(
73         armnn::IWorkloadFactory& workloadFactory,
74         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
75         const armnn::ITensorHandleFactory& tensorHandleFactory)
76     {
77         armnn::TensorInfo paramsInfo({ 8 }, ArmnnType);
78         armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
79         armnn::TensorInfo outputInfo({ 4 }, ArmnnType);
80 
81         if (armnn::IsQuantizedType<T>())
82         {
83             paramsInfo.SetQuantizationScale(1.0f);
84             paramsInfo.SetQuantizationOffset(1);
85             outputInfo.SetQuantizationScale(1.0f);
86             outputInfo.SetQuantizationOffset(1);
87         }
88         const std::vector<T> params         = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8 });
89         const std::vector<int32_t> indices  = std::vector<int32_t>({ 0, 2, 1, 5 });
90         const std::vector<T> expectedOutput = std::vector<T>({ 1, 3, 2, 6 });
91 
92         return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
93             workloadFactory,
94             memoryManager,
95             tensorHandleFactory,
96             paramsInfo,
97             indicesInfo,
98             outputInfo,
99             params,
100             indices,
101             expectedOutput);
102     }
103 
Gather1dParamsAxisTestImpl__anon0dc3d3520111::GatherTestHelper104     static LayerTestResult<T, 1> Gather1dParamsAxisTestImpl(
105         armnn::IWorkloadFactory& workloadFactory,
106         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
107         const armnn::ITensorHandleFactory& tensorHandleFactory)
108     {
109         armnn::GatherDescriptor descriptor;
110         descriptor.m_Axis=1;
111         armnn::TensorInfo paramsInfo({ 4, 3 }, ArmnnType);
112         armnn::TensorInfo indicesInfo({ 2 }, armnn::DataType::Signed32);
113         armnn::TensorInfo outputInfo({ 4, 2 }, ArmnnType);
114 
115         if (armnn::IsQuantizedType<T>())
116         {
117             paramsInfo.SetQuantizationScale(1.0f);
118             paramsInfo.SetQuantizationOffset(1);
119             outputInfo.SetQuantizationScale(1.0f);
120             outputInfo.SetQuantizationOffset(1);
121         }
122         const std::vector<T> params         ={  10,  11,  12,
123                                                110, 111, 112,
124                                                120, 121, 122,
125                                                130, 131, 132 };
126         const std::vector<int32_t> indices  = std::vector<int32_t>({ 2, 1 });
127         const std::vector<T> expectedOutput = {  12,  11,
128                                                 112, 111,
129                                                 122, 121,
130                                                 132, 131 } ;
131 
132         return GatherTestImpl<ArmnnType, T, 1, 1, 1>(
133                 workloadFactory,
134                 memoryManager,
135                 tensorHandleFactory,
136                 paramsInfo,
137                 indicesInfo,
138                 outputInfo,
139                 params,
140                 indices,
141                 expectedOutput,
142                 descriptor);
143     }
144 
GatherMultiDimParamsTestImpl__anon0dc3d3520111::GatherTestHelper145     static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
146         armnn::IWorkloadFactory& workloadFactory,
147         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
148         const armnn::ITensorHandleFactory& tensorHandleFactory)
149     {
150         armnn::TensorInfo paramsInfo({ 5, 2 }, ArmnnType);
151         armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
152         armnn::TensorInfo outputInfo({ 3, 2 }, ArmnnType);
153 
154         if (armnn::IsQuantizedType<T>())
155         {
156             paramsInfo.SetQuantizationScale(1.0f);
157             paramsInfo.SetQuantizationOffset(1);
158             outputInfo.SetQuantizationScale(1.0f);
159             outputInfo.SetQuantizationOffset(1);
160         }
161 
162         const std::vector<T> params         = std::vector<T>({ 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 });
163         const std::vector<int32_t> indices  = std::vector<int32_t>({ 1, 3, 4 });
164         const std::vector<T> expectedOutput = std::vector<T>({ 3, 4, 7, 8, 9, 10 });
165 
166         return GatherTestImpl<ArmnnType, T, 2, 1, 2>(
167             workloadFactory,
168             memoryManager,
169             tensorHandleFactory,
170             paramsInfo,
171             indicesInfo,
172             outputInfo,
173             params,
174             indices,
175             expectedOutput);
176     }
177 
GatherMultiDimParamsMultiDimIndicesTestImpl__anon0dc3d3520111::GatherTestHelper178     static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
179         armnn::IWorkloadFactory& workloadFactory,
180         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
181         const armnn::ITensorHandleFactory& tensorHandleFactory)
182     {
183         armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
184         armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
185         armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, ArmnnType);
186 
187         if (armnn::IsQuantizedType<T>())
188         {
189             paramsInfo.SetQuantizationScale(1.0f);
190             paramsInfo.SetQuantizationOffset(1);
191             outputInfo.SetQuantizationScale(1.0f);
192             outputInfo.SetQuantizationOffset(1);
193         }
194 
195         const std::vector<T> params =
196         {
197             1,  2,  3,
198             4,  5,  6,
199 
200             7,  8,  9,
201             10, 11, 12,
202 
203             13, 14, 15,
204             16, 17, 18
205         };
206 
207         const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
208 
209         const std::vector<T> expectedOutput =
210         {
211             7,  8,  9,
212             10, 11, 12,
213             13, 14, 15,
214             16, 17, 18,
215             7,  8,  9,
216             10, 11, 12,
217 
218             13, 14, 15,
219             16, 17, 18,
220             7,  8,  9,
221             10, 11, 12,
222             1,  2,  3,
223             4,  5,  6
224         };
225 
226         return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
227             workloadFactory,
228             memoryManager,
229             tensorHandleFactory,
230             paramsInfo,
231             indicesInfo,
232             outputInfo,
233             params,
234             indices,
235             expectedOutput);
236     }
237 
GatherMultiDimParamsMultiDimIndicesAxis1TestImpl__anon0dc3d3520111::GatherTestHelper238     static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesAxis1TestImpl(
239             armnn::IWorkloadFactory& workloadFactory,
240             const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
241             const armnn::ITensorHandleFactory& tensorHandleFactory)
242     {
243         armnn::GatherDescriptor descriptor;
244         descriptor.m_Axis=1;
245         armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
246         armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
247         armnn::TensorInfo outputInfo({ 3, 2, 3, 3 }, ArmnnType);
248 
249         if (armnn::IsQuantizedType<T>())
250         {
251             paramsInfo.SetQuantizationScale(1.0f);
252             paramsInfo.SetQuantizationOffset(1);
253             outputInfo.SetQuantizationScale(1.0f);
254             outputInfo.SetQuantizationOffset(1);
255         }
256 
257         const std::vector<T> params =
258                 {
259                         1,  2,  3,
260                         4,  5,  6,
261 
262                         7,  8,  9,
263                         10, 11, 12,
264 
265                         13, 14, 15,
266                         16, 17, 18
267                 };
268 
269         const std::vector<int32_t> indices = { 1, 0, 1, 0, 1, 0 };
270 
271         const std::vector<T> expectedOutput =
272                 {
273                         4, 5, 6,
274                         1, 2, 3,
275                         4, 5, 6,
276 
277                         1, 2, 3,
278                         4, 5, 6,
279                         1, 2, 3,
280 
281                         10, 11, 12,
282                         7,  8,  9,
283                         10, 11, 12,
284 
285                         7,  8,  9,
286                         10, 11, 12,
287                          7,  8,  9,
288 
289                         16, 17, 18,
290                         13, 14, 15,
291                         16, 17, 18,
292 
293                         13, 14, 15,
294                         16, 17, 18,
295                         13, 14, 15
296                 };
297 
298         return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
299                 workloadFactory,
300                 memoryManager,
301                 tensorHandleFactory,
302                 paramsInfo,
303                 indicesInfo,
304                 outputInfo,
305                 params,
306                 indices,
307                 expectedOutput,
308                 descriptor);
309     }
310 
GatherMultiDimParamsMultiDimIndicesAxis2TestImpl__anon0dc3d3520111::GatherTestHelper311     static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesAxis2TestImpl(
312         armnn::IWorkloadFactory& workloadFactory,
313         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
314         const armnn::ITensorHandleFactory& tensorHandleFactory)
315     {
316         armnn::GatherDescriptor descriptor;
317         descriptor.m_Axis=2;
318         armnn::TensorInfo paramsInfo({ 3, 2, 3 }, ArmnnType);
319         armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
320         armnn::TensorInfo outputInfo({ 3, 2, 2, 3 }, ArmnnType);
321 
322         if (armnn::IsQuantizedType<T>())
323         {
324             paramsInfo.SetQuantizationScale(1.0f);
325             paramsInfo.SetQuantizationOffset(1);
326             outputInfo.SetQuantizationScale(1.0f);
327             outputInfo.SetQuantizationOffset(1);
328         }
329 
330         const std::vector<T> params =
331                 {
332                         1,  2,  3,
333                         4,  5,  6,
334 
335                         7,  8,  9,
336                         10, 11, 12,
337 
338                         13, 14, 15,
339                         16, 17, 18
340                 };
341 
342         const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
343 
344         const std::vector<T> expectedOutput =
345                 {
346                         2, 3, 2,
347                         3, 2, 1,
348 
349                         5, 6, 5,
350                         6, 5, 4,
351 
352                         8, 9, 8,
353                         9, 8, 7,
354 
355                         11, 12, 11,
356                         12, 11, 10,
357 
358                         14, 15, 14,
359                         15, 14, 13,
360 
361                         17, 18, 17,
362                         18, 17, 16
363                 };
364 
365         return GatherTestImpl<ArmnnType, T, 3, 2, 4>(
366                 workloadFactory,
367                 memoryManager,
368                 tensorHandleFactory,
369                 paramsInfo,
370                 indicesInfo,
371                 outputInfo,
372                 params,
373                 indices,
374                 expectedOutput,
375                 descriptor);
376     }
377 };
378 
379 template<typename T>
380 struct GatherTestHelper<armnn::DataType::Float16, T>
381 {
Gather1dParamsTestImpl__anon0dc3d3520111::GatherTestHelper382     static LayerTestResult<T, 1> Gather1dParamsTestImpl(
383         armnn::IWorkloadFactory& workloadFactory,
384         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
385         const armnn::ITensorHandleFactory& tensorHandleFactory)
386     {
387         using namespace half_float::literal;
388 
389         armnn::TensorInfo paramsInfo({ 8 }, armnn::DataType::Float16);
390         armnn::TensorInfo indicesInfo({ 4 }, armnn::DataType::Signed32);
391         armnn::TensorInfo outputInfo({ 4 }, armnn::DataType::Float16);
392 
393         const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h });
394         const std::vector<int32_t> indices  = std::vector<int32_t>({ 0, 2, 1, 5 });
395         const std::vector<T> expectedOutput = std::vector<T>({ 1._h, 3._h, 2._h, 6._h });
396 
397         return GatherTestImpl<armnn::DataType::Float16, T, 1, 1, 1>(
398             workloadFactory,
399             memoryManager,
400             tensorHandleFactory,
401             paramsInfo,
402             indicesInfo,
403             outputInfo,
404             params,
405             indices,
406             expectedOutput);
407     }
408 
GatherMultiDimParamsTestImpl__anon0dc3d3520111::GatherTestHelper409     static LayerTestResult<T, 2> GatherMultiDimParamsTestImpl(
410         armnn::IWorkloadFactory& workloadFactory,
411         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
412         const armnn::ITensorHandleFactory& tensorHandleFactory)
413     {
414         using namespace half_float::literal;
415 
416         armnn::TensorInfo paramsInfo({ 5, 2 }, armnn::DataType::Float16);
417         armnn::TensorInfo indicesInfo({ 3 }, armnn::DataType::Signed32);
418         armnn::TensorInfo outputInfo({ 3, 2 }, armnn::DataType::Float16);
419 
420         const std::vector<T> params = std::vector<T>({ 1._h, 2._h, 3._h, 4._h, 5._h, 6._h, 7._h, 8._h, 9._h, 10._h });
421 
422         const std::vector<int32_t> indices  = std::vector<int32_t>({ 1, 3, 4 });
423         const std::vector<T> expectedOutput = std::vector<T>({ 3._h, 4._h, 7._h, 8._h, 9._h, 10._h });
424 
425         return GatherTestImpl<armnn::DataType::Float16, T, 2, 1, 2>(
426             workloadFactory,
427             memoryManager,
428             tensorHandleFactory,
429             paramsInfo,
430             indicesInfo,
431             outputInfo,
432             params,
433             indices,
434             expectedOutput);
435     }
436 
GatherMultiDimParamsMultiDimIndicesTestImpl__anon0dc3d3520111::GatherTestHelper437     static LayerTestResult<T, 4> GatherMultiDimParamsMultiDimIndicesTestImpl(
438         armnn::IWorkloadFactory& workloadFactory,
439         const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
440         const armnn::ITensorHandleFactory& tensorHandleFactory)
441     {
442         using namespace half_float::literal;
443 
444         armnn::TensorInfo paramsInfo({ 3, 2, 3 }, armnn::DataType::Float16);
445         armnn::TensorInfo indicesInfo({ 2, 3 }, armnn::DataType::Signed32);
446         armnn::TensorInfo outputInfo({ 2, 3, 2, 3 }, armnn::DataType::Float16);
447 
448         const std::vector<T> params =
449         {
450             1._h,  2._h,  3._h,
451             4._h,  5._h,  6._h,
452 
453             7._h,  8._h,  9._h,
454             10._h, 11._h, 12._h,
455 
456             13._h, 14._h, 15._h,
457             16._h, 17._h, 18._h
458         };
459 
460         const std::vector<int32_t> indices = { 1, 2, 1, 2, 1, 0 };
461 
462         const std::vector<T> expectedOutput =
463         {
464             7._h,  8._h,  9._h,
465             10._h, 11._h, 12._h,
466             13._h, 14._h, 15._h,
467             16._h, 17._h, 18._h,
468             7._h,  8._h,  9._h,
469             10._h, 11._h, 12._h,
470 
471             13._h, 14._h, 15._h,
472             16._h, 17._h, 18._h,
473             7._h,  8._h,  9._h,
474             10._h, 11._h, 12._h,
475             1._h,  2._h,  3._h,
476             4._h,  5._h,  6._h
477         };
478 
479         return GatherTestImpl<armnn::DataType::Float16, T, 3, 2, 4>(
480             workloadFactory,
481             memoryManager,
482             tensorHandleFactory,
483             paramsInfo,
484             indicesInfo,
485             outputInfo,
486             params,
487             indices,
488             expectedOutput);
489     }
490 };
491 
492 } // anonymous namespace
493 
Gather1dParamsFloat32Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)494 LayerTestResult<float, 1> Gather1dParamsFloat32Test(
495     armnn::IWorkloadFactory& workloadFactory,
496     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
497     const armnn::ITensorHandleFactory& tensorHandleFactory)
498 {
499     return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsTestImpl(
500             workloadFactory, memoryManager, tensorHandleFactory);
501 }
502 
Gather1dParamsAxisTest(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)503 LayerTestResult<float, 1> Gather1dParamsAxisTest(
504     armnn::IWorkloadFactory& workloadFactory,
505     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
506     const armnn::ITensorHandleFactory& tensorHandleFactory)
507 {
508     return GatherTestHelper<armnn::DataType::Float32>::Gather1dParamsAxisTestImpl(
509             workloadFactory, memoryManager, tensorHandleFactory);
510 }
511 
Gather1dParamsFloat16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)512 LayerTestResult<armnn::Half, 1> Gather1dParamsFloat16Test(
513     armnn::IWorkloadFactory& workloadFactory,
514     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
515     const armnn::ITensorHandleFactory& tensorHandleFactory)
516 {
517     return GatherTestHelper<armnn::DataType::Float16>::Gather1dParamsTestImpl(
518             workloadFactory, memoryManager, tensorHandleFactory);
519 }
520 
Gather1dParamsUint8Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)521 LayerTestResult<uint8_t, 1> Gather1dParamsUint8Test(
522     armnn::IWorkloadFactory& workloadFactory,
523     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
524     const armnn::ITensorHandleFactory& tensorHandleFactory)
525 {
526     return GatherTestHelper<armnn::DataType::QAsymmU8>::Gather1dParamsTestImpl(
527             workloadFactory, memoryManager, tensorHandleFactory);
528 }
529 
Gather1dParamsInt16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)530 LayerTestResult<int16_t, 1> Gather1dParamsInt16Test(
531         armnn::IWorkloadFactory& workloadFactory,
532     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
533     const armnn::ITensorHandleFactory& tensorHandleFactory)
534 {
535     return GatherTestHelper<armnn::DataType::QSymmS16>::Gather1dParamsTestImpl(
536             workloadFactory, memoryManager, tensorHandleFactory);
537 }
538 
Gather1dParamsInt32Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)539 LayerTestResult<int32_t, 1> Gather1dParamsInt32Test(
540     armnn::IWorkloadFactory& workloadFactory,
541     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
542     const armnn::ITensorHandleFactory& tensorHandleFactory)
543 {
544     return GatherTestHelper<armnn::DataType::Signed32>::Gather1dParamsTestImpl(
545             workloadFactory, memoryManager, tensorHandleFactory);
546 }
547 
GatherMultiDimParamsFloat32Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)548 LayerTestResult<float, 2> GatherMultiDimParamsFloat32Test(
549     armnn::IWorkloadFactory& workloadFactory,
550     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
551     const armnn::ITensorHandleFactory& tensorHandleFactory)
552 {
553     return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsTestImpl(
554             workloadFactory, memoryManager, tensorHandleFactory);
555 }
556 
GatherMultiDimParamsFloat16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)557 LayerTestResult<armnn::Half, 2> GatherMultiDimParamsFloat16Test(
558     armnn::IWorkloadFactory& workloadFactory,
559     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
560     const armnn::ITensorHandleFactory& tensorHandleFactory)
561 {
562     return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsTestImpl(
563             workloadFactory, memoryManager, tensorHandleFactory);
564 }
565 
GatherMultiDimParamsUint8Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)566 LayerTestResult<uint8_t, 2> GatherMultiDimParamsUint8Test(
567     armnn::IWorkloadFactory& workloadFactory,
568     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
569     const armnn::ITensorHandleFactory& tensorHandleFactory)
570 {
571     return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsTestImpl(
572         workloadFactory, memoryManager, tensorHandleFactory);
573 }
574 
GatherMultiDimParamsInt16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)575 LayerTestResult<int16_t, 2> GatherMultiDimParamsInt16Test(
576     armnn::IWorkloadFactory& workloadFactory,
577     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
578     const armnn::ITensorHandleFactory& tensorHandleFactory)
579 {
580     return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsTestImpl(
581         workloadFactory, memoryManager, tensorHandleFactory);
582 }
583 
GatherMultiDimParamsInt32Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)584 LayerTestResult<int32_t, 2> GatherMultiDimParamsInt32Test(
585     armnn::IWorkloadFactory& workloadFactory,
586     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
587     const armnn::ITensorHandleFactory& tensorHandleFactory)
588 {
589     return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsTestImpl(
590             workloadFactory, memoryManager, tensorHandleFactory);
591 }
592 
GatherMultiDimParamsMultiDimIndicesFloat32Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)593 LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesFloat32Test(
594     armnn::IWorkloadFactory& workloadFactory,
595     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
596     const armnn::ITensorHandleFactory& tensorHandleFactory)
597 {
598     return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
599         workloadFactory, memoryManager, tensorHandleFactory);
600 }
601 
GatherMultiDimParamsMultiDimIndicesAxis1Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)602 LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesAxis1Test(
603     armnn::IWorkloadFactory& workloadFactory,
604     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
605     const armnn::ITensorHandleFactory& tensorHandleFactory)
606 {
607     return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesAxis1TestImpl(
608             workloadFactory, memoryManager, tensorHandleFactory);
609 }
610 
GatherMultiDimParamsMultiDimIndicesAxis2Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)611 LayerTestResult<float, 4> GatherMultiDimParamsMultiDimIndicesAxis2Test(
612     armnn::IWorkloadFactory& workloadFactory,
613     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
614     const armnn::ITensorHandleFactory& tensorHandleFactory)
615 {
616     return GatherTestHelper<armnn::DataType::Float32>::GatherMultiDimParamsMultiDimIndicesAxis2TestImpl(
617             workloadFactory, memoryManager, tensorHandleFactory);
618 }
619 
GatherMultiDimParamsMultiDimIndicesFloat16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)620 LayerTestResult<armnn::Half, 4> GatherMultiDimParamsMultiDimIndicesFloat16Test(
621     armnn::IWorkloadFactory& workloadFactory,
622     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
623     const armnn::ITensorHandleFactory& tensorHandleFactory)
624 {
625     return GatherTestHelper<armnn::DataType::Float16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
626         workloadFactory, memoryManager, tensorHandleFactory);
627 }
628 
GatherMultiDimParamsMultiDimIndicesUint8Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)629 LayerTestResult<uint8_t, 4> GatherMultiDimParamsMultiDimIndicesUint8Test(
630     armnn::IWorkloadFactory& workloadFactory,
631     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
632     const armnn::ITensorHandleFactory& tensorHandleFactory)
633 {
634     return GatherTestHelper<armnn::DataType::QAsymmU8>::GatherMultiDimParamsMultiDimIndicesTestImpl(
635         workloadFactory, memoryManager, tensorHandleFactory);
636 }
637 
GatherMultiDimParamsMultiDimIndicesInt16Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)638 LayerTestResult<int16_t, 4> GatherMultiDimParamsMultiDimIndicesInt16Test(
639     armnn::IWorkloadFactory& workloadFactory,
640     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
641     const armnn::ITensorHandleFactory& tensorHandleFactory)
642 {
643     return GatherTestHelper<armnn::DataType::QSymmS16>::GatherMultiDimParamsMultiDimIndicesTestImpl(
644         workloadFactory, memoryManager, tensorHandleFactory);
645 }
646 
GatherMultiDimParamsMultiDimIndicesInt32Test(armnn::IWorkloadFactory & workloadFactory,const armnn::IBackendInternal::IMemoryManagerSharedPtr & memoryManager,const armnn::ITensorHandleFactory & tensorHandleFactory)647 LayerTestResult<int32_t, 4> GatherMultiDimParamsMultiDimIndicesInt32Test(
648     armnn::IWorkloadFactory& workloadFactory,
649     const armnn::IBackendInternal::IMemoryManagerSharedPtr& memoryManager,
650     const armnn::ITensorHandleFactory& tensorHandleFactory)
651 {
652     return GatherTestHelper<armnn::DataType::Signed32>::GatherMultiDimParamsMultiDimIndicesTestImpl(
653             workloadFactory, memoryManager, tensorHandleFactory);
654 }