1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersSerializeFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include <armnnDeserializer/IDeserializer.hpp> 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <string> 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("Deserializer_Mean") 12*89c4ff92SAndroid Build Coastguard Worker { 13*89c4ff92SAndroid Build Coastguard Worker struct MeanFixture : public ParserFlatbuffersSerializeFixture 14*89c4ff92SAndroid Build Coastguard Worker { MeanFixtureMeanFixture15*89c4ff92SAndroid Build Coastguard Worker explicit MeanFixture(const std::string &inputShape, 16*89c4ff92SAndroid Build Coastguard Worker const std::string &outputShape, 17*89c4ff92SAndroid Build Coastguard Worker const std::string &axis, 18*89c4ff92SAndroid Build Coastguard Worker const std::string &dataType) 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker inputIds: [0], 23*89c4ff92SAndroid Build Coastguard Worker outputIds: [2], 24*89c4ff92SAndroid Build Coastguard Worker layers: [ 25*89c4ff92SAndroid Build Coastguard Worker { 26*89c4ff92SAndroid Build Coastguard Worker layer_type: "InputLayer", 27*89c4ff92SAndroid Build Coastguard Worker layer: { 28*89c4ff92SAndroid Build Coastguard Worker base: { 29*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 0, 30*89c4ff92SAndroid Build Coastguard Worker base: { 31*89c4ff92SAndroid Build Coastguard Worker index: 0, 32*89c4ff92SAndroid Build Coastguard Worker layerName: "InputLayer", 33*89c4ff92SAndroid Build Coastguard Worker layerType: "Input", 34*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 35*89c4ff92SAndroid Build Coastguard Worker index: 0, 36*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 37*89c4ff92SAndroid Build Coastguard Worker }], 38*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 39*89c4ff92SAndroid Build Coastguard Worker index: 0, 40*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 41*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + inputShape + R"(, 42*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 43*89c4ff92SAndroid Build Coastguard Worker } 44*89c4ff92SAndroid Build Coastguard Worker }] 45*89c4ff92SAndroid Build Coastguard Worker } 46*89c4ff92SAndroid Build Coastguard Worker } 47*89c4ff92SAndroid Build Coastguard Worker } 48*89c4ff92SAndroid Build Coastguard Worker }, 49*89c4ff92SAndroid Build Coastguard Worker { 50*89c4ff92SAndroid Build Coastguard Worker layer_type: "MeanLayer", 51*89c4ff92SAndroid Build Coastguard Worker layer: { 52*89c4ff92SAndroid Build Coastguard Worker base: { 53*89c4ff92SAndroid Build Coastguard Worker index: 1, 54*89c4ff92SAndroid Build Coastguard Worker layerName: "MeanLayer", 55*89c4ff92SAndroid Build Coastguard Worker layerType: "Mean", 56*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 57*89c4ff92SAndroid Build Coastguard Worker index: 0, 58*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:0, outputSlotIndex:0 }, 59*89c4ff92SAndroid Build Coastguard Worker }], 60*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 61*89c4ff92SAndroid Build Coastguard Worker index: 0, 62*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 63*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + outputShape + R"(, 64*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 65*89c4ff92SAndroid Build Coastguard Worker } 66*89c4ff92SAndroid Build Coastguard Worker }] 67*89c4ff92SAndroid Build Coastguard Worker }, 68*89c4ff92SAndroid Build Coastguard Worker descriptor: { 69*89c4ff92SAndroid Build Coastguard Worker axis: )" + axis + R"(, 70*89c4ff92SAndroid Build Coastguard Worker keepDims: true 71*89c4ff92SAndroid Build Coastguard Worker } 72*89c4ff92SAndroid Build Coastguard Worker } 73*89c4ff92SAndroid Build Coastguard Worker }, 74*89c4ff92SAndroid Build Coastguard Worker { 75*89c4ff92SAndroid Build Coastguard Worker layer_type: "OutputLayer", 76*89c4ff92SAndroid Build Coastguard Worker layer: { 77*89c4ff92SAndroid Build Coastguard Worker base:{ 78*89c4ff92SAndroid Build Coastguard Worker layerBindingId: 2, 79*89c4ff92SAndroid Build Coastguard Worker base: { 80*89c4ff92SAndroid Build Coastguard Worker index: 2, 81*89c4ff92SAndroid Build Coastguard Worker layerName: "OutputLayer", 82*89c4ff92SAndroid Build Coastguard Worker layerType: "Output", 83*89c4ff92SAndroid Build Coastguard Worker inputSlots: [{ 84*89c4ff92SAndroid Build Coastguard Worker index: 0, 85*89c4ff92SAndroid Build Coastguard Worker connection: {sourceLayerIndex:1, outputSlotIndex:0 }, 86*89c4ff92SAndroid Build Coastguard Worker }], 87*89c4ff92SAndroid Build Coastguard Worker outputSlots: [{ 88*89c4ff92SAndroid Build Coastguard Worker index: 0, 89*89c4ff92SAndroid Build Coastguard Worker tensorInfo: { 90*89c4ff92SAndroid Build Coastguard Worker dimensions: )" + outputShape + R"(, 91*89c4ff92SAndroid Build Coastguard Worker dataType: )" + dataType + R"( 92*89c4ff92SAndroid Build Coastguard Worker }, 93*89c4ff92SAndroid Build Coastguard Worker }], 94*89c4ff92SAndroid Build Coastguard Worker } 95*89c4ff92SAndroid Build Coastguard Worker } 96*89c4ff92SAndroid Build Coastguard Worker }, 97*89c4ff92SAndroid Build Coastguard Worker } 98*89c4ff92SAndroid Build Coastguard Worker ] 99*89c4ff92SAndroid Build Coastguard Worker } 100*89c4ff92SAndroid Build Coastguard Worker )"; 101*89c4ff92SAndroid Build Coastguard Worker Setup(); 102*89c4ff92SAndroid Build Coastguard Worker } 103*89c4ff92SAndroid Build Coastguard Worker }; 104*89c4ff92SAndroid Build Coastguard Worker 105*89c4ff92SAndroid Build Coastguard Worker struct SimpleMeanFixture : MeanFixture 106*89c4ff92SAndroid Build Coastguard Worker { SimpleMeanFixtureSimpleMeanFixture107*89c4ff92SAndroid Build Coastguard Worker SimpleMeanFixture() 108*89c4ff92SAndroid Build Coastguard Worker : MeanFixture("[ 1, 1, 3, 2 ]", // inputShape 109*89c4ff92SAndroid Build Coastguard Worker "[ 1, 1, 1, 2 ]", // outputShape 110*89c4ff92SAndroid Build Coastguard Worker "[ 2 ]", // axis 111*89c4ff92SAndroid Build Coastguard Worker "Float32") // dataType 112*89c4ff92SAndroid Build Coastguard Worker {} 113*89c4ff92SAndroid Build Coastguard Worker }; 114*89c4ff92SAndroid Build Coastguard Worker 115*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(SimpleMeanFixture, "SimpleMean") 116*89c4ff92SAndroid Build Coastguard Worker { 117*89c4ff92SAndroid Build Coastguard Worker RunTest<4, armnn::DataType::Float32>( 118*89c4ff92SAndroid Build Coastguard Worker 0, 119*89c4ff92SAndroid Build Coastguard Worker {{"InputLayer", { 1.0f, 1.0f, 2.0f, 2.0f, 3.0f, 3.0f }}}, 120*89c4ff92SAndroid Build Coastguard Worker {{"OutputLayer", { 2.0f, 2.0f }}}); 121*89c4ff92SAndroid Build Coastguard Worker } 122*89c4ff92SAndroid Build Coastguard Worker 123*89c4ff92SAndroid Build Coastguard Worker }