1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp> 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/QuantizeHelper.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <string> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Comparison") 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker #define DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \ 17*89c4ff92SAndroid Build Coastguard Worker struct Simple##operation##dataType##Fixture : public SimpleComparisonFixture \ 18*89c4ff92SAndroid Build Coastguard Worker { \ 19*89c4ff92SAndroid Build Coastguard Worker Simple##operation##dataType##Fixture() \ 20*89c4ff92SAndroid Build Coastguard Worker : SimpleComparisonFixture(#dataType, #operation) {} \ 21*89c4ff92SAndroid Build Coastguard Worker }; 22*89c4ff92SAndroid Build Coastguard Worker 23*89c4ff92SAndroid Build Coastguard Worker #define DECLARE_SIMPLE_COMPARISON_TEST_CASE(operation, dataType) \ 24*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_FIXTURE(operation, dataType) \ 25*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Simple##operation##dataType##Fixture, #operation#dataType) \ 26*89c4ff92SAndroid Build Coastguard Worker { \ 27*89c4ff92SAndroid Build Coastguard Worker using T = armnn::ResolveType<armnn::DataType::dataType>; \ 28*89c4ff92SAndroid Build Coastguard Worker constexpr float qScale = 1.f; \ 29*89c4ff92SAndroid Build Coastguard Worker constexpr int32_t qOffset = 0; \ 30*89c4ff92SAndroid Build Coastguard Worker RunTest<4, armnn::DataType::dataType, armnn::DataType::Boolean>( \ 31*89c4ff92SAndroid Build Coastguard Worker 0, \ 32*89c4ff92SAndroid Build Coastguard Worker {{ "InputLayer0", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData0, qScale, qOffset) }, \ 33*89c4ff92SAndroid Build Coastguard Worker { "InputLayer1", armnnUtils::QuantizedVector<T>(s_TestData.m_InputData1, qScale, qOffset) }}, \ 34*89c4ff92SAndroid Build Coastguard Worker {{ "OutputLayer", s_TestData.m_Output##operation }}); \ 35*89c4ff92SAndroid Build Coastguard Worker } 36*89c4ff92SAndroid Build Coastguard Worker 37*89c4ff92SAndroid Build Coastguard Worker struct ComparisonFixture : public ParserFlatbuffersSerializeFixture 38*89c4ff92SAndroid Build Coastguard Worker { ComparisonFixtureComparisonFixture39*89c4ff92SAndroid Build Coastguard Worker explicit ComparisonFixture(const std::string& inputShape0, 40*89c4ff92SAndroid Build Coastguard Worker const std::string& inputShape1, 41*89c4ff92SAndroid Build Coastguard Worker const std::string& outputShape, 42*89c4ff92SAndroid Build Coastguard Worker const std::string& inputDataType, 43*89c4ff92SAndroid Build Coastguard Worker const std::string& comparisonOperation) 44*89c4ff92SAndroid Build Coastguard Worker { 45*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 46*89c4ff92SAndroid Build Coastguard Worker { 47*89c4ff92SAndroid Build Coastguard Worker inputIds: [0, 1], 48*89c4ff92SAndroid Build Coastguard Worker outputIds: [3], 49*89c4ff92SAndroid Build Coastguard Worker layers: [ 50*89c4ff92SAndroid Build Coastguard Worker { 51*89c4ff92SAndroid Build Coastguard Worker layer_type: "InputLayer", 52*89c4ff92SAndroid Build Coastguard Worker layer: { 53*89c4ff92SAndroid Build Coastguard Worker base: { 54*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 0, 55*89c4ff92SAndroid Build Coastguard Worker base: { 56*89c4ff92SAndroid Build Coastguard Worker index: 0, 57*89c4ff92SAndroid Build Coastguard Worker layerName: "InputLayer0", 58*89c4ff92SAndroid Build Coastguard Worker layerType: "Input", 59*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 60*89c4ff92SAndroid Build Coastguard Worker index: 0, 61*89c4ff92SAndroid Build Coastguard Worker connection: { sourceLayerIndex:0, outputSlotIndex:0 }, 62*89c4ff92SAndroid Build Coastguard Worker }], 63*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 64*89c4ff92SAndroid Build Coastguard Worker index: 0, 65*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 66*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + inputShape0 + R"(, 67*89c4ff92SAndroid Build Coastguard Worker dataType: )" + inputDataType + R"( 68*89c4ff92SAndroid Build Coastguard Worker }, 69*89c4ff92SAndroid Build Coastguard Worker }], 70*89c4ff92SAndroid Build Coastguard Worker }, 71*89c4ff92SAndroid Build Coastguard Worker } 72*89c4ff92SAndroid Build Coastguard Worker }, 73*89c4ff92SAndroid Build Coastguard Worker }, 74*89c4ff92SAndroid Build Coastguard Worker { 75*89c4ff92SAndroid Build Coastguard Worker layer_type: "InputLayer", 76*89c4ff92SAndroid Build Coastguard Worker layer: { 77*89c4ff92SAndroid Build Coastguard Worker base: { 78*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 1, 79*89c4ff92SAndroid Build Coastguard Worker base: { 80*89c4ff92SAndroid Build Coastguard Worker index:1, 81*89c4ff92SAndroid Build Coastguard Worker layerName: "InputLayer1", 82*89c4ff92SAndroid Build Coastguard Worker layerType: "Input", 83*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 84*89c4ff92SAndroid Build Coastguard Worker index: 0, 85*89c4ff92SAndroid Build Coastguard Worker connection: { sourceLayerIndex:0, outputSlotIndex:0 }, 86*89c4ff92SAndroid Build Coastguard Worker }], 87*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 88*89c4ff92SAndroid Build Coastguard Worker index: 0, 89*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 90*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + inputShape1 + R"(, 91*89c4ff92SAndroid Build Coastguard Worker dataType: )" + inputDataType + R"( 92*89c4ff92SAndroid Build Coastguard Worker }, 93*89c4ff92SAndroid Build Coastguard Worker }], 94*89c4ff92SAndroid Build Coastguard Worker }, 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker }, 97*89c4ff92SAndroid Build Coastguard Worker }, 98*89c4ff92SAndroid Build Coastguard Worker { 99*89c4ff92SAndroid Build Coastguard Worker layer_type: "ComparisonLayer", 100*89c4ff92SAndroid Build Coastguard Worker layer: { 101*89c4ff92SAndroid Build Coastguard Worker base: { 102*89c4ff92SAndroid Build Coastguard Worker index:2, 103*89c4ff92SAndroid Build Coastguard Worker layerName: "ComparisonLayer", 104*89c4ff92SAndroid Build Coastguard Worker layerType: "Comparison", 105*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 106*89c4ff92SAndroid Build Coastguard Worker index: 0, 107*89c4ff92SAndroid Build Coastguard Worker connection: { sourceLayerIndex:0, outputSlotIndex:0 }, 108*89c4ff92SAndroid Build Coastguard Worker }, 109*89c4ff92SAndroid Build Coastguard Worker { 110*89c4ff92SAndroid Build Coastguard Worker index: 1, 111*89c4ff92SAndroid Build Coastguard Worker connection: { sourceLayerIndex:1, outputSlotIndex:0 }, 112*89c4ff92SAndroid Build Coastguard Worker }], 113*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 114*89c4ff92SAndroid Build Coastguard Worker index: 0, 115*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 116*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + outputShape + R"(, 117*89c4ff92SAndroid Build Coastguard Worker dataType: Boolean 118*89c4ff92SAndroid Build Coastguard Worker }, 119*89c4ff92SAndroid Build Coastguard Worker }], 120*89c4ff92SAndroid Build Coastguard Worker }, 121*89c4ff92SAndroid Build Coastguard Worker descriptor: { 122*89c4ff92SAndroid Build Coastguard Worker operation: )" + comparisonOperation + R"( 123*89c4ff92SAndroid Build Coastguard Worker } 124*89c4ff92SAndroid Build Coastguard Worker }, 125*89c4ff92SAndroid Build Coastguard Worker }, 126*89c4ff92SAndroid Build Coastguard Worker { 127*89c4ff92SAndroid Build Coastguard Worker layer_type: "OutputLayer", 128*89c4ff92SAndroid Build Coastguard Worker layer: { 129*89c4ff92SAndroid Build Coastguard Worker base:{ 130*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 0, 131*89c4ff92SAndroid Build Coastguard Worker base: { 132*89c4ff92SAndroid Build Coastguard Worker index: 3, 133*89c4ff92SAndroid Build Coastguard Worker layerName: "OutputLayer", 134*89c4ff92SAndroid Build Coastguard Worker layerType: "Output", 135*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 136*89c4ff92SAndroid Build Coastguard Worker index: 0, 137*89c4ff92SAndroid Build Coastguard Worker connection: { sourceLayerIndex:2, outputSlotIndex:0 }, 138*89c4ff92SAndroid Build Coastguard Worker }], 139*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 140*89c4ff92SAndroid Build Coastguard Worker index: 0, 141*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 142*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + outputShape + R"(, 143*89c4ff92SAndroid Build Coastguard Worker dataType: Boolean 144*89c4ff92SAndroid Build Coastguard Worker }, 145*89c4ff92SAndroid Build Coastguard Worker }], 146*89c4ff92SAndroid Build Coastguard Worker } 147*89c4ff92SAndroid Build Coastguard Worker } 148*89c4ff92SAndroid Build Coastguard Worker }, 149*89c4ff92SAndroid Build Coastguard Worker } 150*89c4ff92SAndroid Build Coastguard Worker ] 151*89c4ff92SAndroid Build Coastguard Worker } 152*89c4ff92SAndroid Build Coastguard Worker )"; 153*89c4ff92SAndroid Build Coastguard Worker Setup(); 154*89c4ff92SAndroid Build Coastguard Worker } 155*89c4ff92SAndroid Build Coastguard Worker }; 156*89c4ff92SAndroid Build Coastguard Worker 157*89c4ff92SAndroid Build Coastguard Worker struct SimpleComparisonTestData 158*89c4ff92SAndroid Build Coastguard Worker { SimpleComparisonTestDataSimpleComparisonTestData159*89c4ff92SAndroid Build Coastguard Worker SimpleComparisonTestData() 160*89c4ff92SAndroid Build Coastguard Worker { 161*89c4ff92SAndroid Build Coastguard Worker m_InputData0 = 162*89c4ff92SAndroid Build Coastguard Worker { 163*89c4ff92SAndroid Build Coastguard Worker 1.f, 1.f, 1.f, 1.f, 5.f, 5.f, 5.f, 5.f, 164*89c4ff92SAndroid Build Coastguard Worker 3.f, 3.f, 3.f, 3.f, 4.f, 4.f, 4.f, 4.f 165*89c4ff92SAndroid Build Coastguard Worker }; 166*89c4ff92SAndroid Build Coastguard Worker 167*89c4ff92SAndroid Build Coastguard Worker m_InputData1 = 168*89c4ff92SAndroid Build Coastguard Worker { 169*89c4ff92SAndroid Build Coastguard Worker 1.f, 1.f, 1.f, 1.f, 3.f, 3.f, 3.f, 3.f, 170*89c4ff92SAndroid Build Coastguard Worker 5.f, 5.f, 5.f, 5.f, 4.f, 4.f, 4.f, 4.f 171*89c4ff92SAndroid Build Coastguard Worker }; 172*89c4ff92SAndroid Build Coastguard Worker 173*89c4ff92SAndroid Build Coastguard Worker m_OutputEqual = 174*89c4ff92SAndroid Build Coastguard Worker { 175*89c4ff92SAndroid Build Coastguard Worker 1, 1, 1, 1, 0, 0, 0, 0, 176*89c4ff92SAndroid Build Coastguard Worker 0, 0, 0, 0, 1, 1, 1, 1 177*89c4ff92SAndroid Build Coastguard Worker }; 178*89c4ff92SAndroid Build Coastguard Worker 179*89c4ff92SAndroid Build Coastguard Worker m_OutputGreater = 180*89c4ff92SAndroid Build Coastguard Worker { 181*89c4ff92SAndroid Build Coastguard Worker 0, 0, 0, 0, 1, 1, 1, 1, 182*89c4ff92SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0 183*89c4ff92SAndroid Build Coastguard Worker }; 184*89c4ff92SAndroid Build Coastguard Worker 185*89c4ff92SAndroid Build Coastguard Worker m_OutputGreaterOrEqual = 186*89c4ff92SAndroid Build Coastguard Worker { 187*89c4ff92SAndroid Build Coastguard Worker 1, 1, 1, 1, 1, 1, 1, 1, 188*89c4ff92SAndroid Build Coastguard Worker 0, 0, 0, 0, 1, 1, 1, 1 189*89c4ff92SAndroid Build Coastguard Worker }; 190*89c4ff92SAndroid Build Coastguard Worker 191*89c4ff92SAndroid Build Coastguard Worker m_OutputLess = 192*89c4ff92SAndroid Build Coastguard Worker { 193*89c4ff92SAndroid Build Coastguard Worker 0, 0, 0, 0, 0, 0, 0, 0, 194*89c4ff92SAndroid Build Coastguard Worker 1, 1, 1, 1, 0, 0, 0, 0 195*89c4ff92SAndroid Build Coastguard Worker }; 196*89c4ff92SAndroid Build Coastguard Worker 197*89c4ff92SAndroid Build Coastguard Worker m_OutputLessOrEqual = 198*89c4ff92SAndroid Build Coastguard Worker { 199*89c4ff92SAndroid Build Coastguard Worker 1, 1, 1, 1, 0, 0, 0, 0, 200*89c4ff92SAndroid Build Coastguard Worker 1, 1, 1, 1, 1, 1, 1, 1 201*89c4ff92SAndroid Build Coastguard Worker }; 202*89c4ff92SAndroid Build Coastguard Worker 203*89c4ff92SAndroid Build Coastguard Worker m_OutputNotEqual = 204*89c4ff92SAndroid Build Coastguard Worker { 205*89c4ff92SAndroid Build Coastguard Worker 0, 0, 0, 0, 1, 1, 1, 1, 206*89c4ff92SAndroid Build Coastguard Worker 1, 1, 1, 1, 0, 0, 0, 0 207*89c4ff92SAndroid Build Coastguard Worker }; 208*89c4ff92SAndroid Build Coastguard Worker } 209*89c4ff92SAndroid Build Coastguard Worker 210*89c4ff92SAndroid Build Coastguard Worker std::vector<float> m_InputData0; 211*89c4ff92SAndroid Build Coastguard Worker std::vector<float> m_InputData1; 212*89c4ff92SAndroid Build Coastguard Worker 213*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> m_OutputEqual; 214*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> m_OutputGreater; 215*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> m_OutputGreaterOrEqual; 216*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> m_OutputLess; 217*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> m_OutputLessOrEqual; 218*89c4ff92SAndroid Build Coastguard Worker std::vector<uint8_t> m_OutputNotEqual; 219*89c4ff92SAndroid Build Coastguard Worker }; 220*89c4ff92SAndroid Build Coastguard Worker 221*89c4ff92SAndroid Build Coastguard Worker struct SimpleComparisonFixture : public ComparisonFixture 222*89c4ff92SAndroid Build Coastguard Worker { SimpleComparisonFixtureSimpleComparisonFixture223*89c4ff92SAndroid Build Coastguard Worker SimpleComparisonFixture(const std::string& inputDataType, 224*89c4ff92SAndroid Build Coastguard Worker const std::string& comparisonOperation) 225*89c4ff92SAndroid Build Coastguard Worker : ComparisonFixture("[ 2, 2, 2, 2 ]", // inputShape0 226*89c4ff92SAndroid Build Coastguard Worker "[ 2, 2, 2, 2 ]", // inputShape1 227*89c4ff92SAndroid Build Coastguard Worker "[ 2, 2, 2, 2 ]", // outputShape, 228*89c4ff92SAndroid Build Coastguard Worker inputDataType, 229*89c4ff92SAndroid Build Coastguard Worker comparisonOperation) {} 230*89c4ff92SAndroid Build Coastguard Worker 231*89c4ff92SAndroid Build Coastguard Worker static SimpleComparisonTestData s_TestData; 232*89c4ff92SAndroid Build Coastguard Worker }; 233*89c4ff92SAndroid Build Coastguard Worker 234*89c4ff92SAndroid Build Coastguard Worker SimpleComparisonTestData SimpleComparisonFixture::s_TestData; 235*89c4ff92SAndroid Build Coastguard Worker 236*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, Float32) 237*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, Float32) 238*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, Float32) 239*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, Float32) 240*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, Float32) 241*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, Float32) 242*89c4ff92SAndroid Build Coastguard Worker 243*89c4ff92SAndroid Build Coastguard Worker 244*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Equal, QAsymmU8) 245*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Greater, QAsymmU8) 246*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(GreaterOrEqual, QAsymmU8) 247*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(Less, QAsymmU8) 248*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(LessOrEqual, QAsymmU8) 249*89c4ff92SAndroid Build Coastguard Worker DECLARE_SIMPLE_COMPARISON_TEST_CASE(NotEqual, QAsymmU8) 250*89c4ff92SAndroid Build Coastguard Worker 251*89c4ff92SAndroid Build Coastguard Worker } 252