xref: /aosp_15_r20/external/armnn/delegate/test/Convolution3dTest.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021, 2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ConvolutionTestHelper.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn_delegate.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <flatbuffers/flatbuffers.h>
11*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/interpreter.h>
12*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/kernels/register.h>
13*89c4ff92SAndroid Build Coastguard Worker #include <tensorflow/lite/model.h>
14*89c4ff92SAndroid Build Coastguard Worker #include <schema_generated.h>
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
17*89c4ff92SAndroid Build Coastguard Worker 
18*89c4ff92SAndroid Build Coastguard Worker namespace armnnDelegate
19*89c4ff92SAndroid Build Coastguard Worker {
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker // Conv3d is currently only supports Float32 inputs, filter, bias and outputs in TFLite.
22*89c4ff92SAndroid Build Coastguard Worker // Conv3d is only correctly supported for external delegates from TF Lite v2.6, as there was a breaking bug in v2.5.
23*89c4ff92SAndroid Build Coastguard Worker #if defined(ARMNN_POST_TFLITE_2_5)
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker // Create a vector from 0 to size divided to create smaller floating point values.
26*89c4ff92SAndroid Build Coastguard Worker template <typename T>
CreateFloatData(int32_t size,float divisor)27*89c4ff92SAndroid Build Coastguard Worker std::vector<T> CreateFloatData(int32_t size, float divisor)
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> data;
30*89c4ff92SAndroid Build Coastguard Worker     for (int32_t i = 0; i < size; ++i)
31*89c4ff92SAndroid Build Coastguard Worker     {
32*89c4ff92SAndroid Build Coastguard Worker         float value = static_cast<float>(i);
33*89c4ff92SAndroid Build Coastguard Worker         data.push_back(value/divisor);
34*89c4ff92SAndroid Build Coastguard Worker     }
35*89c4ff92SAndroid Build Coastguard Worker     return data;
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker 
Conv3DWithBiasesSimpleWithPaddingFp32Test(std::vector<armnn::BackendId> & backends)38*89c4ff92SAndroid Build Coastguard Worker void Conv3DWithBiasesSimpleWithPaddingFp32Test(std::vector<armnn::BackendId>& backends)
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker     // Set input data
41*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 1, 2, 2, 2, 1 };
42*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> filterShape { 2, 2, 2, 1, 1 };
43*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasShape { 1 };
44*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape { 1, 2, 2, 2, 1 };
45*89c4ff92SAndroid Build Coastguard Worker 
46*89c4ff92SAndroid Build Coastguard Worker     static std::vector<float> inputValues =
47*89c4ff92SAndroid Build Coastguard Worker     {
48*89c4ff92SAndroid Build Coastguard Worker         1.f, 2.f, 3.f, 4.f, 5.f, 6.f, 7.f, 8.f
49*89c4ff92SAndroid Build Coastguard Worker     };
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> filterValues =
52*89c4ff92SAndroid Build Coastguard Worker     {
53*89c4ff92SAndroid Build Coastguard Worker         2.f,1.f, 1.f,0.f, 0.f,1.f, 1.f,1.f
54*89c4ff92SAndroid Build Coastguard Worker     };
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasValues = { 5.f };
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues =
59*89c4ff92SAndroid Build Coastguard Worker     {
60*89c4ff92SAndroid Build Coastguard Worker        33.f, 21.f, 23.f, 13.f, 28.f, 25.f, 27.f, 21.f
61*89c4ff92SAndroid Build Coastguard Worker     };
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker     Convolution3dTest<float>(tflite::BuiltinOperator_CONV_3D,
64*89c4ff92SAndroid Build Coastguard Worker                              ::tflite::TensorType_FLOAT32,
65*89c4ff92SAndroid Build Coastguard Worker                              { 1, 1, 1 }, // strideX, strideY, strideZ
66*89c4ff92SAndroid Build Coastguard Worker                              { 1, 1, 1 }, // dilationX, dilationY, dilationZ
67*89c4ff92SAndroid Build Coastguard Worker                              tflite::Padding_SAME,
68*89c4ff92SAndroid Build Coastguard Worker                              tflite::ActivationFunctionType_NONE,
69*89c4ff92SAndroid Build Coastguard Worker                              backends,
70*89c4ff92SAndroid Build Coastguard Worker                              inputShape,
71*89c4ff92SAndroid Build Coastguard Worker                              filterShape,
72*89c4ff92SAndroid Build Coastguard Worker                              outputShape,
73*89c4ff92SAndroid Build Coastguard Worker                              inputValues,
74*89c4ff92SAndroid Build Coastguard Worker                              filterValues,
75*89c4ff92SAndroid Build Coastguard Worker                              expectedOutputValues,
76*89c4ff92SAndroid Build Coastguard Worker                              biasShape,
77*89c4ff92SAndroid Build Coastguard Worker                              biasValues);
78*89c4ff92SAndroid Build Coastguard Worker }
79*89c4ff92SAndroid Build Coastguard Worker 
Conv3DWithBiasesStridesFp32Test(std::vector<armnn::BackendId> & backends)80*89c4ff92SAndroid Build Coastguard Worker void Conv3DWithBiasesStridesFp32Test(std::vector<armnn::BackendId>& backends)
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 1, 3, 10, 10, 1 };
83*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> filterShape { 3, 5, 5, 1, 1 };
84*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasShape { 1 };
85*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape { 1, 1, 3, 3, 1 };
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValues = CreateFloatData<float>(300, 1.0f);
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> filterValues =
90*89c4ff92SAndroid Build Coastguard Worker     {
91*89c4ff92SAndroid Build Coastguard Worker         1.f, 1.f, 1.f, 1.f, 1.f,
92*89c4ff92SAndroid Build Coastguard Worker         1.f, 1.f, 1.f, 1.f, 1.f,
93*89c4ff92SAndroid Build Coastguard Worker         1.f, 1.f, 1.f, 1.f, 1.f,
94*89c4ff92SAndroid Build Coastguard Worker         1.f, 1.f, 1.f, 1.f, 1.f,
95*89c4ff92SAndroid Build Coastguard Worker         1.f, 1.f, 1.f, 1.f, 1.f,
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker         0.f, 0.f, 0.f, 0.f, 0.f,
98*89c4ff92SAndroid Build Coastguard Worker         0.f, 0.f, 0.f, 0.f, 0.f,
99*89c4ff92SAndroid Build Coastguard Worker         0.f, 0.f, 0.f, 0.f, 0.f,
100*89c4ff92SAndroid Build Coastguard Worker         0.f, 0.f, 0.f, 0.f, 0.f,
101*89c4ff92SAndroid Build Coastguard Worker         0.f, 0.f, 0.f, 0.f, 0.f,
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker         2.f, 2.f, 2.f, 2.f, 2.f,
104*89c4ff92SAndroid Build Coastguard Worker         2.f, 2.f, 2.f, 2.f, 2.f,
105*89c4ff92SAndroid Build Coastguard Worker         2.f, 2.f, 2.f, 2.f, 2.f,
106*89c4ff92SAndroid Build Coastguard Worker         2.f, 2.f, 2.f, 2.f, 2.f,
107*89c4ff92SAndroid Build Coastguard Worker         2.f, 2.f, 2.f, 2.f, 2.f
108*89c4ff92SAndroid Build Coastguard Worker     };
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasValues = { 10.f };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues =
113*89c4ff92SAndroid Build Coastguard Worker     {
114*89c4ff92SAndroid Build Coastguard Worker         11660.f, 11810.f, 11960.f,
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker         13160.f, 13310.f, 13460.f,
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker         14660.f, 14810.f, 14960.f
119*89c4ff92SAndroid Build Coastguard Worker     };
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker     Convolution3dTest<float>(tflite::BuiltinOperator_CONV_3D,
122*89c4ff92SAndroid Build Coastguard Worker                              ::tflite::TensorType_FLOAT32,
123*89c4ff92SAndroid Build Coastguard Worker                              { 2, 2, 2 }, // strideX, strideY, strideZ
124*89c4ff92SAndroid Build Coastguard Worker                              { 1, 1, 1 }, // dilationX, dilationY, dilationZ
125*89c4ff92SAndroid Build Coastguard Worker                              tflite::Padding_VALID,
126*89c4ff92SAndroid Build Coastguard Worker                              tflite::ActivationFunctionType_NONE,
127*89c4ff92SAndroid Build Coastguard Worker                              backends,
128*89c4ff92SAndroid Build Coastguard Worker                              inputShape,
129*89c4ff92SAndroid Build Coastguard Worker                              filterShape,
130*89c4ff92SAndroid Build Coastguard Worker                              outputShape,
131*89c4ff92SAndroid Build Coastguard Worker                              inputValues,
132*89c4ff92SAndroid Build Coastguard Worker                              filterValues,
133*89c4ff92SAndroid Build Coastguard Worker                              expectedOutputValues,
134*89c4ff92SAndroid Build Coastguard Worker                              biasShape,
135*89c4ff92SAndroid Build Coastguard Worker                              biasValues);
136*89c4ff92SAndroid Build Coastguard Worker }
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker 
Conv3DWithBiasesDilationFp32Test(std::vector<armnn::BackendId> & backends)139*89c4ff92SAndroid Build Coastguard Worker void Conv3DWithBiasesDilationFp32Test(std::vector<armnn::BackendId>& backends)
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 1, 5, 5, 5, 2 };
142*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> filterShape { 2, 2, 2, 2, 2 };
143*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasShape { 2 };
144*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape { 1, 2, 2, 2, 2 };
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValues = CreateFloatData<float>(250, 1.0f);
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> filterValues =
149*89c4ff92SAndroid Build Coastguard Worker     {
150*89c4ff92SAndroid Build Coastguard Worker         -1.f, -1.f,  -1.f, -1.f,  -1.f, -1.f,  -1.f, -1.f,  -1.f, -1.f,  -1.f,  1.f,   1.f,  1.f,  -1.f, -1.f,
151*89c4ff92SAndroid Build Coastguard Worker          1.f,  1.f,  -1.f,  1.f,  -1.f,  1.f,  -1.f,  1.f,  -1.f, -1.f,  -1.f,  1.f,  -1.f,  1.f,  -1.f,  1.f,
152*89c4ff92SAndroid Build Coastguard Worker     };
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasValues = { 0.f, 2.f };
155*89c4ff92SAndroid Build Coastguard Worker 
156*89c4ff92SAndroid Build Coastguard Worker     // Since the dilation rate is 3 this will dilate the kernel to be 4x4,
157*89c4ff92SAndroid Build Coastguard Worker     // therefore the output will be 2x2
158*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues =
159*89c4ff92SAndroid Build Coastguard Worker     {
160*89c4ff92SAndroid Build Coastguard Worker         -1124.f, 976.f,
161*89c4ff92SAndroid Build Coastguard Worker         -1148.f, 980.f,
162*89c4ff92SAndroid Build Coastguard Worker 
163*89c4ff92SAndroid Build Coastguard Worker         -1244.f, 996.f,
164*89c4ff92SAndroid Build Coastguard Worker         -1268.f, 1000.f,
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker         -1724.f, 1076.f,
167*89c4ff92SAndroid Build Coastguard Worker         -1748.f, 1080.f,
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker         -1844.f, 1096.f,
170*89c4ff92SAndroid Build Coastguard Worker         -1868.f, 1100.f
171*89c4ff92SAndroid Build Coastguard Worker     };
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker     Convolution3dTest<float>(tflite::BuiltinOperator_CONV_3D,
174*89c4ff92SAndroid Build Coastguard Worker                              ::tflite::TensorType_FLOAT32,
175*89c4ff92SAndroid Build Coastguard Worker                              { 1, 1, 1 }, // strideX, strideY, strideZ
176*89c4ff92SAndroid Build Coastguard Worker                              { 3, 3, 3 }, // dilationX, dilationY, dilationZ
177*89c4ff92SAndroid Build Coastguard Worker                              tflite::Padding_VALID,
178*89c4ff92SAndroid Build Coastguard Worker                              tflite::ActivationFunctionType_NONE,
179*89c4ff92SAndroid Build Coastguard Worker                              backends,
180*89c4ff92SAndroid Build Coastguard Worker                              inputShape,
181*89c4ff92SAndroid Build Coastguard Worker                              filterShape,
182*89c4ff92SAndroid Build Coastguard Worker                              outputShape,
183*89c4ff92SAndroid Build Coastguard Worker                              inputValues,
184*89c4ff92SAndroid Build Coastguard Worker                              filterValues,
185*89c4ff92SAndroid Build Coastguard Worker                              expectedOutputValues,
186*89c4ff92SAndroid Build Coastguard Worker                              biasShape,
187*89c4ff92SAndroid Build Coastguard Worker                              biasValues);
188*89c4ff92SAndroid Build Coastguard Worker }
189*89c4ff92SAndroid Build Coastguard Worker 
Conv3DFp32SmallTest(std::vector<armnn::BackendId> & backends)190*89c4ff92SAndroid Build Coastguard Worker void Conv3DFp32SmallTest(std::vector<armnn::BackendId>& backends)
191*89c4ff92SAndroid Build Coastguard Worker {
192*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> inputShape { 1, 3, 10, 10, 1 };
193*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> filterShape { 3, 3, 3, 1, 1 };
194*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> biasShape { 1 };
195*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> outputShape { 1, 1, 4, 4, 1 };
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputValues = CreateFloatData<float>(300, 100.0f);
198*89c4ff92SAndroid Build Coastguard Worker 
199*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> filterValues =
200*89c4ff92SAndroid Build Coastguard Worker     {
201*89c4ff92SAndroid Build Coastguard Worker          0.125977f,  0.150391f,  0.101562f,
202*89c4ff92SAndroid Build Coastguard Worker          0.0585938f, 0.0864258f, 0.043457f,
203*89c4ff92SAndroid Build Coastguard Worker          0.034668f,  0.0322266f, 0.0385742f,
204*89c4ff92SAndroid Build Coastguard Worker 
205*89c4ff92SAndroid Build Coastguard Worker          0.125977f,  0.150391f, -0.101562f,
206*89c4ff92SAndroid Build Coastguard Worker         -0.0585938f,-0.0864258f,-0.043457f,
207*89c4ff92SAndroid Build Coastguard Worker         -0.0104630f, 0.0154114f, 0.0013768f,
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker          0.0344238f, 0.035644f,  0.0495605f,
210*89c4ff92SAndroid Build Coastguard Worker          0.0683594f, 0.099121f, -0.0461426f,
211*89c4ff92SAndroid Build Coastguard Worker         -0.0996094f,-0.126953f, -0.043457f,
212*89c4ff92SAndroid Build Coastguard Worker     };
213*89c4ff92SAndroid Build Coastguard Worker 
214*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> biasValues = { 0 };
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> expectedOutputValues =
217*89c4ff92SAndroid Build Coastguard Worker     {
218*89c4ff92SAndroid Build Coastguard Worker         -0.08156067f, -0.06891209f, -0.05589598f, -0.04310101f,
219*89c4ff92SAndroid Build Coastguard Worker          0.04584253f,  0.05855697f,  0.07129729f,  0.08325434f,
220*89c4ff92SAndroid Build Coastguard Worker          0.17304349f,  0.18521416f,  0.19818866f,  0.21096253f,
221*89c4ff92SAndroid Build Coastguard Worker          0.29965734f,  0.312698f,    0.32547557f,  0.33818722f
222*89c4ff92SAndroid Build Coastguard Worker     };
223*89c4ff92SAndroid Build Coastguard Worker 
224*89c4ff92SAndroid Build Coastguard Worker     Convolution3dTest<float>(tflite::BuiltinOperator_CONV_3D,
225*89c4ff92SAndroid Build Coastguard Worker                              ::tflite::TensorType_FLOAT32,
226*89c4ff92SAndroid Build Coastguard Worker                              { 2, 2, 2 }, // strideX, strideY, strideZ
227*89c4ff92SAndroid Build Coastguard Worker                              { 1, 1, 1 }, // dilationX, dilationY, dilationZ
228*89c4ff92SAndroid Build Coastguard Worker                              tflite::Padding_VALID,
229*89c4ff92SAndroid Build Coastguard Worker                              tflite::ActivationFunctionType_NONE,
230*89c4ff92SAndroid Build Coastguard Worker                              backends,
231*89c4ff92SAndroid Build Coastguard Worker                              inputShape,
232*89c4ff92SAndroid Build Coastguard Worker                              filterShape,
233*89c4ff92SAndroid Build Coastguard Worker                              outputShape,
234*89c4ff92SAndroid Build Coastguard Worker                              inputValues,
235*89c4ff92SAndroid Build Coastguard Worker                              filterValues,
236*89c4ff92SAndroid Build Coastguard Worker                              expectedOutputValues,
237*89c4ff92SAndroid Build Coastguard Worker                              biasShape,
238*89c4ff92SAndroid Build Coastguard Worker                              biasValues);
239*89c4ff92SAndroid Build Coastguard Worker }
240*89c4ff92SAndroid Build Coastguard Worker 
241*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Convolution3dTest_CpuRefTests")
242*89c4ff92SAndroid Build Coastguard Worker {
243*89c4ff92SAndroid Build Coastguard Worker 
244*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DWithBiasesSimpleWithPadding_Fp32_CpuRef_Test")
245*89c4ff92SAndroid Build Coastguard Worker {
246*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
247*89c4ff92SAndroid Build Coastguard Worker     Conv3DWithBiasesSimpleWithPaddingFp32Test(backends);
248*89c4ff92SAndroid Build Coastguard Worker }
249*89c4ff92SAndroid Build Coastguard Worker 
250*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DWithBiasesStrides_Fp32_CpuRef_Test")
251*89c4ff92SAndroid Build Coastguard Worker {
252*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
253*89c4ff92SAndroid Build Coastguard Worker     Conv3DWithBiasesStridesFp32Test(backends);
254*89c4ff92SAndroid Build Coastguard Worker }
255*89c4ff92SAndroid Build Coastguard Worker 
256*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DWithBiasesDilation_Fp32_CpuRef_Test")
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
259*89c4ff92SAndroid Build Coastguard Worker     Conv3DWithBiasesDilationFp32Test(backends);
260*89c4ff92SAndroid Build Coastguard Worker }
261*89c4ff92SAndroid Build Coastguard Worker 
262*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DFp32Small_Fp32_CpuRef_Test")
263*89c4ff92SAndroid Build Coastguard Worker {
264*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuRef};
265*89c4ff92SAndroid Build Coastguard Worker     Conv3DFp32SmallTest(backends);
266*89c4ff92SAndroid Build Coastguard Worker }
267*89c4ff92SAndroid Build Coastguard Worker 
268*89c4ff92SAndroid Build Coastguard Worker } //End of TEST_SUITE("Convolution3dTest_CpuRefTests")
269*89c4ff92SAndroid Build Coastguard Worker 
270*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Convolution3dTest_CpuAccTests")
271*89c4ff92SAndroid Build Coastguard Worker {
272*89c4ff92SAndroid Build Coastguard Worker 
273*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DWithBiasesSimpleWithPadding_Fp32_CpuAcc_Test")
274*89c4ff92SAndroid Build Coastguard Worker {
275*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuAcc};
276*89c4ff92SAndroid Build Coastguard Worker     Conv3DWithBiasesSimpleWithPaddingFp32Test(backends);
277*89c4ff92SAndroid Build Coastguard Worker }
278*89c4ff92SAndroid Build Coastguard Worker 
279*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DWithBiasesStrides_Fp32_CpuAcc_Test")
280*89c4ff92SAndroid Build Coastguard Worker {
281*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuAcc};
282*89c4ff92SAndroid Build Coastguard Worker     Conv3DWithBiasesStridesFp32Test(backends);
283*89c4ff92SAndroid Build Coastguard Worker }
284*89c4ff92SAndroid Build Coastguard Worker 
285*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DFp32Small_Fp32_CpuAcc_Test")
286*89c4ff92SAndroid Build Coastguard Worker {
287*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::CpuAcc};
288*89c4ff92SAndroid Build Coastguard Worker     Conv3DFp32SmallTest(backends);
289*89c4ff92SAndroid Build Coastguard Worker }
290*89c4ff92SAndroid Build Coastguard Worker 
291*89c4ff92SAndroid Build Coastguard Worker } //End of TEST_SUITE("Convolution3dTest_CpuAccTests")
292*89c4ff92SAndroid Build Coastguard Worker 
293*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Convolution3dTest_GpuAccTests")
294*89c4ff92SAndroid Build Coastguard Worker {
295*89c4ff92SAndroid Build Coastguard Worker 
296*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DWithBiasesSimpleWithPadding_Fp32_GpuAcc_Test")
297*89c4ff92SAndroid Build Coastguard Worker {
298*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::GpuAcc};
299*89c4ff92SAndroid Build Coastguard Worker     Conv3DWithBiasesSimpleWithPaddingFp32Test(backends);
300*89c4ff92SAndroid Build Coastguard Worker }
301*89c4ff92SAndroid Build Coastguard Worker 
302*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DWithBiasesStrides_Fp32_GpuAcc_Test")
303*89c4ff92SAndroid Build Coastguard Worker {
304*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::GpuAcc};
305*89c4ff92SAndroid Build Coastguard Worker     Conv3DWithBiasesStridesFp32Test(backends);
306*89c4ff92SAndroid Build Coastguard Worker }
307*89c4ff92SAndroid Build Coastguard Worker 
308*89c4ff92SAndroid Build Coastguard Worker TEST_CASE ("Conv3DFp32Small_Fp32_GpuAcc_Test")
309*89c4ff92SAndroid Build Coastguard Worker {
310*89c4ff92SAndroid Build Coastguard Worker     std::vector <armnn::BackendId> backends = {armnn::Compute::GpuAcc};
311*89c4ff92SAndroid Build Coastguard Worker     Conv3DFp32SmallTest(backends);
312*89c4ff92SAndroid Build Coastguard Worker }
313*89c4ff92SAndroid Build Coastguard Worker 
314*89c4ff92SAndroid Build Coastguard Worker } //End of TEST_SUITE("Convolution3dTest_GpuAccTests")
315*89c4ff92SAndroid Build Coastguard Worker 
316*89c4ff92SAndroid Build Coastguard Worker #endif
317*89c4ff92SAndroid Build Coastguard Worker 
318*89c4ff92SAndroid Build Coastguard Worker } // namespace armnnDelegate