xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeArgMinMax.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "../Deserializer.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <string>
10*89c4ff92SAndroid Build Coastguard Worker #include <iostream>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("DeserializeParser_ArgMinMax")
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker struct ArgMinMaxFixture : public ParserFlatbuffersSerializeFixture
15*89c4ff92SAndroid Build Coastguard Worker {
ArgMinMaxFixtureArgMinMaxFixture16*89c4ff92SAndroid Build Coastguard Worker     explicit ArgMinMaxFixture(const std::string& inputShape,
17*89c4ff92SAndroid Build Coastguard Worker                               const std::string& outputShape,
18*89c4ff92SAndroid Build Coastguard Worker                               const std::string& axis,
19*89c4ff92SAndroid Build Coastguard Worker                               const std::string& argMinMaxFunction)
20*89c4ff92SAndroid Build Coastguard Worker     {
21*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
22*89c4ff92SAndroid Build Coastguard Worker         {
23*89c4ff92SAndroid Build Coastguard Worker           layers: [
24*89c4ff92SAndroid Build Coastguard Worker             {
25*89c4ff92SAndroid Build Coastguard Worker               layer_type: "InputLayer",
26*89c4ff92SAndroid Build Coastguard Worker               layer: {
27*89c4ff92SAndroid Build Coastguard Worker                 base: {
28*89c4ff92SAndroid Build Coastguard Worker                   base: {
29*89c4ff92SAndroid Build Coastguard Worker                     layerName: "InputLayer",
30*89c4ff92SAndroid Build Coastguard Worker                     layerType: "Input",
31*89c4ff92SAndroid Build Coastguard Worker                     inputSlots: [
32*89c4ff92SAndroid Build Coastguard Worker 
33*89c4ff92SAndroid Build Coastguard Worker                     ],
34*89c4ff92SAndroid Build Coastguard Worker                     outputSlots: [
35*89c4ff92SAndroid Build Coastguard Worker                       {
36*89c4ff92SAndroid Build Coastguard Worker                         tensorInfo: {
37*89c4ff92SAndroid Build Coastguard Worker                           dimensions: )" + inputShape + R"(,
38*89c4ff92SAndroid Build Coastguard Worker                           dataType: "Float32",
39*89c4ff92SAndroid Build Coastguard Worker                           quantizationScale: 0.0
40*89c4ff92SAndroid Build Coastguard Worker                         }
41*89c4ff92SAndroid Build Coastguard Worker                       }
42*89c4ff92SAndroid Build Coastguard Worker                     ]
43*89c4ff92SAndroid Build Coastguard Worker                   }
44*89c4ff92SAndroid Build Coastguard Worker                 }
45*89c4ff92SAndroid Build Coastguard Worker               }
46*89c4ff92SAndroid Build Coastguard Worker             },
47*89c4ff92SAndroid Build Coastguard Worker             {
48*89c4ff92SAndroid Build Coastguard Worker               layer_type: "ArgMinMaxLayer",
49*89c4ff92SAndroid Build Coastguard Worker               layer: {
50*89c4ff92SAndroid Build Coastguard Worker                 base: {
51*89c4ff92SAndroid Build Coastguard Worker                   index: 1,
52*89c4ff92SAndroid Build Coastguard Worker                   layerName: "ArgMinMaxLayer",
53*89c4ff92SAndroid Build Coastguard Worker                   layerType: "ArgMinMax",
54*89c4ff92SAndroid Build Coastguard Worker                   inputSlots: [
55*89c4ff92SAndroid Build Coastguard Worker                     {
56*89c4ff92SAndroid Build Coastguard Worker                       connection: {
57*89c4ff92SAndroid Build Coastguard Worker                         sourceLayerIndex: 0,
58*89c4ff92SAndroid Build Coastguard Worker                         outputSlotIndex: 0
59*89c4ff92SAndroid Build Coastguard Worker                       }
60*89c4ff92SAndroid Build Coastguard Worker                     }
61*89c4ff92SAndroid Build Coastguard Worker                   ],
62*89c4ff92SAndroid Build Coastguard Worker                   outputSlots: [
63*89c4ff92SAndroid Build Coastguard Worker                     {
64*89c4ff92SAndroid Build Coastguard Worker                       tensorInfo: {
65*89c4ff92SAndroid Build Coastguard Worker                         dimensions: )" + outputShape + R"(,
66*89c4ff92SAndroid Build Coastguard Worker                         dataType: "Signed64",
67*89c4ff92SAndroid Build Coastguard Worker                         quantizationScale: 0.0
68*89c4ff92SAndroid Build Coastguard Worker                       }
69*89c4ff92SAndroid Build Coastguard Worker                     }
70*89c4ff92SAndroid Build Coastguard Worker                   ]
71*89c4ff92SAndroid Build Coastguard Worker                 },
72*89c4ff92SAndroid Build Coastguard Worker                 descriptor: {
73*89c4ff92SAndroid Build Coastguard Worker                   axis: )" + axis + R"(,
74*89c4ff92SAndroid Build Coastguard Worker                   argMinMaxFunction: )" + argMinMaxFunction + R"(
75*89c4ff92SAndroid Build Coastguard Worker                 }
76*89c4ff92SAndroid Build Coastguard Worker               }
77*89c4ff92SAndroid Build Coastguard Worker             },
78*89c4ff92SAndroid Build Coastguard Worker             {
79*89c4ff92SAndroid Build Coastguard Worker               layer_type: "OutputLayer",
80*89c4ff92SAndroid Build Coastguard Worker               layer: {
81*89c4ff92SAndroid Build Coastguard Worker                 base: {
82*89c4ff92SAndroid Build Coastguard Worker                   base: {
83*89c4ff92SAndroid Build Coastguard Worker                     index: 2,
84*89c4ff92SAndroid Build Coastguard Worker                     layerName: "OutputLayer",
85*89c4ff92SAndroid Build Coastguard Worker                     layerType: "Output",
86*89c4ff92SAndroid Build Coastguard Worker                     inputSlots: [
87*89c4ff92SAndroid Build Coastguard Worker                       {
88*89c4ff92SAndroid Build Coastguard Worker                         connection: {
89*89c4ff92SAndroid Build Coastguard Worker                           sourceLayerIndex: 1,
90*89c4ff92SAndroid Build Coastguard Worker                           outputSlotIndex: 0
91*89c4ff92SAndroid Build Coastguard Worker                         }
92*89c4ff92SAndroid Build Coastguard Worker                       }
93*89c4ff92SAndroid Build Coastguard Worker                     ],
94*89c4ff92SAndroid Build Coastguard Worker                     outputSlots: [
95*89c4ff92SAndroid Build Coastguard Worker 
96*89c4ff92SAndroid Build Coastguard Worker                     ]
97*89c4ff92SAndroid Build Coastguard Worker                   }
98*89c4ff92SAndroid Build Coastguard Worker                 }
99*89c4ff92SAndroid Build Coastguard Worker               }
100*89c4ff92SAndroid Build Coastguard Worker             }
101*89c4ff92SAndroid Build Coastguard Worker           ],
102*89c4ff92SAndroid Build Coastguard Worker           inputIds: [
103*89c4ff92SAndroid Build Coastguard Worker             0
104*89c4ff92SAndroid Build Coastguard Worker           ],
105*89c4ff92SAndroid Build Coastguard Worker           outputIds: [
106*89c4ff92SAndroid Build Coastguard Worker             0
107*89c4ff92SAndroid Build Coastguard Worker           ],
108*89c4ff92SAndroid Build Coastguard Worker           featureVersions: {
109*89c4ff92SAndroid Build Coastguard Worker             bindingIdsScheme: 1
110*89c4ff92SAndroid Build Coastguard Worker           }
111*89c4ff92SAndroid Build Coastguard Worker         }
112*89c4ff92SAndroid Build Coastguard Worker     )";
113*89c4ff92SAndroid Build Coastguard Worker         Setup();
114*89c4ff92SAndroid Build Coastguard Worker     }
115*89c4ff92SAndroid Build Coastguard Worker };
116*89c4ff92SAndroid Build Coastguard Worker 
117*89c4ff92SAndroid Build Coastguard Worker struct SimpleArgMinMaxFixture : public ArgMinMaxFixture
118*89c4ff92SAndroid Build Coastguard Worker {
SimpleArgMinMaxFixtureSimpleArgMinMaxFixture119*89c4ff92SAndroid Build Coastguard Worker     SimpleArgMinMaxFixture() : ArgMinMaxFixture("[ 1, 1, 1, 5 ]",
120*89c4ff92SAndroid Build Coastguard Worker                                                 "[ 1, 1, 1 ]",
121*89c4ff92SAndroid Build Coastguard Worker                                                 "-1",
122*89c4ff92SAndroid Build Coastguard Worker                                                 "Max") {}
123*89c4ff92SAndroid Build Coastguard Worker };
124*89c4ff92SAndroid Build Coastguard Worker 
125*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleArgMinMaxFixture, "ArgMinMax")
126*89c4ff92SAndroid Build Coastguard Worker {
127*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed64>(
128*89c4ff92SAndroid Build Coastguard Worker             0,
129*89c4ff92SAndroid Build Coastguard Worker             {{"InputLayer", { 6.0f, 2.0f, 8.0f, 10.0f, 9.0f}}},
130*89c4ff92SAndroid Build Coastguard Worker             {{"OutputLayer",{ 3l }}});
131*89c4ff92SAndroid Build Coastguard Worker }
132*89c4ff92SAndroid Build Coastguard Worker 
133*89c4ff92SAndroid Build Coastguard Worker }
134