xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeArgMinMax.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersSerializeFixture.hpp"
7 #include "../Deserializer.hpp"
8 
9 #include <string>
10 #include <iostream>
11 
12 TEST_SUITE("DeserializeParser_ArgMinMax")
13 {
14 struct ArgMinMaxFixture : public ParserFlatbuffersSerializeFixture
15 {
ArgMinMaxFixtureArgMinMaxFixture16     explicit ArgMinMaxFixture(const std::string& inputShape,
17                               const std::string& outputShape,
18                               const std::string& axis,
19                               const std::string& argMinMaxFunction)
20     {
21         m_JsonString = R"(
22         {
23           layers: [
24             {
25               layer_type: "InputLayer",
26               layer: {
27                 base: {
28                   base: {
29                     layerName: "InputLayer",
30                     layerType: "Input",
31                     inputSlots: [
32 
33                     ],
34                     outputSlots: [
35                       {
36                         tensorInfo: {
37                           dimensions: )" + inputShape + R"(,
38                           dataType: "Float32",
39                           quantizationScale: 0.0
40                         }
41                       }
42                     ]
43                   }
44                 }
45               }
46             },
47             {
48               layer_type: "ArgMinMaxLayer",
49               layer: {
50                 base: {
51                   index: 1,
52                   layerName: "ArgMinMaxLayer",
53                   layerType: "ArgMinMax",
54                   inputSlots: [
55                     {
56                       connection: {
57                         sourceLayerIndex: 0,
58                         outputSlotIndex: 0
59                       }
60                     }
61                   ],
62                   outputSlots: [
63                     {
64                       tensorInfo: {
65                         dimensions: )" + outputShape + R"(,
66                         dataType: "Signed64",
67                         quantizationScale: 0.0
68                       }
69                     }
70                   ]
71                 },
72                 descriptor: {
73                   axis: )" + axis + R"(,
74                   argMinMaxFunction: )" + argMinMaxFunction + R"(
75                 }
76               }
77             },
78             {
79               layer_type: "OutputLayer",
80               layer: {
81                 base: {
82                   base: {
83                     index: 2,
84                     layerName: "OutputLayer",
85                     layerType: "Output",
86                     inputSlots: [
87                       {
88                         connection: {
89                           sourceLayerIndex: 1,
90                           outputSlotIndex: 0
91                         }
92                       }
93                     ],
94                     outputSlots: [
95 
96                     ]
97                   }
98                 }
99               }
100             }
101           ],
102           inputIds: [
103             0
104           ],
105           outputIds: [
106             0
107           ],
108           featureVersions: {
109             bindingIdsScheme: 1
110           }
111         }
112     )";
113         Setup();
114     }
115 };
116 
117 struct SimpleArgMinMaxFixture : public ArgMinMaxFixture
118 {
SimpleArgMinMaxFixtureSimpleArgMinMaxFixture119     SimpleArgMinMaxFixture() : ArgMinMaxFixture("[ 1, 1, 1, 5 ]",
120                                                 "[ 1, 1, 1 ]",
121                                                 "-1",
122                                                 "Max") {}
123 };
124 
125 TEST_CASE_FIXTURE(SimpleArgMinMaxFixture, "ArgMinMax")
126 {
127     RunTest<3, armnn::DataType::Float32, armnn::DataType::Signed64>(
128             0,
129             {{"InputLayer", { 6.0f, 2.0f, 8.0f, 10.0f, 9.0f}}},
130             {{"OutputLayer",{ 3l }}});
131 }
132 
133 }
134