1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersSerializeFixture.hpp" 7 #include <armnnDeserializer/IDeserializer.hpp> 8 9 #include <string> 10 11 TEST_SUITE("Deserializer_StridedSlice") 12 { 13 struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture 14 { StridedSliceFixtureStridedSliceFixture15 explicit StridedSliceFixture(const std::string& inputShape, 16 const std::string& begin, 17 const std::string& end, 18 const std::string& stride, 19 const std::string& beginMask, 20 const std::string& endMask, 21 const std::string& shrinkAxisMask, 22 const std::string& ellipsisMask, 23 const std::string& newAxisMask, 24 const std::string& dataLayout, 25 const std::string& outputShape, 26 const std::string& dataType) 27 { 28 m_JsonString = R"( 29 { 30 inputIds: [0], 31 outputIds: [2], 32 layers: [ 33 { 34 layer_type: "InputLayer", 35 layer: { 36 base: { 37 layerBindingId: 0, 38 base: { 39 index: 0, 40 layerName: "InputLayer", 41 layerType: "Input", 42 inputSlots: [{ 43 index: 0, 44 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 45 }], 46 outputSlots: [{ 47 index: 0, 48 tensorInfo: { 49 dimensions: )" + inputShape + R"(, 50 dataType: )" + dataType + R"( 51 } 52 }] 53 } 54 } 55 } 56 }, 57 { 58 layer_type: "StridedSliceLayer", 59 layer: { 60 base: { 61 index: 1, 62 layerName: "StridedSliceLayer", 63 layerType: "StridedSlice", 64 inputSlots: [{ 65 index: 0, 66 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 67 }], 68 outputSlots: [{ 69 index: 0, 70 tensorInfo: { 71 dimensions: )" + outputShape + R"(, 72 dataType: )" + dataType + R"( 73 } 74 }] 75 }, 76 descriptor: { 77 begin: )" + begin + R"(, 78 end: )" + end + R"(, 79 stride: )" + stride + R"(, 80 beginMask: )" + beginMask + R"(, 81 endMask: )" + endMask + R"(, 82 shrinkAxisMask: )" + shrinkAxisMask + R"(, 83 ellipsisMask: )" + ellipsisMask + R"(, 84 newAxisMask: )" + newAxisMask + R"(, 85 dataLayout: )" + dataLayout + R"(, 86 } 87 } 88 }, 89 { 90 layer_type: "OutputLayer", 91 layer: { 92 base:{ 93 layerBindingId: 2, 94 base: { 95 index: 2, 96 layerName: "OutputLayer", 97 layerType: "Output", 98 inputSlots: [{ 99 index: 0, 100 connection: {sourceLayerIndex:1, outputSlotIndex:0 }, 101 }], 102 outputSlots: [{ 103 index: 0, 104 tensorInfo: { 105 dimensions: )" + outputShape + R"(, 106 dataType: )" + dataType + R"( 107 }, 108 }], 109 } 110 } 111 }, 112 } 113 ] 114 } 115 )"; 116 SetupSingleInputSingleOutput("InputLayer", "OutputLayer"); 117 } 118 }; 119 120 struct SimpleStridedSliceFixture : StridedSliceFixture 121 { SimpleStridedSliceFixtureSimpleStridedSliceFixture122 SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", 123 "[ 0, 0, 0, 0 ]", 124 "[ 3, 2, 3, 1 ]", 125 "[ 2, 2, 2, 1 ]", 126 "0", 127 "0", 128 "0", 129 "0", 130 "0", 131 "NCHW", 132 "[ 2, 1, 2, 1 ]", 133 "Float32") {} 134 }; 135 136 TEST_CASE_FIXTURE(SimpleStridedSliceFixture, "SimpleStridedSliceFloat32") 137 { 138 RunTest<4, armnn::DataType::Float32>(0, 139 { 140 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 141 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 142 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f 143 }, 144 { 145 1.0f, 1.0f, 5.0f, 5.0f 146 }); 147 } 148 149 struct StridedSliceMaskFixture : StridedSliceFixture 150 { StridedSliceMaskFixtureStridedSliceMaskFixture151 StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", 152 "[ 1, 1, 1, 1 ]", 153 "[ 1, 1, 1, 1 ]", 154 "[ 1, 1, 1, 1 ]", 155 "15", 156 "15", 157 "0", 158 "0", 159 "0", 160 "NCHW", 161 "[ 3, 2, 3, 1 ]", 162 "Float32") {} 163 }; 164 165 TEST_CASE_FIXTURE(StridedSliceMaskFixture, "StridedSliceMaskFloat32") 166 { 167 RunTest<4, armnn::DataType::Float32>(0, 168 { 169 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 170 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 171 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f 172 }, 173 { 174 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 175 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 176 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f 177 }); 178 } 179 180 } 181