1 // 2 // Copyright © 2019 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 TEST_SUITE("TensorflowLiteParser_Slice") 9 { 10 struct SliceFixture : public ParserFlatbuffersFixture 11 { SliceFixtureSliceFixture12 explicit SliceFixture(const std::string & inputShape, 13 const std::string & outputShape, 14 const std::string & beginData, 15 const std::string & sizeData) 16 { 17 m_JsonString = R"( 18 { 19 "version": 3, 20 "operator_codes": [ 21 { 22 "builtin_code": "SLICE", 23 "version": 1 24 } 25 ], 26 "subgraphs": [ 27 { 28 "tensors": [ 29 { 30 "shape": )" + inputShape + R"(, 31 "type": "FLOAT32", 32 "buffer": 0, 33 "name": "inputTensor", 34 "quantization": { 35 "min": [ 36 0.0 37 ], 38 "max": [ 39 255.0 40 ], 41 "details_type": 0, 42 "quantized_dimension": 0 43 }, 44 "is_variable": false 45 }, 46 { 47 "shape": )" + outputShape + R"(, 48 "type": "FLOAT32", 49 "buffer": 1, 50 "name": "outputTensor", 51 "quantization": { 52 "details_type": 0, 53 "quantized_dimension": 0 54 }, 55 "is_variable": false 56 })"; 57 m_JsonString += R"(, 58 { 59 "shape": [ 60 3 61 ], 62 "type": "INT32", 63 "buffer": 2, 64 "name": "beginTensor", 65 "quantization": { 66 } 67 })"; 68 m_JsonString += R"(, 69 { 70 "shape": [ 71 3 72 ], 73 "type": "INT32", 74 "buffer": 3, 75 "name": "sizeTensor", 76 "quantization": { 77 } 78 })"; 79 m_JsonString += R"(], 80 "inputs": [ 81 0 82 ], 83 "outputs": [ 84 1 85 ], 86 "operators": [ 87 { 88 "opcode_index": 0, 89 "inputs": [ 90 0, 91 2, 92 3)"; 93 m_JsonString += R"(], 94 "outputs": [ 95 1 96 ], 97 mutating_variable_inputs: [ 98 ] 99 } 100 ] 101 } 102 ], 103 "description": "TOCO Converted.", 104 "buffers": [ 105 { }, 106 { })"; 107 m_JsonString += R"(,{"data": )" + beginData + R"( })"; 108 m_JsonString += R"(,{"data": )" + sizeData + R"( })"; 109 m_JsonString += R"( 110 ] 111 } 112 )"; 113 SetupSingleInputSingleOutput("inputTensor", "outputTensor"); 114 } 115 }; 116 117 struct SliceFixtureSingleDim : SliceFixture 118 { SliceFixtureSingleDimSliceFixtureSingleDim119 SliceFixtureSingleDim() : SliceFixture("[ 3, 2, 3 ]", 120 "[ 1, 1, 3 ]", 121 "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]", 122 "[ 1, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0 ]") {} 123 }; 124 125 TEST_CASE_FIXTURE(SliceFixtureSingleDim, "SliceSingleDim") 126 { 127 RunTest<3, armnn::DataType::Float32>( 128 0, 129 {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}}, 130 {{"outputTensor", { 3, 3, 3 }}}); 131 132 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() 133 == armnn::TensorShape({1,1,3}))); 134 } 135 136 struct SliceFixtureD123 : SliceFixture 137 { SliceFixtureD123SliceFixtureD123138 SliceFixtureD123() : SliceFixture("[ 3, 2, 3 ]", 139 "[ 1, 2, 3 ]", 140 "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]", 141 "[ 1, 0, 0, 0, 2, 0, 0, 0, 3, 0, 0, 0 ]") {} 142 }; 143 144 TEST_CASE_FIXTURE(SliceFixtureD123, "SliceD123") 145 { 146 RunTest<3, armnn::DataType::Float32>( 147 0, 148 {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}}, 149 {{"outputTensor", { 3, 3, 3, 4, 4, 4 }}}); 150 151 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() 152 == armnn::TensorShape({1,2,3}))); 153 } 154 155 struct SliceFixtureD213 : SliceFixture 156 { SliceFixtureD213SliceFixtureD213157 SliceFixtureD213() : SliceFixture("[ 3, 2, 3 ]", 158 "[ 2, 1, 3 ]", 159 "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]", 160 "[ 2, 0, 0, 0, 1, 0, 0, 0, 3, 0, 0, 0 ]") {} 161 }; 162 163 TEST_CASE_FIXTURE(SliceFixtureD213, "SliceD213") 164 { 165 RunTest<3, armnn::DataType::Float32>( 166 0, 167 {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}}, 168 {{"outputTensor", { 3, 3, 3, 5, 5, 5 }}}); 169 170 CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() 171 == armnn::TensorShape({2,1,3}))); 172 } 173 174 struct DynamicSliceFixtureD213 : SliceFixture 175 { DynamicSliceFixtureD213DynamicSliceFixtureD213176 DynamicSliceFixtureD213() : SliceFixture("[ 3, 2, 3 ]", 177 "[ ]", 178 "[ 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0 ]", 179 "[ 255, 255, 255, 255, 1, 0, 0, 0, 255, 255, 255, 255 ]") {} 180 }; 181 182 TEST_CASE_FIXTURE(DynamicSliceFixtureD213, "DynamicSliceD213") 183 { 184 RunTest<3, armnn::DataType::Float32, armnn::DataType::Float32>( 185 0, 186 {{"inputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 }}}, 187 {{"outputTensor", { 3, 3, 3, 5, 5, 5 }}}, 188 true); 189 } 190 }