1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp> 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <string> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_StridedSlice") 12*89c4ff92SAndroid Build Coastguard Worker { 13*89c4ff92SAndroid Build Coastguard Worker struct StridedSliceFixture : public ParserFlatbuffersSerializeFixture 14*89c4ff92SAndroid Build Coastguard Worker { StridedSliceFixtureStridedSliceFixture15*89c4ff92SAndroid Build Coastguard Worker explicit StridedSliceFixture(const std::string& inputShape, 16*89c4ff92SAndroid Build Coastguard Worker const std::string& begin, 17*89c4ff92SAndroid Build Coastguard Worker const std::string& end, 18*89c4ff92SAndroid Build Coastguard Worker const std::string& stride, 19*89c4ff92SAndroid Build Coastguard Worker const std::string& beginMask, 20*89c4ff92SAndroid Build Coastguard Worker const std::string& endMask, 21*89c4ff92SAndroid Build Coastguard Worker const std::string& shrinkAxisMask, 22*89c4ff92SAndroid Build Coastguard Worker const std::string& ellipsisMask, 23*89c4ff92SAndroid Build Coastguard Worker const std::string& newAxisMask, 24*89c4ff92SAndroid Build Coastguard Worker const std::string& dataLayout, 25*89c4ff92SAndroid Build Coastguard Worker const std::string& outputShape, 26*89c4ff92SAndroid Build Coastguard Worker const std::string& dataType) 27*89c4ff92SAndroid Build Coastguard Worker { 28*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 29*89c4ff92SAndroid Build Coastguard Worker { 30*89c4ff92SAndroid Build Coastguard Worker inputIds: [0], 31*89c4ff92SAndroid Build Coastguard Worker outputIds: [2], 32*89c4ff92SAndroid Build Coastguard Worker layers: [ 33*89c4ff92SAndroid Build Coastguard Worker { 34*89c4ff92SAndroid Build Coastguard Worker layer_type: "InputLayer", 35*89c4ff92SAndroid Build Coastguard Worker layer: { 36*89c4ff92SAndroid Build Coastguard Worker base: { 37*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 0, 38*89c4ff92SAndroid Build Coastguard Worker base: { 39*89c4ff92SAndroid Build Coastguard Worker index: 0, 40*89c4ff92SAndroid Build Coastguard Worker layerName: "InputLayer", 41*89c4ff92SAndroid Build Coastguard Worker layerType: "Input", 42*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 43*89c4ff92SAndroid Build Coastguard Worker index: 0, 44*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 45*89c4ff92SAndroid Build Coastguard Worker }], 46*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 47*89c4ff92SAndroid Build Coastguard Worker index: 0, 48*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 49*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + inputShape + R"(, 50*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker }] 53*89c4ff92SAndroid Build Coastguard Worker } 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker }, 57*89c4ff92SAndroid Build Coastguard Worker { 58*89c4ff92SAndroid Build Coastguard Worker layer_type: "StridedSliceLayer", 59*89c4ff92SAndroid Build Coastguard Worker layer: { 60*89c4ff92SAndroid Build Coastguard Worker base: { 61*89c4ff92SAndroid Build Coastguard Worker index: 1, 62*89c4ff92SAndroid Build Coastguard Worker layerName: "StridedSliceLayer", 63*89c4ff92SAndroid Build Coastguard Worker layerType: "StridedSlice", 64*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 65*89c4ff92SAndroid Build Coastguard Worker index: 0, 66*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 67*89c4ff92SAndroid Build Coastguard Worker }], 68*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 69*89c4ff92SAndroid Build Coastguard Worker index: 0, 70*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 71*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + outputShape + R"(, 72*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 73*89c4ff92SAndroid Build Coastguard Worker } 74*89c4ff92SAndroid Build Coastguard Worker }] 75*89c4ff92SAndroid Build Coastguard Worker }, 76*89c4ff92SAndroid Build Coastguard Worker descriptor: { 77*89c4ff92SAndroid Build Coastguard Worker begin: )" + begin + R"(, 78*89c4ff92SAndroid Build Coastguard Worker end: )" + end + R"(, 79*89c4ff92SAndroid Build Coastguard Worker stride: )" + stride + R"(, 80*89c4ff92SAndroid Build Coastguard Worker beginMask: )" + beginMask + R"(, 81*89c4ff92SAndroid Build Coastguard Worker endMask: )" + endMask + R"(, 82*89c4ff92SAndroid Build Coastguard Worker shrinkAxisMask: )" + shrinkAxisMask + R"(, 83*89c4ff92SAndroid Build Coastguard Worker ellipsisMask: )" + ellipsisMask + R"(, 84*89c4ff92SAndroid Build Coastguard Worker newAxisMask: )" + newAxisMask + R"(, 85*89c4ff92SAndroid Build Coastguard Worker dataLayout: )" + dataLayout + R"(, 86*89c4ff92SAndroid Build Coastguard Worker } 87*89c4ff92SAndroid Build Coastguard Worker } 88*89c4ff92SAndroid Build Coastguard Worker }, 89*89c4ff92SAndroid Build Coastguard Worker { 90*89c4ff92SAndroid Build Coastguard Worker layer_type: "OutputLayer", 91*89c4ff92SAndroid Build Coastguard Worker layer: { 92*89c4ff92SAndroid Build Coastguard Worker base:{ 93*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 2, 94*89c4ff92SAndroid Build Coastguard Worker base: { 95*89c4ff92SAndroid Build Coastguard Worker index: 2, 96*89c4ff92SAndroid Build Coastguard Worker layerName: "OutputLayer", 97*89c4ff92SAndroid Build Coastguard Worker layerType: "Output", 98*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 99*89c4ff92SAndroid Build Coastguard Worker index: 0, 100*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:1, outputSlotIndex:0 }, 101*89c4ff92SAndroid Build Coastguard Worker }], 102*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 103*89c4ff92SAndroid Build Coastguard Worker index: 0, 104*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 105*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + outputShape + R"(, 106*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 107*89c4ff92SAndroid Build Coastguard Worker }, 108*89c4ff92SAndroid Build Coastguard Worker }], 109*89c4ff92SAndroid Build Coastguard Worker } 110*89c4ff92SAndroid Build Coastguard Worker } 111*89c4ff92SAndroid Build Coastguard Worker }, 112*89c4ff92SAndroid Build Coastguard Worker } 113*89c4ff92SAndroid Build Coastguard Worker ] 114*89c4ff92SAndroid Build Coastguard Worker } 115*89c4ff92SAndroid Build Coastguard Worker )"; 116*89c4ff92SAndroid Build Coastguard Worker SetupSingleInputSingleOutput("InputLayer", "OutputLayer"); 117*89c4ff92SAndroid Build Coastguard Worker } 118*89c4ff92SAndroid Build Coastguard Worker }; 119*89c4ff92SAndroid Build Coastguard Worker 120*89c4ff92SAndroid Build Coastguard Worker struct SimpleStridedSliceFixture : StridedSliceFixture 121*89c4ff92SAndroid Build Coastguard Worker { SimpleStridedSliceFixtureSimpleStridedSliceFixture122*89c4ff92SAndroid Build Coastguard Worker SimpleStridedSliceFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", 123*89c4ff92SAndroid Build Coastguard Worker "[ 0, 0, 0, 0 ]", 124*89c4ff92SAndroid Build Coastguard Worker "[ 3, 2, 3, 1 ]", 125*89c4ff92SAndroid Build Coastguard Worker "[ 2, 2, 2, 1 ]", 126*89c4ff92SAndroid Build Coastguard Worker "0", 127*89c4ff92SAndroid Build Coastguard Worker "0", 128*89c4ff92SAndroid Build Coastguard Worker "0", 129*89c4ff92SAndroid Build Coastguard Worker "0", 130*89c4ff92SAndroid Build Coastguard Worker "0", 131*89c4ff92SAndroid Build Coastguard Worker "NCHW", 132*89c4ff92SAndroid Build Coastguard Worker "[ 2, 1, 2, 1 ]", 133*89c4ff92SAndroid Build Coastguard Worker "Float32") {} 134*89c4ff92SAndroid Build Coastguard Worker }; 135*89c4ff92SAndroid Build Coastguard Worker 136*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleStridedSliceFixture, "SimpleStridedSliceFloat32") 137*89c4ff92SAndroid Build Coastguard Worker { 138*89c4ff92SAndroid Build Coastguard Worker RunTest<4, armnn::DataType::Float32>(0, 139*89c4ff92SAndroid Build Coastguard Worker { 140*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 141*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 142*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f 143*89c4ff92SAndroid Build Coastguard Worker }, 144*89c4ff92SAndroid Build Coastguard Worker { 145*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 5.0f, 5.0f 146*89c4ff92SAndroid Build Coastguard Worker }); 147*89c4ff92SAndroid Build Coastguard Worker } 148*89c4ff92SAndroid Build Coastguard Worker 149*89c4ff92SAndroid Build Coastguard Worker struct StridedSliceMaskFixture : StridedSliceFixture 150*89c4ff92SAndroid Build Coastguard Worker { StridedSliceMaskFixtureStridedSliceMaskFixture151*89c4ff92SAndroid Build Coastguard Worker StridedSliceMaskFixture() : StridedSliceFixture("[ 3, 2, 3, 1 ]", 152*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 1 ]", 153*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 1 ]", 154*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 1 ]", 155*89c4ff92SAndroid Build Coastguard Worker "15", 156*89c4ff92SAndroid Build Coastguard Worker "15", 157*89c4ff92SAndroid Build Coastguard Worker "0", 158*89c4ff92SAndroid Build Coastguard Worker "0", 159*89c4ff92SAndroid Build Coastguard Worker "0", 160*89c4ff92SAndroid Build Coastguard Worker "NCHW", 161*89c4ff92SAndroid Build Coastguard Worker "[ 3, 2, 3, 1 ]", 162*89c4ff92SAndroid Build Coastguard Worker "Float32") {} 163*89c4ff92SAndroid Build Coastguard Worker }; 164*89c4ff92SAndroid Build Coastguard Worker 165*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(StridedSliceMaskFixture, "StridedSliceMaskFloat32") 166*89c4ff92SAndroid Build Coastguard Worker { 167*89c4ff92SAndroid Build Coastguard Worker RunTest<4, armnn::DataType::Float32>(0, 168*89c4ff92SAndroid Build Coastguard Worker { 169*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 170*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 171*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f 172*89c4ff92SAndroid Build Coastguard Worker }, 173*89c4ff92SAndroid Build Coastguard Worker { 174*89c4ff92SAndroid Build Coastguard Worker 1.0f, 1.0f, 1.0f, 2.0f, 2.0f, 2.0f, 175*89c4ff92SAndroid Build Coastguard Worker 3.0f, 3.0f, 3.0f, 4.0f, 4.0f, 4.0f, 176*89c4ff92SAndroid Build Coastguard Worker 5.0f, 5.0f, 5.0f, 6.0f, 6.0f, 6.0f 177*89c4ff92SAndroid Build Coastguard Worker }); 178*89c4ff92SAndroid Build Coastguard Worker } 179*89c4ff92SAndroid Build Coastguard Worker 180*89c4ff92SAndroid Build Coastguard Worker } 181