1 // 2 // Copyright © 2021 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 9 TEST_SUITE("TensorflowLiteParser_Gather") 10 { 11 struct GatherFixture : public ParserFlatbuffersFixture 12 { GatherFixtureGatherFixture13 explicit GatherFixture(const std::string& paramsShape, 14 const std::string& outputShape, 15 const std::string& indicesShape, 16 const std::string& dataType = "FLOAT32", 17 const std::string& scale = "1.0", 18 const std::string& offset = "0") 19 { 20 m_JsonString = R"( 21 { 22 "version": 3, 23 "operator_codes": [ { "builtin_code": "GATHER" } ], 24 "subgraphs": [ { 25 "tensors": [ 26 { 27 "shape": )" + paramsShape + R"(, 28 "type": )" + dataType + R"(, 29 "buffer": 0, 30 "name": "inputTensor", 31 "quantization": { 32 "min": [ 0.0 ], 33 "max": [ 255.0 ], 34 "scale": [ )" + scale + R"( ], 35 "zero_point": [ )" + offset + R"( ], 36 } 37 }, 38 { 39 "shape": )" + indicesShape + R"( , 40 "type": "INT32", 41 "buffer": 1, 42 "name": "indices", 43 "quantization": { 44 "min": [ 0.0 ], 45 "max": [ 255.0 ], 46 "scale": [ 1.0 ], 47 "zero_point": [ 0 ], 48 } 49 }, 50 { 51 "shape": )" + outputShape + R"(, 52 "type": )" + dataType + R"(, 53 "buffer": 2, 54 "name": "outputTensor", 55 "quantization": { 56 "min": [ 0.0 ], 57 "max": [ 255.0 ], 58 "scale": [ )" + scale + R"( ], 59 "zero_point": [ )" + offset + R"( ], 60 } 61 } 62 ], 63 "inputs": [ 0, 1 ], 64 "outputs": [ 2 ], 65 "operators": [ 66 { 67 "opcode_index": 0, 68 "inputs": [ 0, 1 ], 69 "outputs": [ 2 ], 70 "builtin_options_type": "GatherOptions", 71 "builtin_options": { 72 "axis": 0 73 }, 74 "custom_options_format": "FLEXBUFFERS" 75 } 76 ], 77 } ], 78 "buffers" : [ 79 { }, 80 { }, 81 { }, 82 ] 83 } 84 )"; 85 Setup(); 86 } 87 }; 88 89 struct SimpleGatherFixture : public GatherFixture 90 { SimpleGatherFixtureSimpleGatherFixture91 SimpleGatherFixture() : GatherFixture("[ 5, 2 ]", "[ 3, 2 ]", "[ 3 ]") {} 92 }; 93 94 TEST_CASE_FIXTURE(SimpleGatherFixture, "ParseGather") 95 { 96 RunTest<2, armnn::DataType::Float32, armnn::DataType::Signed32, armnn::DataType::Float32> 97 (0, 98 {{ "inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }}}, 99 {{ "indices", { 1, 3, 4 }}}, 100 {{ "outputTensor", { 3, 4, 7, 8, 9, 10 }}}); 101 } 102 103 struct GatherUint8Fixture : public GatherFixture 104 { GatherUint8FixtureGatherUint8Fixture105 GatherUint8Fixture() : GatherFixture("[ 8 ]", "[ 3 ]", "[ 3 ]", "UINT8") {} 106 }; 107 108 TEST_CASE_FIXTURE(GatherUint8Fixture, "ParseGatherUint8") 109 { 110 RunTest<1, armnn::DataType::QAsymmU8, armnn::DataType::Signed32, armnn::DataType::QAsymmU8> 111 (0, 112 {{ "inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8 }}}, 113 {{ "indices", { 7, 6, 5 }}}, 114 {{ "outputTensor", { 8, 7, 6 }}}); 115 } 116 117 } 118