1 // 2 // Copyright © 2017 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 6 #include "ParserFlatbuffersSerializeFixture.hpp" 7 #include <armnnDeserializer/IDeserializer.hpp> 8 9 #include <string> 10 11 TEST_SUITE("Deserializer_Mean") 12 { 13 struct MeanFixture : public ParserFlatbuffersSerializeFixture 14 { MeanFixtureMeanFixture15 explicit MeanFixture(const std::string &inputShape, 16 const std::string &outputShape, 17 const std::string &axis, 18 const std::string &dataType) 19 { 20 m_JsonString = R"( 21 { 22 inputIds: [0], 23 outputIds: [2], 24 layers: [ 25 { 26 layer_type: "InputLayer", 27 layer: { 28 base: { 29 layerBindingId: 0, 30 base: { 31 index: 0, 32 layerName: "InputLayer", 33 layerType: "Input", 34 inputSlots: [{ 35 index: 0, 36 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 37 }], 38 outputSlots: [{ 39 index: 0, 40 tensorInfo: { 41 dimensions: )" + inputShape + R"(, 42 dataType: )" + dataType + R"( 43 } 44 }] 45 } 46 } 47 } 48 }, 49 { 50 layer_type: "MeanLayer", 51 layer: { 52 base: { 53 index: 1, 54 layerName: "MeanLayer", 55 layerType: "Mean", 56 inputSlots: [{ 57 index: 0, 58 connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 59 }], 60 outputSlots: [{ 61 index: 0, 62 tensorInfo: { 63 dimensions: )" + outputShape + R"(, 64 dataType: )" + dataType + R"( 65 } 66 }] 67 }, 68 descriptor: { 69 axis: )" + axis + R"(, 70 keepDims: true 71 } 72 } 73 }, 74 { 75 layer_type: "OutputLayer", 76 layer: { 77 base:{ 78 layerBindingId: 2, 79 base: { 80 index: 2, 81 layerName: "OutputLayer", 82 layerType: "Output", 83 inputSlots: [{ 84 index: 0, 85 connection: {sourceLayerIndex:1, outputSlotIndex:0 }, 86 }], 87 outputSlots: [{ 88 index: 0, 89 tensorInfo: { 90 dimensions: )" + outputShape + R"(, 91 dataType: )" + dataType + R"( 92 }, 93 }], 94 } 95 } 96 }, 97 } 98 ] 99 } 100 )"; 101 Setup(); 102 } 103 }; 104 105 struct SimpleMeanFixture : MeanFixture 106 { SimpleMeanFixtureSimpleMeanFixture107 SimpleMeanFixture() 108 : MeanFixture("[ 1, 1, 3, 2 ]", // inputShape 109 "[ 1, 1, 1, 2 ]", // outputShape 110 "[ 2 ]", // axis 111 "Float32") // dataType 112 {} 113 }; 114 115 TEST_CASE_FIXTURE(SimpleMeanFixture, "SimpleMean") 116 { 117 RunTest<4, armnn::DataType::Float32>( 118 0, 119 {{"InputLayer", { 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f }}}, 120 {{"OutputLayer", { 2.0f, 2.0f }}}); 121 } 122 123 }