1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 #include "../TfLiteParser.hpp" 8 9 #include <string> 10 11 TEST_SUITE("TensorflowLiteParser_Comparison") 12 { 13 struct ComparisonFixture : public ParserFlatbuffersFixture 14 { ComparisonFixtureComparisonFixture15 explicit ComparisonFixture(const std::string& operatorCode, 16 const std::string& dataType, 17 const std::string& inputShape, 18 const std::string& inputShape2, 19 const std::string& outputShape) 20 { 21 m_JsonString = R"( 22 { 23 "version": 3, 24 "operator_codes": [ { "builtin_code": )" + operatorCode + R"( } ], 25 "subgraphs": [ { 26 "tensors": [ 27 { 28 "shape": )" + inputShape + R"(, 29 "type": )" + dataType + R"( , 30 "buffer": 0, 31 "name": "inputTensor", 32 "quantization": { 33 "min": [ 0.0 ], 34 "max": [ 255.0 ], 35 "scale": [ 1.0 ], 36 "zero_point": [ 0 ], 37 } 38 }, 39 { 40 "shape": )" + inputShape2 + R"(, 41 "type": )" + dataType + R"( , 42 "buffer": 1, 43 "name": "inputTensor2", 44 "quantization": { 45 "min": [ 0.0 ], 46 "max": [ 255.0 ], 47 "scale": [ 1.0 ], 48 "zero_point": [ 0 ], 49 } 50 }, 51 { 52 "shape": )" + outputShape + R"( , 53 "type": "BOOL", 54 "buffer": 2, 55 "name": "outputTensor", 56 "quantization": { 57 "min": [ 0.0 ], 58 "max": [ 255.0 ], 59 "scale": [ 1.0 ], 60 "zero_point": [ 0 ], 61 } 62 } 63 ], 64 "inputs": [ 0, 1 ], 65 "outputs": [ 2 ], 66 "operators": [ 67 { 68 "opcode_index": 0, 69 "inputs": [ 0, 1 ], 70 "outputs": [ 2 ], 71 "custom_options_format": "FLEXBUFFERS" 72 } 73 ], 74 } ], 75 "buffers" : [ 76 { }, 77 { } 78 ] 79 } 80 )"; 81 Setup(); 82 } 83 }; 84 85 struct SimpleEqualFixture : public ComparisonFixture 86 { SimpleEqualFixtureSimpleEqualFixture87 SimpleEqualFixture() : ComparisonFixture("EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {} 88 }; 89 90 TEST_CASE_FIXTURE(SimpleEqualFixture, "SimpleEqual") 91 { 92 RunTest<2, armnn::DataType::QAsymmU8, 93 armnn::DataType::Boolean>( 94 0, 95 {{"inputTensor", { 0, 1, 2, 3 }}, 96 {"inputTensor2", { 0, 1, 5, 6 }}}, 97 {{"outputTensor", { 1, 1, 0, 0 }}}); 98 } 99 100 struct BroadcastEqualFixture : public ComparisonFixture 101 { BroadcastEqualFixtureBroadcastEqualFixture102 BroadcastEqualFixture() : ComparisonFixture("EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {} 103 }; 104 105 TEST_CASE_FIXTURE(BroadcastEqualFixture, "BroadcastEqual") 106 { 107 RunTest<2, armnn::DataType::QAsymmU8, 108 armnn::DataType::Boolean>( 109 0, 110 {{"inputTensor", { 0, 1, 2, 3 }}, 111 {"inputTensor2", { 0, 1 }}}, 112 {{"outputTensor", { 1, 1, 0, 0 }}}); 113 } 114 115 struct SimpleNotEqualFixture : public ComparisonFixture 116 { SimpleNotEqualFixtureSimpleNotEqualFixture117 SimpleNotEqualFixture() : ComparisonFixture("NOT_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {} 118 }; 119 120 TEST_CASE_FIXTURE(SimpleNotEqualFixture, "SimpleNotEqual") 121 { 122 RunTest<2, armnn::DataType::QAsymmU8, 123 armnn::DataType::Boolean>( 124 0, 125 {{"inputTensor", { 0, 1, 2, 3 }}, 126 {"inputTensor2", { 0, 1, 5, 6 }}}, 127 {{"outputTensor", { 0, 0, 1, 1 }}}); 128 } 129 130 struct BroadcastNotEqualFixture : public ComparisonFixture 131 { BroadcastNotEqualFixtureBroadcastNotEqualFixture132 BroadcastNotEqualFixture() : ComparisonFixture("NOT_EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {} 133 }; 134 135 TEST_CASE_FIXTURE(BroadcastNotEqualFixture, "BroadcastNotEqual") 136 { 137 RunTest<2, armnn::DataType::QAsymmU8, 138 armnn::DataType::Boolean>( 139 0, 140 {{"inputTensor", { 0, 1, 2, 3 }}, 141 {"inputTensor2", { 0, 1 }}}, 142 {{"outputTensor", { 0, 0, 1, 1 }}}); 143 } 144 145 struct SimpleGreaterFixture : public ComparisonFixture 146 { SimpleGreaterFixtureSimpleGreaterFixture147 SimpleGreaterFixture() : ComparisonFixture("GREATER", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {} 148 }; 149 150 TEST_CASE_FIXTURE(SimpleGreaterFixture, "SimpleGreater") 151 { 152 RunTest<2, armnn::DataType::QAsymmU8, 153 armnn::DataType::Boolean>( 154 0, 155 {{"inputTensor", { 0, 2, 3, 6 }}, 156 {"inputTensor2", { 0, 1, 5, 3 }}}, 157 {{"outputTensor", { 0, 1, 0, 1 }}}); 158 } 159 160 struct BroadcastGreaterFixture : public ComparisonFixture 161 { BroadcastGreaterFixtureBroadcastGreaterFixture162 BroadcastGreaterFixture() : ComparisonFixture("GREATER", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {} 163 }; 164 165 TEST_CASE_FIXTURE(BroadcastGreaterFixture, "BroadcastGreater") 166 { 167 RunTest<2, armnn::DataType::QAsymmU8, 168 armnn::DataType::Boolean>( 169 0, 170 {{"inputTensor", { 5, 4, 1, 0 }}, 171 {"inputTensor2", { 2, 3 }}}, 172 {{"outputTensor", { 1, 1, 0, 0 }}}); 173 } 174 175 struct SimpleGreaterOrEqualFixture : public ComparisonFixture 176 { SimpleGreaterOrEqualFixtureSimpleGreaterOrEqualFixture177 SimpleGreaterOrEqualFixture() : ComparisonFixture("GREATER_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {} 178 }; 179 180 TEST_CASE_FIXTURE(SimpleGreaterOrEqualFixture, "SimpleGreaterOrEqual") 181 { 182 RunTest<2, armnn::DataType::QAsymmU8, 183 armnn::DataType::Boolean>( 184 0, 185 {{"inputTensor", { 0, 2, 3, 6 }}, 186 {"inputTensor2", { 0, 1, 5, 3 }}}, 187 {{"outputTensor", { 1, 1, 0, 1 }}}); 188 } 189 190 struct BroadcastGreaterOrEqualFixture : public ComparisonFixture 191 { BroadcastGreaterOrEqualFixtureBroadcastGreaterOrEqualFixture192 BroadcastGreaterOrEqualFixture() : ComparisonFixture("GREATER_EQUAL", "UINT8", 193 "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {} 194 }; 195 196 TEST_CASE_FIXTURE(BroadcastGreaterOrEqualFixture, "BroadcastGreaterOrEqual") 197 { 198 RunTest<2, armnn::DataType::QAsymmU8, 199 armnn::DataType::Boolean>( 200 0, 201 {{"inputTensor", { 5, 4, 1, 0 }}, 202 {"inputTensor2", { 2, 4 }}}, 203 {{"outputTensor", { 1, 1, 0, 0 }}}); 204 } 205 206 struct SimpleLessFixture : public ComparisonFixture 207 { SimpleLessFixtureSimpleLessFixture208 SimpleLessFixture() : ComparisonFixture("LESS", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {} 209 }; 210 211 TEST_CASE_FIXTURE(SimpleLessFixture, "SimpleLess") 212 { 213 RunTest<2, armnn::DataType::QAsymmU8, 214 armnn::DataType::Boolean>( 215 0, 216 {{"inputTensor", { 0, 2, 3, 6 }}, 217 {"inputTensor2", { 0, 1, 5, 3 }}}, 218 {{"outputTensor", { 0, 0, 1, 0 }}}); 219 } 220 221 struct BroadcastLessFixture : public ComparisonFixture 222 { BroadcastLessFixtureBroadcastLessFixture223 BroadcastLessFixture() : ComparisonFixture("LESS", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {} 224 }; 225 226 TEST_CASE_FIXTURE(BroadcastLessFixture, "BroadcastLess") 227 { 228 RunTest<2, armnn::DataType::QAsymmU8, 229 armnn::DataType::Boolean>( 230 0, 231 {{"inputTensor", { 5, 4, 1, 0 }}, 232 {"inputTensor2", { 2, 3 }}}, 233 {{"outputTensor", { 0, 0, 1, 1 }}}); 234 } 235 236 struct SimpleLessOrEqualFixture : public ComparisonFixture 237 { SimpleLessOrEqualFixtureSimpleLessOrEqualFixture238 SimpleLessOrEqualFixture() : ComparisonFixture("LESS_EQUAL", "UINT8", "[ 2, 2 ]", "[ 2, 2 ]", "[ 2, 2 ]") {} 239 }; 240 241 TEST_CASE_FIXTURE(SimpleLessOrEqualFixture, "SimpleLessOrEqual") 242 { 243 RunTest<2, armnn::DataType::QAsymmU8, 244 armnn::DataType::Boolean>( 245 0, 246 {{"inputTensor", { 0, 2, 3, 6 }}, 247 {"inputTensor2", { 0, 1, 5, 3 }}}, 248 {{"outputTensor", { 1, 0, 1, 0 }}}); 249 } 250 251 struct BroadcastLessOrEqualFixture : public ComparisonFixture 252 { BroadcastLessOrEqualFixtureBroadcastLessOrEqualFixture253 BroadcastLessOrEqualFixture() : ComparisonFixture("LESS_EQUAL", "UINT8", "[ 2, 2 ]", "[ 1, 2 ]", "[ 2, 2 ]") {} 254 }; 255 256 TEST_CASE_FIXTURE(BroadcastLessOrEqualFixture, "BroadcastLessOrEqual") 257 { 258 RunTest<2, armnn::DataType::QAsymmU8, 259 armnn::DataType::Boolean>( 260 0, 261 {{"inputTensor", { 5, 4, 1, 0 }}, 262 {"inputTensor2", { 1, 3 }}}, 263 {{"outputTensor", { 0, 0, 1, 1 }}}); 264 } 265 266 } 267