1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 9 TEST_SUITE("TensorflowLiteParser_Split") 10 { 11 struct SplitFixture : public ParserFlatbuffersFixture 12 { SplitFixtureSplitFixture13 explicit SplitFixture(const std::string& inputShape, 14 const std::string& axisShape, 15 const std::string& numSplits, 16 const std::string& outputShape1, 17 const std::string& outputShape2, 18 const std::string& axisData, 19 const std::string& dataType) 20 { 21 m_JsonString = R"( 22 { 23 "version": 3, 24 "operator_codes": [ { "builtin_code": "SPLIT" } ], 25 "subgraphs": [ { 26 "tensors": [ 27 { 28 "shape": )" + inputShape + R"(, 29 "type": )" + dataType + R"(, 30 "buffer": 0, 31 "name": "inputTensor", 32 "quantization": { 33 "min": [ 0.0 ], 34 "max": [ 255.0 ], 35 "scale": [ 1.0 ], 36 "zero_point": [ 0 ], 37 } 38 }, 39 { 40 "shape": )" + axisShape + R"(, 41 "type": "INT32", 42 "buffer": 1, 43 "name": "axis", 44 "quantization": { 45 "min": [ 0.0 ], 46 "max": [ 255.0 ], 47 "scale": [ 1.0 ], 48 "zero_point": [ 0 ], 49 } 50 }, 51 { 52 "shape": )" + outputShape1 + R"( , 53 "type":)" + dataType + R"(, 54 "buffer": 2, 55 "name": "outputTensor1", 56 "quantization": { 57 "min": [ 0.0 ], 58 "max": [ 255.0 ], 59 "scale": [ 1.0 ], 60 "zero_point": [ 0 ], 61 } 62 }, 63 { 64 "shape": )" + outputShape2 + R"( , 65 "type":)" + dataType + R"(, 66 "buffer": 3, 67 "name": "outputTensor2", 68 "quantization": { 69 "min": [ 0.0 ], 70 "max": [ 255.0 ], 71 "scale": [ 1.0 ], 72 "zero_point": [ 0 ], 73 } 74 } 75 ], 76 "inputs": [ 0 ], 77 "outputs": [ 2, 3 ], 78 "operators": [ 79 { 80 "opcode_index": 0, 81 "inputs": [ 1, 0 ], 82 "outputs": [ 2, 3 ], 83 "builtin_options_type": "SplitOptions", 84 "builtin_options": { 85 "num_splits": )" + numSplits + R"( 86 }, 87 "custom_options_format": "FLEXBUFFERS" 88 } 89 ], 90 } ], 91 "buffers" : [ {}, {"data": )" + axisData + R"( }, {}, {} ] 92 } 93 )"; 94 95 Setup(); 96 } 97 }; 98 99 100 struct SimpleSplitFixtureFloat32 : SplitFixture 101 { SimpleSplitFixtureFloat32SimpleSplitFixtureFloat32102 SimpleSplitFixtureFloat32() 103 : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32") 104 {} 105 }; 106 107 TEST_CASE_FIXTURE(SimpleSplitFixtureFloat32, "ParseAxisOneSplitTwoFloat32") 108 { 109 110 RunTest<4, armnn::DataType::Float32>( 111 0, 112 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 113 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } }, 114 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 9.0f, 10.0f, 11.0f, 12.0f } }, 115 {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f, 13.0f, 14.0f, 15.0f, 16.0f } } }); 116 } 117 118 struct SimpleSplitAxisThreeFixtureFloat32 : SplitFixture 119 { SimpleSplitAxisThreeFixtureFloat32SimpleSplitAxisThreeFixtureFloat32120 SimpleSplitAxisThreeFixtureFloat32() 121 : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "FLOAT32") 122 {} 123 }; 124 125 TEST_CASE_FIXTURE(SimpleSplitAxisThreeFixtureFloat32, "ParseAxisThreeSplitTwoFloat32") 126 { 127 RunTest<4, armnn::DataType::Float32>( 128 0, 129 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 130 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } }, 131 { {"outputTensor1", { 1.0f, 3.0f, 5.0f, 7.0f, 9.0f, 11.0f, 13.0f, 15.0f } }, 132 {"outputTensor2", { 2.0f, 4.0f, 6.0f, 8.0f, 10.0f, 12.0f, 14.0f, 16.0f } } } ); 133 } 134 135 struct SimpleSplit2DFixtureFloat32 : SplitFixture 136 { SimpleSplit2DFixtureFloat32SimpleSplit2DFixtureFloat32137 SimpleSplit2DFixtureFloat32() 138 : SplitFixture( "[ 1, 8 ]", "[ ]", "2", "[ 1, 4 ]", "[ 1, 4 ]", "[ 1, 0, 0, 0 ]", "FLOAT32") 139 {} 140 }; 141 142 TEST_CASE_FIXTURE(SimpleSplit2DFixtureFloat32, "SimpleSplit2DFloat32") 143 { 144 RunTest<2, armnn::DataType::Float32>( 145 0, 146 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } } }, 147 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f } }, 148 {"outputTensor2", { 5.0f, 6.0f, 7.0f, 8.0f } } } ); 149 } 150 151 struct SimpleSplit3DFixtureFloat32 : SplitFixture 152 { SimpleSplit3DFixtureFloat32SimpleSplit3DFixtureFloat32153 SimpleSplit3DFixtureFloat32() 154 : SplitFixture( "[ 1, 8, 2 ]", "[ ]", "2", "[ 1, 4, 2 ]", "[ 1, 4, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32") 155 {} 156 }; 157 158 TEST_CASE_FIXTURE(SimpleSplit3DFixtureFloat32, "SimpleSplit3DFloat32") 159 { 160 RunTest<3, armnn::DataType::Float32>( 161 0, 162 { {"inputTensor", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 163 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } }, 164 { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } }, 165 {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f } } } ); 166 } 167 168 struct SimpleSplitFixtureUint8 : SplitFixture 169 { SimpleSplitFixtureUint8SimpleSplitFixtureUint8170 SimpleSplitFixtureUint8() 171 : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", "[ 2, 1, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "UINT8") 172 {} 173 }; 174 175 TEST_CASE_FIXTURE(SimpleSplitFixtureUint8, "ParseAxisOneSplitTwoUint8") 176 { 177 178 RunTest<4, armnn::DataType::QAsymmU8>( 179 0, 180 { {"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 181 9, 10, 11, 12, 13, 14, 15, 16 } } }, 182 { {"outputTensor1", { 1, 2, 3, 4, 9, 10, 11, 12 } }, 183 {"outputTensor2", { 5, 6, 7, 8, 13, 14, 15, 16 } } }); 184 } 185 186 struct SimpleSplitAxisThreeFixtureUint8 : SplitFixture 187 { SimpleSplitAxisThreeFixtureUint8SimpleSplitAxisThreeFixtureUint8188 SimpleSplitAxisThreeFixtureUint8() 189 : SplitFixture( "[ 2, 2, 2, 2 ]", "[ ]", "2", "[ 2, 2, 2, 1 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "UINT8") 190 {} 191 }; 192 193 TEST_CASE_FIXTURE(SimpleSplitAxisThreeFixtureUint8, "ParseAxisThreeSplitTwoUint8") 194 { 195 RunTest<4, armnn::DataType::QAsymmU8>( 196 0, 197 { {"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 198 9, 10, 11, 12, 13, 14, 15, 16 } } }, 199 { {"outputTensor1", { 1, 3, 5, 7, 9, 11, 13, 15 } }, 200 {"outputTensor2", { 2, 4, 6, 8, 10, 12, 14, 16 } } } ); 201 } 202 203 struct SimpleSplit2DFixtureUint8 : SplitFixture 204 { SimpleSplit2DFixtureUint8SimpleSplit2DFixtureUint8205 SimpleSplit2DFixtureUint8() 206 : SplitFixture( "[ 1, 8 ]", "[ ]", "2", "[ 1, 4 ]", "[ 1, 4 ]", "[ 1, 0, 0, 0 ]", "UINT8") 207 {} 208 }; 209 210 TEST_CASE_FIXTURE(SimpleSplit2DFixtureUint8, "SimpleSplit2DUint8") 211 { 212 RunTest<2, armnn::DataType::QAsymmU8>( 213 0, 214 { {"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8 } } }, 215 { {"outputTensor1", { 1, 2, 3, 4 } }, 216 {"outputTensor2", { 5, 6, 7, 8 } } } ); 217 } 218 219 struct SimpleSplit3DFixtureUint8 : SplitFixture 220 { SimpleSplit3DFixtureUint8SimpleSplit3DFixtureUint8221 SimpleSplit3DFixtureUint8() 222 : SplitFixture( "[ 1, 8, 2 ]", "[ ]", "2", "[ 1, 4, 2 ]", "[ 1, 4, 2 ]", "[ 1, 0, 0, 0 ]", "UINT8") 223 {} 224 }; 225 226 TEST_CASE_FIXTURE(SimpleSplit3DFixtureUint8, "SimpleSplit3DUint8") 227 { 228 RunTest<3, armnn::DataType::QAsymmU8>( 229 0, 230 { {"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 231 9, 10, 11, 12, 13, 14, 15, 16 } } }, 232 { {"outputTensor1", { 1, 2, 3, 4, 5, 6, 7, 8 } }, 233 {"outputTensor2", { 9, 10, 11, 12, 13, 14, 15, 16 } } } ); 234 } 235 236 }