1 // 2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersFixture.hpp" 7 8 9 TEST_SUITE("TensorflowLiteParser_StridedSlice") 10 { 11 struct StridedSliceFixture : public ParserFlatbuffersFixture 12 { StridedSliceFixtureStridedSliceFixture13 explicit StridedSliceFixture(const std::string & inputShape, 14 const std::string & outputShape, 15 const std::string & beginData, 16 const std::string & endData, 17 const std::string & stridesData, 18 int beginMask = 0, 19 int endMask = 0) 20 { 21 m_JsonString = R"( 22 { 23 "version": 3, 24 "operator_codes": [ { "builtin_code": "STRIDED_SLICE" } ], 25 "subgraphs": [ { 26 "tensors": [ 27 { 28 "shape": )" + inputShape + R"(, 29 "type": "FLOAT32", 30 "buffer": 0, 31 "name": "inputTensor", 32 "quantization": { 33 "min": [ 0.0 ], 34 "max": [ 255.0 ], 35 "scale": [ 1.0 ], 36 "zero_point": [ 0 ], 37 } 38 }, 39 { 40 "shape": [ 4 ], 41 "type": "INT32", 42 "buffer": 1, 43 "name": "beginTensor", 44 "quantization": { 45 } 46 }, 47 { 48 "shape": [ 4 ], 49 "type": "INT32", 50 "buffer": 2, 51 "name": "endTensor", 52 "quantization": { 53 } 54 }, 55 { 56 "shape": [ 4 ], 57 "type": "INT32", 58 "buffer": 3, 59 "name": "stridesTensor", 60 "quantization": { 61 } 62 }, 63 { 64 "shape": )" + outputShape + R"( , 65 "type": "FLOAT32", 66 "buffer": 4, 67 "name": "outputTensor", 68 "quantization": { 69 "min": [ 0.0 ], 70 "max": [ 255.0 ], 71 "scale": [ 1.0 ], 72 "zero_point": [ 0 ], 73 } 74 } 75 ], 76 "inputs": [ 0, 1, 2, 3 ], 77 "outputs": [ 4 ], 78 "operators": [ 79 { 80 "opcode_index": 0, 81 "inputs": [ 0, 1, 2, 3 ], 82 "outputs": [ 4 ], 83 "builtin_options_type": "StridedSliceOptions", 84 "builtin_options": { 85 "begin_mask": )" + std::to_string(beginMask) + R"(, 86 "end_mask": )" + std::to_string(endMask) + R"( 87 }, 88 "custom_options_format": "FLEXBUFFERS" 89 } 90 ], 91 } ], 92 "buffers" : [ 93 { }, 94 { "data": )" + beginData + R"(, }, 95 { "data": )" + endData + R"(, }, 96 { "data": )" + stridesData + R"(, }, 97 { } 98 ] 99 } 100 )"; 101 Setup(); 102 } 103 }; 104 105 struct StridedSlice4DFixture : StridedSliceFixture 106 { StridedSlice4DFixtureStridedSlice4DFixture107 StridedSlice4DFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape 108 "[ 1, 2, 3, 1 ]", // outputShape 109 "[ 1,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData 110 "[ 2,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData 111 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]" // stridesData 112 ) {} 113 }; 114 115 TEST_CASE_FIXTURE(StridedSlice4DFixture, "StridedSlice4D") 116 { 117 RunTest<4, armnn::DataType::Float32>( 118 0, 119 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 120 121 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 122 123 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, 124 125 {{"outputTensor", { 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f }}}); 126 } 127 128 struct StridedSlice4DReverseFixture : StridedSliceFixture 129 { StridedSlice4DReverseFixtureStridedSlice4DReverseFixture130 StridedSlice4DReverseFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape 131 "[ 1, 2, 3, 1 ]", // outputShape 132 "[ 1,0,0,0, " 133 "255,255,255,255, " 134 "0,0,0,0, " 135 "0,0,0,0 ]", // beginData [ 1 -1 0 0 ] 136 "[ 2,0,0,0, " 137 "253,255,255,255, " 138 "3,0,0,0, " 139 "1,0,0,0 ]", // endData [ 2 -3 3 1 ] 140 "[ 1,0,0,0, " 141 "255,255,255,255, " 142 "1,0,0,0, " 143 "1,0,0,0 ]" // stridesData [ 1 -1 1 1 ] 144 ) {} 145 }; 146 147 TEST_CASE_FIXTURE(StridedSlice4DReverseFixture, "StridedSlice4DReverse") 148 { 149 RunTest<4, armnn::DataType::Float32>( 150 0, 151 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 152 153 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 154 155 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, 156 157 {{"outputTensor", { 4.0f, 4.0f, 4.0f, 3.0f, 3.0f, 3.0f }}}); 158 } 159 160 struct StridedSliceSimpleStrideFixture : StridedSliceFixture 161 { StridedSliceSimpleStrideFixtureStridedSliceSimpleStrideFixture162 StridedSliceSimpleStrideFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape 163 "[ 2, 1, 2, 1 ]", // outputShape 164 "[ 0,0,0,0, 0,0,0,0, 0,0,0,0, 0,0,0,0 ]", // beginData 165 "[ 3,0,0,0, 2,0,0,0, 3,0,0,0, 1,0,0,0 ]", // endData 166 "[ 2,0,0,0, 2,0,0,0, 2,0,0,0, 1,0,0,0 ]" // stridesData 167 ) {} 168 }; 169 170 TEST_CASE_FIXTURE(StridedSliceSimpleStrideFixture, "StridedSliceSimpleStride") 171 { 172 RunTest<4, armnn::DataType::Float32>( 173 0, 174 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 175 176 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 177 178 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, 179 180 {{"outputTensor", { 1.0f, 1.0f, 181 182 5.0f, 5.0f }}}); 183 } 184 185 struct StridedSliceSimpleRangeMaskFixture : StridedSliceFixture 186 { StridedSliceSimpleRangeMaskFixtureStridedSliceSimpleRangeMaskFixture187 StridedSliceSimpleRangeMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", // inputShape 188 "[ 3, 2, 3, 1 ]", // outputShape 189 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // beginData 190 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // endData 191 "[ 1,0,0,0, 1,0,0,0, 1,0,0,0, 1,0,0,0 ]", // stridesData 192 (1 << 4) - 1, // beginMask 193 (1 << 4) - 1 // endMask 194 ) {} 195 }; 196 197 TEST_CASE_FIXTURE(StridedSliceSimpleRangeMaskFixture, "StridedSliceSimpleRangeMask") 198 { 199 RunTest<4, armnn::DataType::Float32>( 200 0, 201 {{"inputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 202 203 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 204 205 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}, 206 207 {{"outputTensor", { 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 208 209 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 210 211 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f }}}); 212 } 213 214 } 215