xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Comparison.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 #include "../TfLiteParser.hpp"
8 
9 #include <string>
10 
11 TEST_SUITE("TensorflowLiteParser_Comparison")
12 {
13 struct ComparisonFixture : public ParserFlatbuffersFixture
14 {
ComparisonFixtureComparisonFixture15     explicit ComparisonFixture(const std::string& operatorCode,
16                                      const std::string& dataType,
17                                      const std::string& inputShape,
18                                      const std::string& inputShape2,
19                                      const std::string& outputShape)
20     {
21         m_JsonString = R"(
22             {
23                 "version": 3,
24                 "operator_codes": [ { "builtin_code": )" + operatorCode + R"( } ],
25                 "subgraphs": [ {
26                     "tensors": [
27                         {
28                             "shape": )" + inputShape + R"(,
29                             "type": )" + dataType + R"( ,
30                             "buffer": 0,
31                             "name": "inputTensor",
32                             "quantization": {
33                                 "min": [ 0.0 ],
34                                 "max": [ 255.0 ],
35                                 "scale": [ 1.0 ],
36                                 "zero_point": [ 0 ],
37                             }
38                         },
39                         {
40                             "shape": )" + inputShape2 + R"(,
41                             "type": )" + dataType + R"( ,
42                             "buffer": 1,
43                             "name": "inputTensor2",
44                             "quantization": {
45                                 "min": [ 0.0 ],
46                                 "max": [ 255.0 ],
47                                 "scale": [ 1.0 ],
48                                 "zero_point": [ 0 ],
49                             }
50                         },
51                         {
52                             "shape": )" + outputShape + R"( ,
53                             "type": "BOOL",
54                             "buffer": 2,
55                             "name": "outputTensor",
56                             "quantization": {
57                                 "min": [ 0.0 ],
58                                 "max": [ 255.0 ],
59                                 "scale": [ 1.0 ],
60                                 "zero_point": [ 0 ],
61                             }
62                         }
63                     ],
64                     "inputs": [ 0, 1 ],
65                     "outputs": [ 2 ],
66                     "operators": [
67                         {
68                             "opcode_index": 0,
69                             "inputs": [ 0, 1 ],
70                             "outputs": [ 2 ],
71                             "custom_options_format": "FLEXBUFFERS"
72                         }
73                     ],
74                 } ],
75                 "buffers" : [
76                     { },
77                     { }
78                 ]
79             }
80         )";
81         Setup();
82     }
83 };
84 
85 struct SimpleEqualFixture : public ComparisonFixture
86 {
SimpleEqualFixtureSimpleEqualFixture87     SimpleEqualFixture() : ComparisonFixture("EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
88 };
89 
90 TEST_CASE_FIXTURE(SimpleEqualFixture, "SimpleEqual")
91 {
92     RunTest<2, armnn::DataType::QAsymmU8,
93                armnn::DataType::Boolean>(
94                    0,
95                    {{"inputTensor",  { 0, 1, 2, 3 }},
96                     {"inputTensor2", { 0, 1, 5, 6 }}},
97                    {{"outputTensor", { 1, 1, 0, 0 }}});
98 }
99 
100 struct BroadcastEqualFixture : public ComparisonFixture
101 {
BroadcastEqualFixtureBroadcastEqualFixture102     BroadcastEqualFixture() : ComparisonFixture("EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
103 };
104 
105 TEST_CASE_FIXTURE(BroadcastEqualFixture, "BroadcastEqual")
106 {
107     RunTest<2, armnn::DataType::QAsymmU8,
108                armnn::DataType::Boolean>(
109                    0,
110                    {{"inputTensor",  { 0, 1, 2, 3 }},
111                     {"inputTensor2", { 0, 1 }}},
112                    {{"outputTensor", { 1, 1, 0, 0 }}});
113 }
114 
115 struct SimpleNotEqualFixture : public ComparisonFixture
116 {
SimpleNotEqualFixtureSimpleNotEqualFixture117     SimpleNotEqualFixture() : ComparisonFixture("NOT_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
118 };
119 
120 TEST_CASE_FIXTURE(SimpleNotEqualFixture, "SimpleNotEqual")
121 {
122     RunTest<2, armnn::DataType::QAsymmU8,
123                armnn::DataType::Boolean>(
124                    0,
125                    {{"inputTensor",  { 0, 1, 2, 3 }},
126                     {"inputTensor2", { 0, 1, 5, 6 }}},
127                    {{"outputTensor", { 0, 0, 1, 1 }}});
128 }
129 
130 struct BroadcastNotEqualFixture : public ComparisonFixture
131 {
BroadcastNotEqualFixtureBroadcastNotEqualFixture132     BroadcastNotEqualFixture() : ComparisonFixture("NOT_EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
133 };
134 
135 TEST_CASE_FIXTURE(BroadcastNotEqualFixture, "BroadcastNotEqual")
136 {
137     RunTest<2, armnn::DataType::QAsymmU8,
138                armnn::DataType::Boolean>(
139                    0,
140                    {{"inputTensor",  { 0, 1, 2, 3 }},
141                     {"inputTensor2", { 0, 1 }}},
142                    {{"outputTensor", { 0, 0, 1, 1 }}});
143 }
144 
145 struct SimpleGreaterFixture : public ComparisonFixture
146 {
SimpleGreaterFixtureSimpleGreaterFixture147     SimpleGreaterFixture() : ComparisonFixture("GREATER", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
148 };
149 
150 TEST_CASE_FIXTURE(SimpleGreaterFixture, "SimpleGreater")
151 {
152     RunTest<2, armnn::DataType::QAsymmU8,
153                armnn::DataType::Boolean>(
154                    0,
155                    {{"inputTensor",  { 0, 2, 3, 6 }},
156                     {"inputTensor2", { 0, 1, 5, 3 }}},
157                    {{"outputTensor", { 0, 1, 0, 1 }}});
158 }
159 
160 struct BroadcastGreaterFixture : public ComparisonFixture
161 {
BroadcastGreaterFixtureBroadcastGreaterFixture162     BroadcastGreaterFixture() : ComparisonFixture("GREATER", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
163 };
164 
165 TEST_CASE_FIXTURE(BroadcastGreaterFixture, "BroadcastGreater")
166 {
167     RunTest<2, armnn::DataType::QAsymmU8,
168                armnn::DataType::Boolean>(
169                    0,
170                    {{"inputTensor",  { 5, 4, 1, 0 }},
171                     {"inputTensor2", { 2, 3 }}},
172                    {{"outputTensor", { 1, 1, 0, 0 }}});
173 }
174 
175 struct SimpleGreaterOrEqualFixture : public ComparisonFixture
176 {
SimpleGreaterOrEqualFixtureSimpleGreaterOrEqualFixture177     SimpleGreaterOrEqualFixture() : ComparisonFixture("GREATER_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
178 };
179 
180 TEST_CASE_FIXTURE(SimpleGreaterOrEqualFixture, "SimpleGreaterOrEqual")
181 {
182     RunTest<2, armnn::DataType::QAsymmU8,
183                armnn::DataType::Boolean>(
184                    0,
185                    {{"inputTensor",  { 0, 2, 3, 6 }},
186                     {"inputTensor2", { 0, 1, 5, 3 }}},
187                    {{"outputTensor", { 1, 1, 0, 1 }}});
188 }
189 
190 struct BroadcastGreaterOrEqualFixture : public ComparisonFixture
191 {
BroadcastGreaterOrEqualFixtureBroadcastGreaterOrEqualFixture192     BroadcastGreaterOrEqualFixture() : ComparisonFixture("GREATER_EQUAL", "UINT8",
193                                                          "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
194 };
195 
196 TEST_CASE_FIXTURE(BroadcastGreaterOrEqualFixture, "BroadcastGreaterOrEqual")
197 {
198     RunTest<2, armnn::DataType::QAsymmU8,
199                armnn::DataType::Boolean>(
200                    0,
201                    {{"inputTensor",  { 5, 4, 1, 0 }},
202                     {"inputTensor2", { 2, 4 }}},
203                    {{"outputTensor", { 1, 1, 0, 0 }}});
204 }
205 
206 struct SimpleLessFixture : public ComparisonFixture
207 {
SimpleLessFixtureSimpleLessFixture208     SimpleLessFixture() : ComparisonFixture("LESS", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
209 };
210 
211 TEST_CASE_FIXTURE(SimpleLessFixture, "SimpleLess")
212 {
213     RunTest<2, armnn::DataType::QAsymmU8,
214                armnn::DataType::Boolean>(
215                    0,
216                    {{"inputTensor",  { 0, 2, 3, 6 }},
217                     {"inputTensor2", { 0, 1, 5, 3 }}},
218                    {{"outputTensor", { 0, 0, 1, 0 }}});
219 }
220 
221 struct BroadcastLessFixture : public ComparisonFixture
222 {
BroadcastLessFixtureBroadcastLessFixture223     BroadcastLessFixture() : ComparisonFixture("LESS", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
224 };
225 
226 TEST_CASE_FIXTURE(BroadcastLessFixture, "BroadcastLess")
227 {
228     RunTest<2, armnn::DataType::QAsymmU8,
229                armnn::DataType::Boolean>(
230                    0,
231                    {{"inputTensor",  { 5, 4, 1, 0 }},
232                     {"inputTensor2", { 2, 3 }}},
233                    {{"outputTensor", { 0, 0, 1, 1 }}});
234 }
235 
236 struct SimpleLessOrEqualFixture : public ComparisonFixture
237 {
SimpleLessOrEqualFixtureSimpleLessOrEqualFixture238     SimpleLessOrEqualFixture() : ComparisonFixture("LESS_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {}
239 };
240 
241 TEST_CASE_FIXTURE(SimpleLessOrEqualFixture, "SimpleLessOrEqual")
242 {
243     RunTest<2, armnn::DataType::QAsymmU8,
244                armnn::DataType::Boolean>(
245                    0,
246                    {{"inputTensor",  { 0, 2, 3, 6 }},
247                     {"inputTensor2", { 0, 1, 5, 3 }}},
248                    {{"outputTensor", { 1, 0, 1, 0 }}});
249 }
250 
251 struct BroadcastLessOrEqualFixture : public ComparisonFixture
252 {
BroadcastLessOrEqualFixtureBroadcastLessOrEqualFixture253     BroadcastLessOrEqualFixture() : ComparisonFixture("LESS_EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {}
254 };
255 
256 TEST_CASE_FIXTURE(BroadcastLessOrEqualFixture, "BroadcastLessOrEqual")
257 {
258     RunTest<2, armnn::DataType::QAsymmU8,
259                armnn::DataType::Boolean>(
260                    0,
261                    {{"inputTensor",  { 5, 4, 1, 0 }},
262                     {"inputTensor2", { 1, 3 }}},
263                    {{"outputTensor", { 0, 0, 1, 1 }}});
264 }
265 
266 }
267