xref: /aosp_15_r20/external/armnn/delegate/test/FillTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "FillTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
14*89c4ff92SAndroid Build Coastguard Worker 
15*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
16*89c4ff92SAndroid Build Coastguard Worker {
17*89c4ff92SAndroid Build Coastguard Worker 
Fill2dTest(std::vector<armnn::BackendId> & backends,tflite::BuiltinOperator fillOperatorCode=tflite::BuiltinOperator_FILL,float fill=2.0f)18*89c4ff92SAndroid Build Coastguard Worker void Fill2dTest(std::vector<armnn::BackendId>& backends,
19*89c4ff92SAndroid Build Coastguard Worker                tflite::BuiltinOperator fillOperatorCode = tflite::BuiltinOperator_FILL,
20*89c4ff92SAndroid Build Coastguard Worker                float fill = 2.0f )
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 2 };
23*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tensorShape { 2, 2 };
24*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues = { fill, fill,
25*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill };
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker     FillTest<float>(fillOperatorCode,
28*89c4ff92SAndroid Build Coastguard Worker                     ::tflite::TensorType_FLOAT32,
29*89c4ff92SAndroid Build Coastguard Worker                     backends,
30*89c4ff92SAndroid Build Coastguard Worker                     inputShape,
31*89c4ff92SAndroid Build Coastguard Worker                     tensorShape,
32*89c4ff92SAndroid Build Coastguard Worker                     expectedOutputValues,
33*89c4ff92SAndroid Build Coastguard Worker                     fill);
34*89c4ff92SAndroid Build Coastguard Worker }
35*89c4ff92SAndroid Build Coastguard Worker 
Fill3dTest(std::vector<armnn::BackendId> & backends,tflite::BuiltinOperator fillOperatorCode=tflite::BuiltinOperator_FILL,float fill=5.0f)36*89c4ff92SAndroid Build Coastguard Worker void Fill3dTest(std::vector<armnn::BackendId>& backends,
37*89c4ff92SAndroid Build Coastguard Worker                tflite::BuiltinOperator fillOperatorCode = tflite::BuiltinOperator_FILL,
38*89c4ff92SAndroid Build Coastguard Worker                float fill = 5.0f )
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 3 };
41*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tensorShape { 3, 3, 3 };
42*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues = { fill, fill, fill,
43*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill,
44*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill,
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill,
47*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill,
48*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill,
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill,
51*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill,
52*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill };
53*89c4ff92SAndroid Build Coastguard Worker 
54*89c4ff92SAndroid Build Coastguard Worker     FillTest<float>(fillOperatorCode,
55*89c4ff92SAndroid Build Coastguard Worker                     ::tflite::TensorType_FLOAT32,
56*89c4ff92SAndroid Build Coastguard Worker                     backends,
57*89c4ff92SAndroid Build Coastguard Worker                     inputShape,
58*89c4ff92SAndroid Build Coastguard Worker                     tensorShape,
59*89c4ff92SAndroid Build Coastguard Worker                     expectedOutputValues,
60*89c4ff92SAndroid Build Coastguard Worker                     fill);
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker 
Fill4dTest(std::vector<armnn::BackendId> & backends,tflite::BuiltinOperator fillOperatorCode=tflite::BuiltinOperator_FILL,float fill=3.0f)63*89c4ff92SAndroid Build Coastguard Worker void Fill4dTest(std::vector<armnn::BackendId>& backends,
64*89c4ff92SAndroid Build Coastguard Worker                tflite::BuiltinOperator fillOperatorCode = tflite::BuiltinOperator_FILL,
65*89c4ff92SAndroid Build Coastguard Worker                float fill = 3.0f )
66*89c4ff92SAndroid Build Coastguard Worker {
67*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 4 };
68*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tensorShape { 2, 2, 4, 4 };
69*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues = { fill, fill, fill, fill,
70*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
71*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
72*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
75*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
76*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
77*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
80*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
81*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
82*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
83*89c4ff92SAndroid Build Coastguard Worker 
84*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
85*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
86*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill,
87*89c4ff92SAndroid Build Coastguard Worker                                                 fill, fill, fill, fill };
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     FillTest<float>(fillOperatorCode,
90*89c4ff92SAndroid Build Coastguard Worker                     ::tflite::TensorType_FLOAT32,
91*89c4ff92SAndroid Build Coastguard Worker                     backends,
92*89c4ff92SAndroid Build Coastguard Worker                     inputShape,
93*89c4ff92SAndroid Build Coastguard Worker                     tensorShape,
94*89c4ff92SAndroid Build Coastguard Worker                     expectedOutputValues,
95*89c4ff92SAndroid Build Coastguard Worker                     fill);
96*89c4ff92SAndroid Build Coastguard Worker }
97*89c4ff92SAndroid Build Coastguard Worker 
FillInt32Test(std::vector<armnn::BackendId> & backends,tflite::BuiltinOperator fillOperatorCode=tflite::BuiltinOperator_FILL,int32_t fill=2)98*89c4ff92SAndroid Build Coastguard Worker void FillInt32Test(std::vector<armnn::BackendId>& backends,
99*89c4ff92SAndroid Build Coastguard Worker                   tflite::BuiltinOperator fillOperatorCode = tflite::BuiltinOperator_FILL,
100*89c4ff92SAndroid Build Coastguard Worker                   int32_t fill = 2 )
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 2 };
103*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> tensorShape { 2, 2 };
104*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputValues = { fill, fill,
105*89c4ff92SAndroid Build Coastguard Worker                                                   fill, fill };
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker     FillTest<int32_t>(fillOperatorCode,
108*89c4ff92SAndroid Build Coastguard Worker                       ::tflite::TensorType_INT32,
109*89c4ff92SAndroid Build Coastguard Worker                       backends,
110*89c4ff92SAndroid Build Coastguard Worker                       inputShape,
111*89c4ff92SAndroid Build Coastguard Worker                       tensorShape,
112*89c4ff92SAndroid Build Coastguard Worker                       expectedOutputValues,
113*89c4ff92SAndroid Build Coastguard Worker                       fill);
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Fill_CpuRefTests")
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill2d_CpuRef_Test")
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
122*89c4ff92SAndroid Build Coastguard Worker     Fill2dTest(backends);
123*89c4ff92SAndroid Build Coastguard Worker }
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill3d_CpuRef_Test")
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
128*89c4ff92SAndroid Build Coastguard Worker     Fill3dTest(backends);
129*89c4ff92SAndroid Build Coastguard Worker }
130*89c4ff92SAndroid Build Coastguard Worker 
131*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill3d_CpuRef_Test")
132*89c4ff92SAndroid Build Coastguard Worker {
133*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
134*89c4ff92SAndroid Build Coastguard Worker     Fill3dTest(backends);
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill4d_CpuRef_Test")
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
140*89c4ff92SAndroid Build Coastguard Worker     Fill4dTest(backends);
141*89c4ff92SAndroid Build Coastguard Worker }
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FillInt32_CpuRef_Test")
144*89c4ff92SAndroid Build Coastguard Worker {
145*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuRef };
146*89c4ff92SAndroid Build Coastguard Worker     FillInt32Test(backends);
147*89c4ff92SAndroid Build Coastguard Worker }
148*89c4ff92SAndroid Build Coastguard Worker 
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Fill_CpuAccTests")
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill2d_CpuAcc_Test")
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
157*89c4ff92SAndroid Build Coastguard Worker     Fill2dTest(backends);
158*89c4ff92SAndroid Build Coastguard Worker }
159*89c4ff92SAndroid Build Coastguard Worker 
160*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill3d_CpuAcc_Test")
161*89c4ff92SAndroid Build Coastguard Worker {
162*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
163*89c4ff92SAndroid Build Coastguard Worker     Fill3dTest(backends);
164*89c4ff92SAndroid Build Coastguard Worker }
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill3d_CpuAcc_Test")
167*89c4ff92SAndroid Build Coastguard Worker {
168*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
169*89c4ff92SAndroid Build Coastguard Worker     Fill3dTest(backends);
170*89c4ff92SAndroid Build Coastguard Worker }
171*89c4ff92SAndroid Build Coastguard Worker 
172*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill4d_CpuAcc_Test")
173*89c4ff92SAndroid Build Coastguard Worker {
174*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
175*89c4ff92SAndroid Build Coastguard Worker     Fill4dTest(backends);
176*89c4ff92SAndroid Build Coastguard Worker }
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FillInt32_CpuAcc_Test")
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::CpuAcc };
181*89c4ff92SAndroid Build Coastguard Worker     FillInt32Test(backends);
182*89c4ff92SAndroid Build Coastguard Worker }
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker }
185*89c4ff92SAndroid Build Coastguard Worker 
186*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Fill_GpuAccTests")
187*89c4ff92SAndroid Build Coastguard Worker {
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill2d_GpuAcc_Test")
190*89c4ff92SAndroid Build Coastguard Worker {
191*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
192*89c4ff92SAndroid Build Coastguard Worker     Fill2dTest(backends);
193*89c4ff92SAndroid Build Coastguard Worker }
194*89c4ff92SAndroid Build Coastguard Worker 
195*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill3d_GpuAcc_Test")
196*89c4ff92SAndroid Build Coastguard Worker {
197*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
198*89c4ff92SAndroid Build Coastguard Worker     Fill3dTest(backends);
199*89c4ff92SAndroid Build Coastguard Worker }
200*89c4ff92SAndroid Build Coastguard Worker 
201*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill3d_GpuAcc_Test")
202*89c4ff92SAndroid Build Coastguard Worker {
203*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
204*89c4ff92SAndroid Build Coastguard Worker     Fill3dTest(backends);
205*89c4ff92SAndroid Build Coastguard Worker }
206*89c4ff92SAndroid Build Coastguard Worker 
207*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Fill4d_GpuAcc_Test")
208*89c4ff92SAndroid Build Coastguard Worker {
209*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
210*89c4ff92SAndroid Build Coastguard Worker     Fill4dTest(backends);
211*89c4ff92SAndroid Build Coastguard Worker }
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("FillInt32_GpuAcc_Test")
214*89c4ff92SAndroid Build Coastguard Worker {
215*89c4ff92SAndroid Build Coastguard Worker     std::vector<armnn::BackendId> backends = { armnn::Compute::GpuAcc };
216*89c4ff92SAndroid Build Coastguard Worker     FillInt32Test(backends);
217*89c4ff92SAndroid Build Coastguard Worker }
218*89c4ff92SAndroid Build Coastguard Worker 
219*89c4ff92SAndroid Build Coastguard Worker }
220*89c4ff92SAndroid Build Coastguard Worker 
221*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate