xref: /aosp_15_r20/external/armnn/src/backends/tosaReference/test/TosaRefEndToEndTests.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022-2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/EndToEndTestImpl.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/AdditionEndToEndTestImpl.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/Convolution2dEndToEndTestImpl.hpp"
10*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/ConcatEndToEndTestImpl.hpp"
11*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/MultiplicationEndToEndTestImpl.hpp"
12*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/Pooling2dEndToEndTestImpl.hpp"
13*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/ReshapeEndToEndTestImpl.hpp"
14*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/ElementwiseUnaryEndToEndTestImpl.hpp"
15*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/SliceEndToEndTestImpl.hpp"
16*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/SubtractionEndToEndTestImpl.hpp"
17*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/TransposeConvolution2dEndToEndTestImpl.hpp"
18*89c4ff92SAndroid Build Coastguard Worker #include "backendsCommon/test/TransposeEndToEndTestImpl.hpp"
19*89c4ff92SAndroid Build Coastguard Worker 
20*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TosaRefEndToEnd")
23*89c4ff92SAndroid Build Coastguard Worker {
24*89c4ff92SAndroid Build Coastguard Worker std::vector<BackendId> tosaDefaultBackends = { "TosaRef" };
25*89c4ff92SAndroid Build Coastguard Worker 
26*89c4ff92SAndroid Build Coastguard Worker // Addition
27*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefAdditionEndtoEndTestFloat32")
28*89c4ff92SAndroid Build Coastguard Worker {
29*89c4ff92SAndroid Build Coastguard Worker     AdditionEndToEnd<DataType::Float32>(tosaDefaultBackends);
30*89c4ff92SAndroid Build Coastguard Worker }
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefAdditionEndtoEndTestInt32")
33*89c4ff92SAndroid Build Coastguard Worker {
34*89c4ff92SAndroid Build Coastguard Worker     AdditionEndToEnd<DataType::Signed32>(tosaDefaultBackends);
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefAdditionEndtoEndTestFloat16")
38*89c4ff92SAndroid Build Coastguard Worker {
39*89c4ff92SAndroid Build Coastguard Worker     AdditionEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
40*89c4ff92SAndroid Build Coastguard Worker }
41*89c4ff92SAndroid Build Coastguard Worker 
42*89c4ff92SAndroid Build Coastguard Worker // Concat
43*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConcatEndToEndDim0TestFloat32")
44*89c4ff92SAndroid Build Coastguard Worker {
45*89c4ff92SAndroid Build Coastguard Worker     ConcatDim0EndToEnd<armnn::DataType::Float32>(tosaDefaultBackends);
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConcatEndToEndDim0TestInt32")
49*89c4ff92SAndroid Build Coastguard Worker {
50*89c4ff92SAndroid Build Coastguard Worker     ConcatDim0EndToEnd<armnn::DataType::Signed32>(tosaDefaultBackends);
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConcatEndToEndDim1TestFloat32")
54*89c4ff92SAndroid Build Coastguard Worker {
55*89c4ff92SAndroid Build Coastguard Worker     ConcatDim1EndToEnd<armnn::DataType::Float32>(tosaDefaultBackends);
56*89c4ff92SAndroid Build Coastguard Worker }
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConcatEndToEndDim1TestInt32")
59*89c4ff92SAndroid Build Coastguard Worker {
60*89c4ff92SAndroid Build Coastguard Worker     ConcatDim1EndToEnd<armnn::DataType::Signed32>(tosaDefaultBackends);
61*89c4ff92SAndroid Build Coastguard Worker }
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConcatEndToEndDim2TestFloat32")
64*89c4ff92SAndroid Build Coastguard Worker {
65*89c4ff92SAndroid Build Coastguard Worker     ConcatDim2EndToEnd<armnn::DataType::Float32>(tosaDefaultBackends);
66*89c4ff92SAndroid Build Coastguard Worker }
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConcatEndToEndDim2TestInt32")
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker     ConcatDim2EndToEnd<armnn::DataType::Signed32>(tosaDefaultBackends);
71*89c4ff92SAndroid Build Coastguard Worker }
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConcatEndToEndDim3TestFloat32")
74*89c4ff92SAndroid Build Coastguard Worker {
75*89c4ff92SAndroid Build Coastguard Worker     ConcatDim3EndToEnd<armnn::DataType::Float32>(tosaDefaultBackends);
76*89c4ff92SAndroid Build Coastguard Worker }
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConcatEndToEndDim3TestInt32")
79*89c4ff92SAndroid Build Coastguard Worker {
80*89c4ff92SAndroid Build Coastguard Worker     ConcatDim3EndToEnd<armnn::DataType::Signed32>(tosaDefaultBackends);
81*89c4ff92SAndroid Build Coastguard Worker }
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker // Conv2d
84*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConv2dEndtoEndTestFloat32")
85*89c4ff92SAndroid Build Coastguard Worker {
86*89c4ff92SAndroid Build Coastguard Worker     Convolution2dEndToEnd<armnn::DataType::Float32>(tosaDefaultBackends, armnn::DataLayout::NHWC);
87*89c4ff92SAndroid Build Coastguard Worker }
88*89c4ff92SAndroid Build Coastguard Worker 
89*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefConv2dWithoutBiasEndtoEndTestFloat32")
90*89c4ff92SAndroid Build Coastguard Worker {
91*89c4ff92SAndroid Build Coastguard Worker     Convolution2dEndToEnd<armnn::DataType::Float32>(tosaDefaultBackends, armnn::DataLayout::NHWC, false);
92*89c4ff92SAndroid Build Coastguard Worker }
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker // Average Pool 2D
95*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefAvgPool2DEndtoEndTestFloat32")
96*89c4ff92SAndroid Build Coastguard Worker {
97*89c4ff92SAndroid Build Coastguard Worker     AvgPool2dEndToEnd<DataType::Float32>(tosaDefaultBackends);
98*89c4ff92SAndroid Build Coastguard Worker }
99*89c4ff92SAndroid Build Coastguard Worker 
100*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefAvgPool2DEndtoEndTestFloat16")
101*89c4ff92SAndroid Build Coastguard Worker {
102*89c4ff92SAndroid Build Coastguard Worker     AvgPool2dEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
103*89c4ff92SAndroid Build Coastguard Worker }
104*89c4ff92SAndroid Build Coastguard Worker 
105*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefAvgPool2DIgnoreValueEndtoEndTestFloat32")
106*89c4ff92SAndroid Build Coastguard Worker {
107*89c4ff92SAndroid Build Coastguard Worker     AvgPool2dEndToEnd<DataType::Float32>(tosaDefaultBackends, PaddingMethod::IgnoreValue);
108*89c4ff92SAndroid Build Coastguard Worker }
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker // Max Pool 2D
111*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefMaxPool2DEndtoEndTestFloat32")
112*89c4ff92SAndroid Build Coastguard Worker {
113*89c4ff92SAndroid Build Coastguard Worker     MaxPool2dEndToEnd<DataType::Float32>(tosaDefaultBackends);
114*89c4ff92SAndroid Build Coastguard Worker }
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefMaxPool2DEndtoEndTestFloat16")
117*89c4ff92SAndroid Build Coastguard Worker {
118*89c4ff92SAndroid Build Coastguard Worker     MaxPool2dEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
119*89c4ff92SAndroid Build Coastguard Worker }
120*89c4ff92SAndroid Build Coastguard Worker 
121*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefMaxPool2DIgnoreValueEndtoEndTestFloat32")
122*89c4ff92SAndroid Build Coastguard Worker {
123*89c4ff92SAndroid Build Coastguard Worker     MaxPool2dEndToEnd<DataType::Float32>(tosaDefaultBackends, PaddingMethod::IgnoreValue);
124*89c4ff92SAndroid Build Coastguard Worker }
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker // Reshape
127*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefReshapeEndtoEndTestFloat32")
128*89c4ff92SAndroid Build Coastguard Worker {
129*89c4ff92SAndroid Build Coastguard Worker     ReshapeEndToEnd<DataType::Float32>(tosaDefaultBackends);
130*89c4ff92SAndroid Build Coastguard Worker }
131*89c4ff92SAndroid Build Coastguard Worker 
132*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefReshapeEndtoEndTestInt32")
133*89c4ff92SAndroid Build Coastguard Worker {
134*89c4ff92SAndroid Build Coastguard Worker     ReshapeEndToEnd<DataType::Signed32>(tosaDefaultBackends);
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefReshapeEndtoEndTestFloat16")
138*89c4ff92SAndroid Build Coastguard Worker {
139*89c4ff92SAndroid Build Coastguard Worker     ReshapeEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
140*89c4ff92SAndroid Build Coastguard Worker }
141*89c4ff92SAndroid Build Coastguard Worker 
142*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefRsqrtEndtoEndTestFloat32")
143*89c4ff92SAndroid Build Coastguard Worker {
144*89c4ff92SAndroid Build Coastguard Worker     ElementwiseUnarySimpleEndToEnd<armnn::DataType::Float32>(tosaDefaultBackends,
145*89c4ff92SAndroid Build Coastguard Worker                                                              UnaryOperation::Rsqrt);
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker // Slice
149*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefSliceEndtoEndTestFloat32")
150*89c4ff92SAndroid Build Coastguard Worker {
151*89c4ff92SAndroid Build Coastguard Worker     SliceEndToEnd<DataType::Float32>(tosaDefaultBackends);
152*89c4ff92SAndroid Build Coastguard Worker }
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefSliceEndtoEndTestInt32")
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker     SliceEndToEnd<DataType::Signed32>(tosaDefaultBackends);
157*89c4ff92SAndroid Build Coastguard Worker }
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefSliceEndtoEndTestFloat16")
160*89c4ff92SAndroid Build Coastguard Worker {
161*89c4ff92SAndroid Build Coastguard Worker     SliceEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
162*89c4ff92SAndroid Build Coastguard Worker }
163*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefSubtractionEndtoEndTestFloat32")
164*89c4ff92SAndroid Build Coastguard Worker {
165*89c4ff92SAndroid Build Coastguard Worker     SubtractionEndToEnd<DataType::Float32>(tosaDefaultBackends);
166*89c4ff92SAndroid Build Coastguard Worker }
167*89c4ff92SAndroid Build Coastguard Worker 
168*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefSubtractionEndtoEndTestInt32")
169*89c4ff92SAndroid Build Coastguard Worker {
170*89c4ff92SAndroid Build Coastguard Worker     SubtractionEndToEnd<DataType::Signed32>(tosaDefaultBackends);
171*89c4ff92SAndroid Build Coastguard Worker }
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefSubtractionEndtoEndTestFloat16")
174*89c4ff92SAndroid Build Coastguard Worker {
175*89c4ff92SAndroid Build Coastguard Worker     SubtractionEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
176*89c4ff92SAndroid Build Coastguard Worker }
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefMultiplicationEndtoEndTestFloat32")
179*89c4ff92SAndroid Build Coastguard Worker {
180*89c4ff92SAndroid Build Coastguard Worker     MultiplicationEndToEnd<DataType::Float32>(tosaDefaultBackends);
181*89c4ff92SAndroid Build Coastguard Worker }
182*89c4ff92SAndroid Build Coastguard Worker 
183*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefMultiplicationEndtoEndTestInt32")
184*89c4ff92SAndroid Build Coastguard Worker {
185*89c4ff92SAndroid Build Coastguard Worker     MultiplicationEndToEnd<DataType::Signed32>(tosaDefaultBackends);
186*89c4ff92SAndroid Build Coastguard Worker }
187*89c4ff92SAndroid Build Coastguard Worker 
188*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefMultiplicationEndtoEndTestFloat16")
189*89c4ff92SAndroid Build Coastguard Worker {
190*89c4ff92SAndroid Build Coastguard Worker     MultiplicationEndToEndFloat16<DataType::Float16>(tosaDefaultBackends);
191*89c4ff92SAndroid Build Coastguard Worker }
192*89c4ff92SAndroid Build Coastguard Worker 
193*89c4ff92SAndroid Build Coastguard Worker // TransposeConvolution2d
194*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefTransposeConvolution2dEndToEndFloatNhwcTest")
195*89c4ff92SAndroid Build Coastguard Worker {
196*89c4ff92SAndroid Build Coastguard Worker     TransposeConvolution2dEndToEnd<armnn::DataType::Float32, armnn::DataType::Float32>(
197*89c4ff92SAndroid Build Coastguard Worker         tosaDefaultBackends, armnn::DataLayout::NHWC);
198*89c4ff92SAndroid Build Coastguard Worker }
199*89c4ff92SAndroid Build Coastguard Worker 
200*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefSimpleTransposeConvolution2dEndToEndFloatNhwcTest")
201*89c4ff92SAndroid Build Coastguard Worker {
202*89c4ff92SAndroid Build Coastguard Worker     SimpleTransposeConvolution2dEndToEnd<armnn::DataType::Float32, armnn::DataType::Float32>(
203*89c4ff92SAndroid Build Coastguard Worker         tosaDefaultBackends, armnn::DataLayout::NHWC);
204*89c4ff92SAndroid Build Coastguard Worker }
205*89c4ff92SAndroid Build Coastguard Worker 
206*89c4ff92SAndroid Build Coastguard Worker // Transpose
207*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("TosaRefTransposeEndtoEndTestFloat32")
208*89c4ff92SAndroid Build Coastguard Worker {
209*89c4ff92SAndroid Build Coastguard Worker     TransposeEndToEnd<armnn::DataType::Float32>(tosaDefaultBackends);
210*89c4ff92SAndroid Build Coastguard Worker }
211*89c4ff92SAndroid Build Coastguard Worker 
212*89c4ff92SAndroid Build Coastguard Worker }