xref: /aosp_15_r20/external/armnn/delegate/test/GatherTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020, 2023 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "GatherTestHelper.hpp"
7 
8 #include <armnn_delegate.hpp>
9 
10 #include <flatbuffers/flatbuffers.h>
11 #include <schema_generated.h>
12 
13 #include <doctest/doctest.h>
14 
15 namespace armnnDelegate
16 {
17 
18 // GATHER Operator
GatherUint8Test(std::vector<armnn::BackendId> & backends)19 void GatherUint8Test(std::vector<armnn::BackendId>& backends)
20 {
21 
22     std::vector<int32_t> paramsShape{8};
23     std::vector<int32_t> indicesShape{3};
24     std::vector<int32_t> expectedOutputShape{3};
25 
26     int32_t              axis = 0;
27     std::vector<uint8_t> paramsValues{1, 2, 3, 4, 5, 6, 7, 8};
28     std::vector<int32_t> indicesValues{7, 6, 5};
29     std::vector<uint8_t> expectedOutputValues{8, 7, 6};
30 
31     GatherTest<uint8_t>(::tflite::TensorType_UINT8,
32                         backends,
33                         paramsShape,
34                         indicesShape,
35                         expectedOutputShape,
36                         axis,
37                         paramsValues,
38                         indicesValues,
39                         expectedOutputValues);
40 }
41 
GatherFp32Test(std::vector<armnn::BackendId> & backends)42 void GatherFp32Test(std::vector<armnn::BackendId>& backends)
43 {
44     std::vector<int32_t> paramsShape{8};
45     std::vector<int32_t> indicesShape{3};
46     std::vector<int32_t> expectedOutputShape{3};
47 
48     int32_t              axis = 0;
49     std::vector<float>   paramsValues{1.1f, 2.2f, 3.3f, 4.4f, 5.5f, 6.6f, 7.7f, 8.8f};
50     std::vector<int32_t> indicesValues{7, 6, 5};
51     std::vector<float>   expectedOutputValues{8.8f, 7.7f, 6.6f};
52 
53     GatherTest<float>(::tflite::TensorType_FLOAT32,
54                       backends,
55                       paramsShape,
56                       indicesShape,
57                       expectedOutputShape,
58                       axis,
59                       paramsValues,
60                       indicesValues,
61                       expectedOutputValues);
62 }
63 
64 // GATHER Test Suite
65 TEST_SUITE("GATHER_CpuRefTests")
66 {
67 
68 TEST_CASE ("GATHER_Uint8_CpuRef_Test")
69 {
70     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
71     GatherUint8Test(backends);
72 }
73 
74 TEST_CASE ("GATHER_Fp32_CpuRef_Test")
75 {
76     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuRef};
77     GatherFp32Test(backends);
78 }
79 
80 }
81 
82 TEST_SUITE("GATHER_CpuAccTests")
83 {
84 
85 TEST_CASE ("GATHER_Uint8_CpuAcc_Test")
86 {
87     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
88     GatherUint8Test(backends);
89 }
90 
91 TEST_CASE ("GATHER_Fp32_CpuAcc_Test")
92 {
93     std::vector<armnn::BackendId> backends = {armnn::Compute::CpuAcc};
94     GatherFp32Test(backends);
95 }
96 
97 }
98 
99 TEST_SUITE("GATHER_GpuAccTests")
100 {
101 
102 TEST_CASE ("GATHER_Uint8_GpuAcc_Test")
103 {
104     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
105     GatherUint8Test(backends);
106 }
107 
108 TEST_CASE ("GATHER_Fp32_GpuAcc_Test")
109 {
110     std::vector<armnn::BackendId> backends = {armnn::Compute::GpuAcc};
111     GatherFp32Test(backends);
112 }
113 
114 }
115 // End of GATHER Test Suite
116 
117 } // namespace armnnDelegate