1 // 2 // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersSerializeFixture.hpp" 7 #include <armnnDeserializer/IDeserializer.hpp> 8 9 #include <string> 10 11 TEST_SUITE("Deserializer_GatherNd") 12 { 13 struct GatherNdFixture : public ParserFlatbuffersSerializeFixture 14 { GatherNdFixtureGatherNdFixture15 explicit GatherNdFixture(const std::string& paramsShape, 16 const std::string& indicesShape, 17 const std::string& outputShape, 18 const std::string& indicesData, 19 const std::string dataType, 20 const std::string constDataType) 21 { 22 m_JsonString = R"( 23 { 24 inputIds: [0], 25 outputIds: [3], 26 layers: [ 27 { 28 layer_type: "InputLayer", 29 layer: { 30 base: { 31 layerBindingId: 0, 32 base: { 33 index: 0, 34 layerName: "InputLayer", 35 layerType: "Input", 36 inputSlots: [{ 37 index: 0, 38 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 39 }], 40 outputSlots: [ { 41 index: 0, 42 tensorInfo: { 43 dimensions: )" + paramsShape + R"(, 44 dataType: )" + dataType + R"( 45 }}] 46 } 47 }}}, 48 { 49 layer_type: "ConstantLayer", 50 layer: { 51 base: { 52 index:1, 53 layerName: "ConstantLayer", 54 layerType: "Constant", 55 outputSlots: [ { 56 index: 0, 57 tensorInfo: { 58 dimensions: )" + indicesShape + R"(, 59 dataType: "Signed32", 60 }, 61 }], 62 }, 63 input: { 64 info: { 65 dimensions: )" + indicesShape + R"(, 66 dataType: )" + dataType + R"( 67 }, 68 data_type: )" + constDataType + R"(, 69 data: { 70 data: )" + indicesData + R"(, 71 } } 72 },}, 73 { 74 layer_type: "GatherNdLayer", 75 layer: { 76 base: { 77 index: 2, 78 layerName: "GatherNdLayer", 79 layerType: "GatherNd", 80 inputSlots: [ 81 { 82 index: 0, 83 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 84 }, 85 { 86 index: 1, 87 connection: {sourceLayerIndex:1, outputSlotIndex:0 } 88 }], 89 outputSlots: [ { 90 index: 0, 91 tensorInfo: { 92 dimensions: )" + outputShape + R"(, 93 dataType: )" + dataType + R"( 94 95 }}]}, 96 }}, 97 { 98 layer_type: "OutputLayer", 99 layer: { 100 base:{ 101 layerBindingId: 0, 102 base: { 103 index: 3, 104 layerName: "OutputLayer", 105 layerType: "Output", 106 inputSlots: [{ 107 index: 0, 108 connection: {sourceLayerIndex:2, outputSlotIndex:0 }, 109 }], 110 outputSlots: [ { 111 index: 0, 112 tensorInfo: { 113 dimensions: )" + outputShape + R"(, 114 dataType: )" + dataType + R"( 115 }, 116 }], 117 }}}, 118 }], 119 featureVersions: { 120 weightsLayoutScheme: 1, 121 } 122 } )"; 123 124 Setup(); 125 } 126 }; 127 128 struct SimpleGatherNdFixtureFloat32 : GatherNdFixture 129 { SimpleGatherNdFixtureFloat32SimpleGatherNdFixtureFloat32130 SimpleGatherNdFixtureFloat32() : GatherNdFixture("[ 6, 3 ]", "[ 3, 1 ]", "[ 3, 3 ]", 131 "[ 5, 1, 0 ]", "Float32", "IntData") {} 132 }; 133 134 TEST_CASE_FIXTURE(SimpleGatherNdFixtureFloat32, "GatherNdFloat32") 135 { 136 RunTest<4, armnn::DataType::Float32>(0, 137 {{"InputLayer", { 1, 2, 3, 138 4, 5, 6, 139 7, 8, 9, 140 10, 11, 12, 141 13, 14, 15, 142 16, 17, 18 }}}, 143 {{"OutputLayer", { 16, 17, 18, 144 4, 5, 6, 145 1, 2, 3}}}); 146 } 147 148 } 149 150