xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/test/ArgMinMaxEndToEndTestImpl.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <CommonTestUtils.hpp>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
11*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker namespace
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker 
CreateArgMinMaxNetwork(const armnn::TensorInfo & inputTensorInfo,const armnn::TensorInfo & outputTensorInfo,armnn::ArgMinMaxFunction function,int axis)17*89c4ff92SAndroid Build Coastguard Worker armnn::INetworkPtr CreateArgMinMaxNetwork(const armnn::TensorInfo& inputTensorInfo,
18*89c4ff92SAndroid Build Coastguard Worker                                           const armnn::TensorInfo& outputTensorInfo,
19*89c4ff92SAndroid Build Coastguard Worker                                           armnn::ArgMinMaxFunction function,
20*89c4ff92SAndroid Build Coastguard Worker                                           int axis)
21*89c4ff92SAndroid Build Coastguard Worker {
22*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network(armnn::INetwork::Create());
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker     armnn::ArgMinMaxDescriptor descriptor;
25*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Function = function;
26*89c4ff92SAndroid Build Coastguard Worker     descriptor.m_Axis = axis;
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* inputLayer  = network->AddInputLayer(0, "Input");
29*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* argMinMaxLayer  = network->AddArgMinMaxLayer(descriptor, "ArgMinMax");
30*89c4ff92SAndroid Build Coastguard Worker     armnn::IConnectableLayer* outputLayer = network->AddOutputLayer(0, "Output");
31*89c4ff92SAndroid Build Coastguard Worker 
32*89c4ff92SAndroid Build Coastguard Worker     Connect(inputLayer, argMinMaxLayer, inputTensorInfo, 0, 0);
33*89c4ff92SAndroid Build Coastguard Worker     Connect(argMinMaxLayer, outputLayer, outputTensorInfo, 0, 0);
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker     return network;
36*89c4ff92SAndroid Build Coastguard Worker }
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMinMaxEndToEndImpl(const armnn::TensorShape & inputShape,const armnn::TensorShape & outputShape,const std::vector<float> & inputData,const std::vector<int32_t> & expectedOutputData,armnn::ArgMinMaxFunction function,int axis,const std::vector<armnn::BackendId> & backends)39*89c4ff92SAndroid Build Coastguard Worker void ArgMinMaxEndToEndImpl(const armnn::TensorShape& inputShape,
40*89c4ff92SAndroid Build Coastguard Worker                            const armnn::TensorShape& outputShape,
41*89c4ff92SAndroid Build Coastguard Worker                            const std::vector<float>& inputData,
42*89c4ff92SAndroid Build Coastguard Worker                            const std::vector<int32_t>& expectedOutputData,
43*89c4ff92SAndroid Build Coastguard Worker                            armnn::ArgMinMaxFunction function,
44*89c4ff92SAndroid Build Coastguard Worker                            int axis,
45*89c4ff92SAndroid Build Coastguard Worker                            const std::vector<armnn::BackendId>& backends)
46*89c4ff92SAndroid Build Coastguard Worker {
47*89c4ff92SAndroid Build Coastguard Worker     const float qScale  = armnn::IsQuantizedType<T>() ? 2.0f : 1.0f;
48*89c4ff92SAndroid Build Coastguard Worker     const int32_t qOffset = armnn::IsQuantizedType<T>() ? 2 : 0;
49*89c4ff92SAndroid Build Coastguard Worker 
50*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo(inputShape, ArmnnType, qScale, qOffset, true);
51*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo(outputShape, armnn::DataType::Signed32);
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker     // quantize data
54*89c4ff92SAndroid Build Coastguard Worker     std::vector<T> qInputData = armnnUtils::QuantizedVector<T>(inputData, qScale, qOffset);
55*89c4ff92SAndroid Build Coastguard Worker 
56*89c4ff92SAndroid Build Coastguard Worker     armnn::INetworkPtr network = CreateArgMinMaxNetwork(inputTensorInfo,
57*89c4ff92SAndroid Build Coastguard Worker                                                         outputTensorInfo,
58*89c4ff92SAndroid Build Coastguard Worker                                                         function,
59*89c4ff92SAndroid Build Coastguard Worker                                                         axis);
60*89c4ff92SAndroid Build Coastguard Worker 
61*89c4ff92SAndroid Build Coastguard Worker     EndToEndLayerTestImpl<ArmnnType, armnn::DataType::Signed32>(std::move(network),
62*89c4ff92SAndroid Build Coastguard Worker                                                                 { { 0, qInputData } },
63*89c4ff92SAndroid Build Coastguard Worker                                                                 { { 0, expectedOutputData } },
64*89c4ff92SAndroid Build Coastguard Worker                                                                 backends);
65*89c4ff92SAndroid Build Coastguard Worker }
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMaxEndToEndSimple(const std::vector<armnn::BackendId> & backends)68*89c4ff92SAndroid Build Coastguard Worker void ArgMaxEndToEndSimple(const std::vector<armnn::BackendId>& backends)
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 1, 1, 1, 5 };
71*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 1, 1, 1 };
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({ 6.0f, 2.0f, 8.0f, 10.0f, 9.0f });
74*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 3 });
75*89c4ff92SAndroid Build Coastguard Worker 
76*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
77*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
78*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
79*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
80*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Max,
81*89c4ff92SAndroid Build Coastguard Worker                                      -1,
82*89c4ff92SAndroid Build Coastguard Worker                                      backends);
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMinEndToEndSimple(const std::vector<armnn::BackendId> & backends)86*89c4ff92SAndroid Build Coastguard Worker void ArgMinEndToEndSimple(const std::vector<armnn::BackendId>& backends)
87*89c4ff92SAndroid Build Coastguard Worker {
88*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 1, 1, 1, 5 };
89*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 1, 1, 1 };
90*89c4ff92SAndroid Build Coastguard Worker 
91*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({ 6.0f, 2.0f, 8.0f, 10.0f, 9.0f });
92*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 1 });
93*89c4ff92SAndroid Build Coastguard Worker 
94*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
95*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
96*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
97*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
98*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Min,
99*89c4ff92SAndroid Build Coastguard Worker                                      3,
100*89c4ff92SAndroid Build Coastguard Worker                                      backends);
101*89c4ff92SAndroid Build Coastguard Worker }
102*89c4ff92SAndroid Build Coastguard Worker 
103*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMaxAxis0EndToEnd(const std::vector<armnn::BackendId> & backends)104*89c4ff92SAndroid Build Coastguard Worker void ArgMaxAxis0EndToEnd(const std::vector<armnn::BackendId>& backends)
105*89c4ff92SAndroid Build Coastguard Worker {
106*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 3, 2, 1, 4 };
107*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 2, 1, 4 };
108*89c4ff92SAndroid Build Coastguard Worker 
109*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({    1.0f,   2.0f,   3.0f,   4.0f,
110*89c4ff92SAndroid Build Coastguard Worker                                       8.0f,   7.0f,   6.0f,   6.0f,
111*89c4ff92SAndroid Build Coastguard Worker                                     100.0f,  20.0f, 300.0f,  40.0f,
112*89c4ff92SAndroid Build Coastguard Worker                                     500.0f, 476.0f, 450.0f, 426.0f,
113*89c4ff92SAndroid Build Coastguard Worker                                      50.0f,  60.0f,  70.0f,  80.0f,
114*89c4ff92SAndroid Build Coastguard Worker                                      10.0f, 200.0f,  30.0f, 400.0f });
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 1, 2, 1, 2,
117*89c4ff92SAndroid Build Coastguard Worker                                               1, 1, 1, 1 });
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
120*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
121*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
122*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
123*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Max,
124*89c4ff92SAndroid Build Coastguard Worker                                      0,
125*89c4ff92SAndroid Build Coastguard Worker                                      backends);
126*89c4ff92SAndroid Build Coastguard Worker }
127*89c4ff92SAndroid Build Coastguard Worker 
128*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMinAxis0EndToEnd(const std::vector<armnn::BackendId> & backends)129*89c4ff92SAndroid Build Coastguard Worker void ArgMinAxis0EndToEnd(const std::vector<armnn::BackendId>& backends)
130*89c4ff92SAndroid Build Coastguard Worker {
131*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 3, 2, 1, 4 };
132*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 2, 1, 4 };
133*89c4ff92SAndroid Build Coastguard Worker 
134*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({    1.0f,   2.0f,   3.0f,   4.0f,
135*89c4ff92SAndroid Build Coastguard Worker                                       8.0f,   7.0f,   6.0f,   6.0f,
136*89c4ff92SAndroid Build Coastguard Worker                                     100.0f,  20.0f, 300.0f,  40.0f,
137*89c4ff92SAndroid Build Coastguard Worker                                     500.0f, 476.0f, 450.0f, 426.0f,
138*89c4ff92SAndroid Build Coastguard Worker                                      50.0f,  60.0f,  70.0f,  80.0f,
139*89c4ff92SAndroid Build Coastguard Worker                                      10.0f, 200.0f,  30.0f, 400.0f });
140*89c4ff92SAndroid Build Coastguard Worker 
141*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 0, 0, 0, 0,
142*89c4ff92SAndroid Build Coastguard Worker                                               0, 0, 0, 0 });
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
145*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
146*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
147*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
148*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Min,
149*89c4ff92SAndroid Build Coastguard Worker                                      0,
150*89c4ff92SAndroid Build Coastguard Worker                                      backends);
151*89c4ff92SAndroid Build Coastguard Worker }
152*89c4ff92SAndroid Build Coastguard Worker 
153*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMaxAxis1EndToEnd(const std::vector<armnn::BackendId> & backends)154*89c4ff92SAndroid Build Coastguard Worker void ArgMaxAxis1EndToEnd(const std::vector<armnn::BackendId>& backends)
155*89c4ff92SAndroid Build Coastguard Worker {
156*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
157*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 1, 2, 4 };
158*89c4ff92SAndroid Build Coastguard Worker 
159*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({    1.0f,   2.0f,   3.0f,   4.0f,
160*89c4ff92SAndroid Build Coastguard Worker                                       8.0f,   7.0f,   6.0f,   6.0f,
161*89c4ff92SAndroid Build Coastguard Worker                                     100.0f,  20.0f, 300.0f,  40.0f,
162*89c4ff92SAndroid Build Coastguard Worker                                     500.0f, 476.0f, 450.0f, 426.0f,
163*89c4ff92SAndroid Build Coastguard Worker                                      50.0f,  60.0f,  70.0f,  80.0f,
164*89c4ff92SAndroid Build Coastguard Worker                                      10.0f, 200.0f,  30.0f, 400.0f });
165*89c4ff92SAndroid Build Coastguard Worker 
166*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 1, 2, 1, 2,
167*89c4ff92SAndroid Build Coastguard Worker                                               1, 1, 1, 1 });
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
170*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
171*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
172*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
173*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Max,
174*89c4ff92SAndroid Build Coastguard Worker                                      1,
175*89c4ff92SAndroid Build Coastguard Worker                                      backends);
176*89c4ff92SAndroid Build Coastguard Worker }
177*89c4ff92SAndroid Build Coastguard Worker 
178*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMinAxis1EndToEnd(const std::vector<armnn::BackendId> & backends)179*89c4ff92SAndroid Build Coastguard Worker void ArgMinAxis1EndToEnd(const std::vector<armnn::BackendId>& backends)
180*89c4ff92SAndroid Build Coastguard Worker {
181*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
182*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 1, 2, 4 };
183*89c4ff92SAndroid Build Coastguard Worker 
184*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({    1.0f,   2.0f,   3.0f,   4.0f,
185*89c4ff92SAndroid Build Coastguard Worker                                       8.0f,   7.0f,   6.0f,   6.0f,
186*89c4ff92SAndroid Build Coastguard Worker                                     100.0f,  20.0f, 300.0f,  40.0f,
187*89c4ff92SAndroid Build Coastguard Worker                                     500.0f, 476.0f, 450.0f, 426.0f,
188*89c4ff92SAndroid Build Coastguard Worker                                      50.0f,  60.0f,  70.0f,  80.0f,
189*89c4ff92SAndroid Build Coastguard Worker                                      10.0f, 200.0f,  30.0f, 400.0f });
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 0, 0, 0, 0,
192*89c4ff92SAndroid Build Coastguard Worker                                               0, 0, 0, 0 });
193*89c4ff92SAndroid Build Coastguard Worker 
194*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
195*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
196*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
197*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
198*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Min,
199*89c4ff92SAndroid Build Coastguard Worker                                      1,
200*89c4ff92SAndroid Build Coastguard Worker                                      backends);
201*89c4ff92SAndroid Build Coastguard Worker }
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMaxAxis2EndToEnd(const std::vector<armnn::BackendId> & backends)204*89c4ff92SAndroid Build Coastguard Worker void ArgMaxAxis2EndToEnd(const std::vector<armnn::BackendId>& backends)
205*89c4ff92SAndroid Build Coastguard Worker {
206*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
207*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 1, 3, 4 };
208*89c4ff92SAndroid Build Coastguard Worker 
209*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({    1.0f,   2.0f,   3.0f,   4.0f,
210*89c4ff92SAndroid Build Coastguard Worker                                       8.0f,   7.0f,   6.0f,   6.0f,
211*89c4ff92SAndroid Build Coastguard Worker                                     100.0f,  20.0f, 300.0f,  40.0f,
212*89c4ff92SAndroid Build Coastguard Worker                                     500.0f, 476.0f, 450.0f, 426.0f,
213*89c4ff92SAndroid Build Coastguard Worker                                      10.0f, 200.0f,  30.0f, 400.0f,
214*89c4ff92SAndroid Build Coastguard Worker                                      50.0f,  60.0f,  70.0f,  80.0f });
215*89c4ff92SAndroid Build Coastguard Worker 
216*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 1, 1, 1, 1,
217*89c4ff92SAndroid Build Coastguard Worker                                               1, 1, 1, 1,
218*89c4ff92SAndroid Build Coastguard Worker                                               1, 0, 1, 0});
219*89c4ff92SAndroid Build Coastguard Worker 
220*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
221*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
222*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
223*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
224*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Max,
225*89c4ff92SAndroid Build Coastguard Worker                                      2,
226*89c4ff92SAndroid Build Coastguard Worker                                      backends);
227*89c4ff92SAndroid Build Coastguard Worker }
228*89c4ff92SAndroid Build Coastguard Worker 
229*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMinAxis2EndToEnd(const std::vector<armnn::BackendId> & backends)230*89c4ff92SAndroid Build Coastguard Worker void ArgMinAxis2EndToEnd(const std::vector<armnn::BackendId>& backends)
231*89c4ff92SAndroid Build Coastguard Worker {
232*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
233*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 1, 3, 4 };
234*89c4ff92SAndroid Build Coastguard Worker 
235*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({    1.0f,   2.0f,   3.0f,   4.0f,
236*89c4ff92SAndroid Build Coastguard Worker                                       8.0f,   7.0f,   6.0f,   6.0f,
237*89c4ff92SAndroid Build Coastguard Worker                                     100.0f,  20.0f, 300.0f,  40.0f,
238*89c4ff92SAndroid Build Coastguard Worker                                     500.0f, 476.0f, 450.0f, 426.0f,
239*89c4ff92SAndroid Build Coastguard Worker                                      10.0f, 200.0f,  30.0f, 400.0f,
240*89c4ff92SAndroid Build Coastguard Worker                                      50.0f,  60.0f,  70.0f,  80.0f });
241*89c4ff92SAndroid Build Coastguard Worker 
242*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 0, 0, 0, 0,
243*89c4ff92SAndroid Build Coastguard Worker                                               0, 0, 0, 0,
244*89c4ff92SAndroid Build Coastguard Worker                                               0, 1, 0, 1 });
245*89c4ff92SAndroid Build Coastguard Worker 
246*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
247*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
248*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
249*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
250*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Min,
251*89c4ff92SAndroid Build Coastguard Worker                                      2,
252*89c4ff92SAndroid Build Coastguard Worker                                      backends);
253*89c4ff92SAndroid Build Coastguard Worker }
254*89c4ff92SAndroid Build Coastguard Worker 
255*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMaxAxis3EndToEnd(const std::vector<armnn::BackendId> & backends)256*89c4ff92SAndroid Build Coastguard Worker void ArgMaxAxis3EndToEnd(const std::vector<armnn::BackendId>& backends)
257*89c4ff92SAndroid Build Coastguard Worker {
258*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
259*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 1, 3, 2 };
260*89c4ff92SAndroid Build Coastguard Worker 
261*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({    1.0f,   3.0f,   6.0f,   7.0f,
262*89c4ff92SAndroid Build Coastguard Worker                                       8.0f,   7.0f,   6.0f,   6.0f,
263*89c4ff92SAndroid Build Coastguard Worker                                     100.0f,  20.0f, 300.0f,  40.0f,
264*89c4ff92SAndroid Build Coastguard Worker                                     500.0f, 476.0f, 450.0f, 426.0f,
265*89c4ff92SAndroid Build Coastguard Worker                                      10.0f, 200.0f,  30.0f, 400.0f,
266*89c4ff92SAndroid Build Coastguard Worker                                      50.0f,  60.0f,  70.0f,  80.0f });
267*89c4ff92SAndroid Build Coastguard Worker 
268*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 3, 0,
269*89c4ff92SAndroid Build Coastguard Worker                                               2, 0,
270*89c4ff92SAndroid Build Coastguard Worker                                               3, 3});
271*89c4ff92SAndroid Build Coastguard Worker 
272*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
273*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
274*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
275*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
276*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Max,
277*89c4ff92SAndroid Build Coastguard Worker                                      3,
278*89c4ff92SAndroid Build Coastguard Worker                                      backends);
279*89c4ff92SAndroid Build Coastguard Worker }
280*89c4ff92SAndroid Build Coastguard Worker 
281*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, typename T = armnn::ResolveType<ArmnnType>>
ArgMinAxis3EndToEnd(const std::vector<armnn::BackendId> & backends)282*89c4ff92SAndroid Build Coastguard Worker void ArgMinAxis3EndToEnd(const std::vector<armnn::BackendId>& backends)
283*89c4ff92SAndroid Build Coastguard Worker {
284*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape inputShape{ 1, 3, 2, 4 };
285*89c4ff92SAndroid Build Coastguard Worker     const armnn::TensorShape outputShape{ 1, 3, 2 };
286*89c4ff92SAndroid Build Coastguard Worker 
287*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> inputData({    1.0f,   3.0f,   6.0f,   7.0f,
288*89c4ff92SAndroid Build Coastguard Worker                                      18.0f,  16.0f,  14.0f,  12.0f,
289*89c4ff92SAndroid Build Coastguard Worker                                     100.0f,  20.0f, 300.0f,  40.0f,
290*89c4ff92SAndroid Build Coastguard Worker                                     500.0f, 476.0f, 450.0f, 426.0f,
291*89c4ff92SAndroid Build Coastguard Worker                                      10.0f, 200.0f,  30.0f, 400.0f,
292*89c4ff92SAndroid Build Coastguard Worker                                      50.0f,  60.0f,  70.0f,  80.0f });
293*89c4ff92SAndroid Build Coastguard Worker 
294*89c4ff92SAndroid Build Coastguard Worker     std::vector<int32_t> expectedOutputData({ 0, 3,
295*89c4ff92SAndroid Build Coastguard Worker                                               1, 3,
296*89c4ff92SAndroid Build Coastguard Worker                                               0, 0 });
297*89c4ff92SAndroid Build Coastguard Worker 
298*89c4ff92SAndroid Build Coastguard Worker     ArgMinMaxEndToEndImpl<ArmnnType>(inputShape,
299*89c4ff92SAndroid Build Coastguard Worker                                      outputShape,
300*89c4ff92SAndroid Build Coastguard Worker                                      inputData,
301*89c4ff92SAndroid Build Coastguard Worker                                      expectedOutputData,
302*89c4ff92SAndroid Build Coastguard Worker                                      armnn::ArgMinMaxFunction::Min,
303*89c4ff92SAndroid Build Coastguard Worker                                      3,
304*89c4ff92SAndroid Build Coastguard Worker                                      backends);
305*89c4ff92SAndroid Build Coastguard Worker }
306*89c4ff92SAndroid Build Coastguard Worker 
307*89c4ff92SAndroid Build Coastguard Worker } // anonymous namespace
308