1 // 2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 9 TEST_SUITE("TensorflowLiteParser") 10 { 11 struct SplitVFixture : public ParserFlatbuffersFixture 12 { SplitVFixtureSplitVFixture13 explicit SplitVFixture(const std::string& inputShape, 14 const std::string& splitValues, 15 const std::string& sizeSplitsShape, 16 const std::string& axisShape, 17 const std::string& numSplits, 18 const std::string& outputShape1, 19 const std::string& outputShape2, 20 const std::string& axisData, 21 const std::string& dataType) 22 { 23 m_JsonString = R"( 24 { 25 "version": 3, 26 "operator_codes": [ { "builtin_code": "SPLIT_V" } ], 27 "subgraphs": [ { 28 "tensors": [ 29 { 30 "shape": )" + inputShape + R"(, 31 "type": )" + dataType + R"(, 32 "buffer": 0, 33 "name": "inputTensor", 34 "quantization": { 35 "min": [ 0.0 ], 36 "max": [ 255.0 ], 37 "scale": [ 1.0 ], 38 "zero_point": [ 0 ], 39 } 40 }, 41 { 42 "shape": )" + sizeSplitsShape + R"(, 43 "type": "INT32", 44 "buffer": 1, 45 "name": "sizeSplits", 46 "quantization": { 47 "min": [ 0.0 ], 48 "max": [ 255.0 ], 49 "scale": [ 1.0 ], 50 "zero_point": [ 0 ], 51 } 52 }, 53 { 54 "shape": )" + axisShape + R"(, 55 "type": "INT32", 56 "buffer": 2, 57 "name": "axis", 58 "quantization": { 59 "min": [ 0.0 ], 60 "max": [ 255.0 ], 61 "scale": [ 1.0 ], 62 "zero_point": [ 0 ], 63 } 64 }, 65 { 66 "shape": )" + outputShape1 + R"( , 67 "type":)" + dataType + R"(, 68 "buffer": 3, 69 "name": "outputTensor1", 70 "quantization": { 71 "min": [ 0.0 ], 72 "max": [ 255.0 ], 73 "scale": [ 1.0 ], 74 "zero_point": [ 0 ], 75 } 76 }, 77 { 78 "shape": )" + outputShape2 + R"( , 79 "type":)" + dataType + R"(, 80 "buffer": 4, 81 "name": "outputTensor2", 82 "quantization": { 83 "min": [ 0.0 ], 84 "max": [ 255.0 ], 85 "scale": [ 1.0 ], 86 "zero_point": [ 0 ], 87 } 88 } 89 ], 90 "inputs": [ 0, 1, 2 ], 91 "outputs": [ 3, 4 ], 92 "operators": [ 93 { 94 "opcode_index": 0, 95 "inputs": [ 0, 1, 2 ], 96 "outputs": [ 3, 4 ], 97 "builtin_options_type": "SplitVOptions", 98 "builtin_options": { 99 "num_splits": )" + numSplits + R"( 100 }, 101 "custom_options_format": "FLEXBUFFERS" 102 } 103 ], 104 } ], 105 "buffers" : [ {}, { "data": )" + splitValues + R"( }, { "data": )" + axisData + R"( }, {}, {}] 106 } 107 )"; 108 109 Setup(); 110 } 111 }; 112 113 /* 114 * Tested inferred splitSizes with splitValues [-1, 1] locally. 115 */ 116 117 struct SimpleSplitVAxisOneFixture : SplitVFixture 118 { SimpleSplitVAxisOneFixtureSimpleSplitVAxisOneFixture119 SimpleSplitVAxisOneFixture() 120 : SplitVFixture( "[ 4, 2, 2, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2", 121 "[ 1, 2, 2, 2 ]", "[ 3, 2, 2, 2 ]", "[ 0, 0, 0, 0 ]", "FLOAT32") 122 {} 123 }; 124 125 TEST_CASE_FIXTURE(SimpleSplitVAxisOneFixture, "ParseAxisOneSplitVTwo") 126 { 127 RunTest<4, armnn::DataType::Float32>( 128 0, 129 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 130 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 131 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 132 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } }, 133 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } }, 134 {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 135 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 136 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } ); 137 } 138 139 struct SimpleSplitVAxisTwoFixture : SplitVFixture 140 { SimpleSplitVAxisTwoFixtureSimpleSplitVAxisTwoFixture141 SimpleSplitVAxisTwoFixture() 142 : SplitVFixture( "[ 2, 4, 2, 2 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2", 143 "[ 2, 3, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32") 144 {} 145 }; 146 147 TEST_CASE_FIXTURE(SimpleSplitVAxisTwoFixture, "ParseAxisTwoSplitVTwo") 148 { 149 RunTest<4, armnn::DataType::Float32>( 150 0, 151 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 152 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 153 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 154 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } }, 155 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 156 9.0f, 10.0f, 11.0f, 12.0f, 17.0f, 18.0f, 19.0f, 20.0f, 157 21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f } }, 158 {"outputTensor2", { 13.0f, 14.0f, 15.0f, 16.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } ); 159 } 160 161 struct SimpleSplitVAxisThreeFixture : SplitVFixture 162 { SimpleSplitVAxisThreeFixtureSimpleSplitVAxisThreeFixture163 SimpleSplitVAxisThreeFixture() 164 : SplitVFixture( "[ 2, 2, 4, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2", 165 "[ 2, 2, 1, 2 ]", "[ 2, 2, 3, 2 ]", "[ 2, 0, 0, 0 ]", "FLOAT32") 166 {} 167 }; 168 169 TEST_CASE_FIXTURE(SimpleSplitVAxisThreeFixture, "ParseAxisThreeSplitVTwo") 170 { 171 RunTest<4, armnn::DataType::Float32>( 172 0, 173 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 174 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 175 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 176 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } }, 177 { {"outputTensor1", { 1.0f, 2.0f, 9.0f, 10.0f, 17.0f, 18.0f, 25.0f, 26.0f } }, 178 {"outputTensor2", { 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 11.0f, 12.0f, 179 13.0f, 14.0f, 15.0f, 16.0f, 19.0f, 20.0f, 21.0f, 22.0f, 180 23.0f, 24.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } ); 181 } 182 183 struct SimpleSplitVAxisFourFixture : SplitVFixture 184 { SimpleSplitVAxisFourFixtureSimpleSplitVAxisFourFixture185 SimpleSplitVAxisFourFixture() 186 : SplitVFixture( "[ 2, 2, 2, 4 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2", 187 "[ 2, 2, 2, 3 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "FLOAT32") 188 {} 189 }; 190 191 TEST_CASE_FIXTURE(SimpleSplitVAxisFourFixture, "ParseAxisFourSplitVTwo") 192 { 193 RunTest<4, armnn::DataType::Float32>( 194 0, 195 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 196 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f, 197 17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f, 198 25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } }, 199 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 5.0f, 6.0f, 7.0f, 9.0f, 10.0f, 200 11.0f, 13.0f, 14.0f, 15.0f, 17.0f, 18.0f, 19.0f, 21.0f, 201 22.0f, 23.0f, 25.0f, 26.0f, 27.0f, 29.0f, 30.0f, 31.0f} }, 202 {"outputTensor2", { 4.0f, 8.0f, 12.0f, 16.0f, 20.0f, 24.0f, 28.0f, 32.0f } } } ); 203 } 204 205 } 206