xref: /aosp_15_r20/external/armnn/src/armnnDeserializer/test/DeserializeComparison.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp>
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <string>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Comparison")
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker #define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
17*89c4ff92SAndroid Build Coastguard Worker struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \
18*89c4ff92SAndroid Build Coastguard Worker { \
19*89c4ff92SAndroid Build Coastguard Worker     Simple##operation##dataType##Fixture() \
20*89c4ff92SAndroid Build Coastguard Worker         : SimpleComparisonFixture(#dataType, #operation) {} \
21*89c4ff92SAndroid Build Coastguard Worker };
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker #define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \
24*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \
25*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Simple##operation##dataType##Fixture, #operation#dataType) \
26*89c4ff92SAndroid Build Coastguard Worker { \
27*89c4ff92SAndroid Build Coastguard Worker     using T = armnn::ResolveType<armnn::DataType::dataType>; \
28*89c4ff92SAndroid Build Coastguard Worker     constexpr float   qScale  = 1.f; \
29*89c4ff92SAndroid Build Coastguard Worker     constexpr int32_t qOffset = 0; \
30*89c4ff92SAndroid Build Coastguard Worker     RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \
31*89c4ff92SAndroid Build Coastguard Worker         0, \
32*89c4ff92SAndroid Build Coastguard Worker         {{ "InputLayer0", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData0, qScale, qOffset)  }, \
33*89c4ff92SAndroid Build Coastguard Worker          { "InputLayer1", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData1, qScale, qOffset)  }}, \
34*89c4ff92SAndroid Build Coastguard Worker         {{ "OutputLayer", s_TestData.m_Output##operation }}); \
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker struct ComparisonFixture : public ParserFlatbuffersSerializeFixture
38*89c4ff92SAndroid Build Coastguard Worker {
ComparisonFixtureComparisonFixture39*89c4ff92SAndroid Build Coastguard Worker     explicit ComparisonFixture(const std::string& inputShape0,
40*89c4ff92SAndroid Build Coastguard Worker                                const std::string& inputShape1,
41*89c4ff92SAndroid Build Coastguard Worker                                const std::string& outputShape,
42*89c4ff92SAndroid Build Coastguard Worker                                const std::string& inputDataType,
43*89c4ff92SAndroid Build Coastguard Worker                                const std::string& comparisonOperation)
44*89c4ff92SAndroid Build Coastguard Worker     {
45*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
46*89c4ff92SAndroid Build Coastguard Worker             {
47*89c4ff92SAndroid Build Coastguard Worker                 inputIds: [0, 1],
48*89c4ff92SAndroid Build Coastguard Worker                 outputIds: [3],
49*89c4ff92SAndroid Build Coastguard Worker                 layers: [
50*89c4ff92SAndroid Build Coastguard Worker                     {
51*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "InputLayer",
52*89c4ff92SAndroid Build Coastguard Worker                         layer: {
53*89c4ff92SAndroid Build Coastguard Worker                             base: {
54*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 0,
55*89c4ff92SAndroid Build Coastguard Worker                                 base: {
56*89c4ff92SAndroid Build Coastguard Worker                                     index: 0,
57*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "InputLayer0",
58*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Input",
59*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
60*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
61*89c4ff92SAndroid Build Coastguard Worker                                         connection: { sourceLayerIndex:0, outputSlotIndex:0 },
62*89c4ff92SAndroid Build Coastguard Worker                                     }],
63*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [{
64*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
65*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
66*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + inputShape0 + R"(,
67*89c4ff92SAndroid Build Coastguard Worker                                             dataType: )" + inputDataType + R"(
68*89c4ff92SAndroid Build Coastguard Worker                                         },
69*89c4ff92SAndroid Build Coastguard Worker                                     }],
70*89c4ff92SAndroid Build Coastguard Worker                                 },
71*89c4ff92SAndroid Build Coastguard Worker                             }
72*89c4ff92SAndroid Build Coastguard Worker                         },
73*89c4ff92SAndroid Build Coastguard Worker                     },
74*89c4ff92SAndroid Build Coastguard Worker                     {
75*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "InputLayer",
76*89c4ff92SAndroid Build Coastguard Worker                         layer: {
77*89c4ff92SAndroid Build Coastguard Worker                             base: {
78*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 1,
79*89c4ff92SAndroid Build Coastguard Worker                                 base: {
80*89c4ff92SAndroid Build Coastguard Worker                                       index:1,
81*89c4ff92SAndroid Build Coastguard Worker                                       layerName: "InputLayer1",
82*89c4ff92SAndroid Build Coastguard Worker                                       layerType: "Input",
83*89c4ff92SAndroid Build Coastguard Worker                                       inputSlots: [{
84*89c4ff92SAndroid Build Coastguard Worker                                           index: 0,
85*89c4ff92SAndroid Build Coastguard Worker                                           connection: { sourceLayerIndex:0, outputSlotIndex:0 },
86*89c4ff92SAndroid Build Coastguard Worker                                       }],
87*89c4ff92SAndroid Build Coastguard Worker                                       outputSlots: [{
88*89c4ff92SAndroid Build Coastguard Worker                                           index: 0,
89*89c4ff92SAndroid Build Coastguard Worker                                           tensorInfo: {
90*89c4ff92SAndroid Build Coastguard Worker                                               dimensions: )" + inputShape1 + R"(,
91*89c4ff92SAndroid Build Coastguard Worker                                               dataType: )" + inputDataType + R"(
92*89c4ff92SAndroid Build Coastguard Worker                                           },
93*89c4ff92SAndroid Build Coastguard Worker                                       }],
94*89c4ff92SAndroid Build Coastguard Worker                                 },
95*89c4ff92SAndroid Build Coastguard Worker                             }
96*89c4ff92SAndroid Build Coastguard Worker                         },
97*89c4ff92SAndroid Build Coastguard Worker                     },
98*89c4ff92SAndroid Build Coastguard Worker                     {
99*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "ComparisonLayer",
100*89c4ff92SAndroid Build Coastguard Worker                         layer: {
101*89c4ff92SAndroid Build Coastguard Worker                             base: {
102*89c4ff92SAndroid Build Coastguard Worker                                  index:2,
103*89c4ff92SAndroid Build Coastguard Worker                                  layerName: "ComparisonLayer",
104*89c4ff92SAndroid Build Coastguard Worker                                  layerType: "Comparison",
105*89c4ff92SAndroid Build Coastguard Worker                                  inputSlots: [{
106*89c4ff92SAndroid Build Coastguard Worker                                      index: 0,
107*89c4ff92SAndroid Build Coastguard Worker                                      connection: { sourceLayerIndex:0, outputSlotIndex:0 },
108*89c4ff92SAndroid Build Coastguard Worker                                  },
109*89c4ff92SAndroid Build Coastguard Worker                                  {
110*89c4ff92SAndroid Build Coastguard Worker                                      index: 1,
111*89c4ff92SAndroid Build Coastguard Worker                                      connection: { sourceLayerIndex:1, outputSlotIndex:0 },
112*89c4ff92SAndroid Build Coastguard Worker                                  }],
113*89c4ff92SAndroid Build Coastguard Worker                                  outputSlots: [{
114*89c4ff92SAndroid Build Coastguard Worker                                      index: 0,
115*89c4ff92SAndroid Build Coastguard Worker                                      tensorInfo: {
116*89c4ff92SAndroid Build Coastguard Worker                                          dimensions: )" + outputShape + R"(,
117*89c4ff92SAndroid Build Coastguard Worker                                          dataType: Boolean
118*89c4ff92SAndroid Build Coastguard Worker                                      },
119*89c4ff92SAndroid Build Coastguard Worker                                  }],
120*89c4ff92SAndroid Build Coastguard Worker                             },
121*89c4ff92SAndroid Build Coastguard Worker                             descriptor: {
122*89c4ff92SAndroid Build Coastguard Worker                                 operation: )" + comparisonOperation + R"(
123*89c4ff92SAndroid Build Coastguard Worker                             }
124*89c4ff92SAndroid Build Coastguard Worker                         },
125*89c4ff92SAndroid Build Coastguard Worker                     },
126*89c4ff92SAndroid Build Coastguard Worker                     {
127*89c4ff92SAndroid Build Coastguard Worker                         layer_type: "OutputLayer",
128*89c4ff92SAndroid Build Coastguard Worker                         layer: {
129*89c4ff92SAndroid Build Coastguard Worker                             base:{
130*89c4ff92SAndroid Build Coastguard Worker                                 layerBindingId: 0,
131*89c4ff92SAndroid Build Coastguard Worker                                 base: {
132*89c4ff92SAndroid Build Coastguard Worker                                     index: 3,
133*89c4ff92SAndroid Build Coastguard Worker                                     layerName: "OutputLayer",
134*89c4ff92SAndroid Build Coastguard Worker                                     layerType: "Output",
135*89c4ff92SAndroid Build Coastguard Worker                                     inputSlots: [{
136*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
137*89c4ff92SAndroid Build Coastguard Worker                                         connection: { sourceLayerIndex:2, outputSlotIndex:0 },
138*89c4ff92SAndroid Build Coastguard Worker                                     }],
139*89c4ff92SAndroid Build Coastguard Worker                                     outputSlots: [{
140*89c4ff92SAndroid Build Coastguard Worker                                         index: 0,
141*89c4ff92SAndroid Build Coastguard Worker                                         tensorInfo: {
142*89c4ff92SAndroid Build Coastguard Worker                                             dimensions: )" + outputShape + R"(,
143*89c4ff92SAndroid Build Coastguard Worker                                             dataType: Boolean
144*89c4ff92SAndroid Build Coastguard Worker                                         },
145*89c4ff92SAndroid Build Coastguard Worker                                     }],
146*89c4ff92SAndroid Build Coastguard Worker                                 }
147*89c4ff92SAndroid Build Coastguard Worker                             }
148*89c4ff92SAndroid Build Coastguard Worker                         },
149*89c4ff92SAndroid Build Coastguard Worker                     }
150*89c4ff92SAndroid Build Coastguard Worker                 ]
151*89c4ff92SAndroid Build Coastguard Worker             }
152*89c4ff92SAndroid Build Coastguard Worker         )";
153*89c4ff92SAndroid Build Coastguard Worker         Setup();
154*89c4ff92SAndroid Build Coastguard Worker     }
155*89c4ff92SAndroid Build Coastguard Worker };
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker struct SimpleComparisonTestData
158*89c4ff92SAndroid Build Coastguard Worker {
SimpleComparisonTestDataSimpleComparisonTestData159*89c4ff92SAndroid Build Coastguard Worker     SimpleComparisonTestData()
160*89c4ff92SAndroid Build Coastguard Worker     {
161*89c4ff92SAndroid Build Coastguard Worker         m_InputData0 =
162*89c4ff92SAndroid Build Coastguard Worker         {
163*89c4ff92SAndroid Build Coastguard Worker             1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f,
164*89c4ff92SAndroid Build Coastguard Worker             3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f
165*89c4ff92SAndroid Build Coastguard Worker         };
166*89c4ff92SAndroid Build Coastguard Worker 
167*89c4ff92SAndroid Build Coastguard Worker         m_InputData1 =
168*89c4ff92SAndroid Build Coastguard Worker         {
169*89c4ff92SAndroid Build Coastguard Worker             1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f,
170*89c4ff92SAndroid Build Coastguard Worker             5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f
171*89c4ff92SAndroid Build Coastguard Worker         };
172*89c4ff92SAndroid Build Coastguard Worker 
173*89c4ff92SAndroid Build Coastguard Worker         m_OutputEqual =
174*89c4ff92SAndroid Build Coastguard Worker         {
175*89c4ff92SAndroid Build Coastguard Worker             1, 1, 1, 1, 0, 0, 0, 0,
176*89c4ff92SAndroid Build Coastguard Worker             0, 0, 0, 0, 1, 1, 1, 1
177*89c4ff92SAndroid Build Coastguard Worker         };
178*89c4ff92SAndroid Build Coastguard Worker 
179*89c4ff92SAndroid Build Coastguard Worker         m_OutputGreater =
180*89c4ff92SAndroid Build Coastguard Worker         {
181*89c4ff92SAndroid Build Coastguard Worker             0, 0, 0, 0, 1, 1, 1, 1,
182*89c4ff92SAndroid Build Coastguard Worker             0, 0, 0, 0, 0, 0, 0, 0
183*89c4ff92SAndroid Build Coastguard Worker         };
184*89c4ff92SAndroid Build Coastguard Worker 
185*89c4ff92SAndroid Build Coastguard Worker         m_OutputGreaterOrEqual =
186*89c4ff92SAndroid Build Coastguard Worker         {
187*89c4ff92SAndroid Build Coastguard Worker             1, 1, 1, 1, 1, 1, 1, 1,
188*89c4ff92SAndroid Build Coastguard Worker             0, 0, 0, 0, 1, 1, 1, 1
189*89c4ff92SAndroid Build Coastguard Worker         };
190*89c4ff92SAndroid Build Coastguard Worker 
191*89c4ff92SAndroid Build Coastguard Worker         m_OutputLess =
192*89c4ff92SAndroid Build Coastguard Worker         {
193*89c4ff92SAndroid Build Coastguard Worker             0, 0, 0, 0, 0, 0, 0, 0,
194*89c4ff92SAndroid Build Coastguard Worker             1, 1, 1, 1, 0, 0, 0, 0
195*89c4ff92SAndroid Build Coastguard Worker         };
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker         m_OutputLessOrEqual =
198*89c4ff92SAndroid Build Coastguard Worker         {
199*89c4ff92SAndroid Build Coastguard Worker             1, 1, 1, 1, 0, 0, 0, 0,
200*89c4ff92SAndroid Build Coastguard Worker             1, 1, 1, 1, 1, 1, 1, 1
201*89c4ff92SAndroid Build Coastguard Worker         };
202*89c4ff92SAndroid Build Coastguard Worker 
203*89c4ff92SAndroid Build Coastguard Worker         m_OutputNotEqual =
204*89c4ff92SAndroid Build Coastguard Worker         {
205*89c4ff92SAndroid Build Coastguard Worker             0, 0, 0, 0, 1, 1, 1, 1,
206*89c4ff92SAndroid Build Coastguard Worker             1, 1, 1, 1, 0, 0, 0, 0
207*89c4ff92SAndroid Build Coastguard Worker         };
208*89c4ff92SAndroid Build Coastguard Worker     }
209*89c4ff92SAndroid Build Coastguard Worker 
210*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> m_InputData0;
211*89c4ff92SAndroid Build Coastguard Worker     std::vector<float> m_InputData1;
212*89c4ff92SAndroid Build Coastguard Worker 
213*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> m_OutputEqual;
214*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> m_OutputGreater;
215*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> m_OutputGreaterOrEqual;
216*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> m_OutputLess;
217*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> m_OutputLessOrEqual;
218*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint8_t> m_OutputNotEqual;
219*89c4ff92SAndroid Build Coastguard Worker };
220*89c4ff92SAndroid Build Coastguard Worker 
221*89c4ff92SAndroid Build Coastguard Worker struct SimpleComparisonFixture : public ComparisonFixture
222*89c4ff92SAndroid Build Coastguard Worker {
SimpleComparisonFixtureSimpleComparisonFixture223*89c4ff92SAndroid Build Coastguard Worker     SimpleComparisonFixture(const std::string& inputDataType,
224*89c4ff92SAndroid Build Coastguard Worker                             const std::string& comparisonOperation)
225*89c4ff92SAndroid Build Coastguard Worker         : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0
226*89c4ff92SAndroid Build Coastguard Worker                             "[ 2, 2, 2, 2 ]", // inputShape1
227*89c4ff92SAndroid Build Coastguard Worker                             "[ 2, 2, 2, 2 ]", // outputShape,
228*89c4ff92SAndroid Build Coastguard Worker                             inputDataType,
229*89c4ff92SAndroid Build Coastguard Worker                             comparisonOperation) {}
230*89c4ff92SAndroid Build Coastguard Worker 
231*89c4ff92SAndroid Build Coastguard Worker     static SimpleComparisonTestData s_TestData;
232*89c4ff92SAndroid Build Coastguard Worker };
233*89c4ff92SAndroid Build Coastguard Worker 
234*89c4ff92SAndroid Build Coastguard Worker SimpleComparisonTestData SimpleComparisonFixture::s_TestData;
235*89c4ff92SAndroid Build Coastguard Worker 
236*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal,          Float32)
237*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater,        Float32)
238*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32)
239*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less,           Float32)
240*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual,    Float32)
241*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual,       Float32)
242*89c4ff92SAndroid Build Coastguard Worker 
243*89c4ff92SAndroid Build Coastguard Worker 
244*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal,          QAsymmU8)
245*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater,        QAsymmU8)
246*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QAsymmU8)
247*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less,           QAsymmU8)
248*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual,    QAsymmU8)
249*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual,       QAsymmU8)
250*89c4ff92SAndroid Build Coastguard Worker 
251*89c4ff92SAndroid Build Coastguard Worker }
252