1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "../OnnxParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker #include "ParserPrototxtFixture.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include <onnx/onnx.pb.h> 9*89c4ff92SAndroid Build Coastguard Worker #include "google/protobuf/stubs/logging.h" 10*89c4ff92SAndroid Build Coastguard Worker 11*89c4ff92SAndroid Build Coastguard Worker using ModelPtr = std::unique_ptr<onnx::ModelProto>; 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_GetInputsOutputs") 14*89c4ff92SAndroid Build Coastguard Worker { 15*89c4ff92SAndroid Build Coastguard Worker struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 16*89c4ff92SAndroid Build Coastguard Worker { GetInputsOutputsMainFixtureGetInputsOutputsMainFixture17*89c4ff92SAndroid Build Coastguard Worker explicit GetInputsOutputsMainFixture() 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 20*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 21*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 22*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 23*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 24*89c4ff92SAndroid Build Coastguard Worker model_version: 1 25*89c4ff92SAndroid Build Coastguard Worker graph { 26*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 27*89c4ff92SAndroid Build Coastguard Worker input { 28*89c4ff92SAndroid Build Coastguard Worker name: "Input" 29*89c4ff92SAndroid Build Coastguard Worker type { 30*89c4ff92SAndroid Build Coastguard Worker tensor_type { 31*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 32*89c4ff92SAndroid Build Coastguard Worker shape { 33*89c4ff92SAndroid Build Coastguard Worker dim { 34*89c4ff92SAndroid Build Coastguard Worker dim_value: 4 35*89c4ff92SAndroid Build Coastguard Worker } 36*89c4ff92SAndroid Build Coastguard Worker } 37*89c4ff92SAndroid Build Coastguard Worker } 38*89c4ff92SAndroid Build Coastguard Worker } 39*89c4ff92SAndroid Build Coastguard Worker } 40*89c4ff92SAndroid Build Coastguard Worker node { 41*89c4ff92SAndroid Build Coastguard Worker input: "Input" 42*89c4ff92SAndroid Build Coastguard Worker output: "Output" 43*89c4ff92SAndroid Build Coastguard Worker name: "ActivationLayer" 44*89c4ff92SAndroid Build Coastguard Worker op_type: "Relu" 45*89c4ff92SAndroid Build Coastguard Worker } 46*89c4ff92SAndroid Build Coastguard Worker output { 47*89c4ff92SAndroid Build Coastguard Worker name: "Output" 48*89c4ff92SAndroid Build Coastguard Worker type { 49*89c4ff92SAndroid Build Coastguard Worker tensor_type { 50*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 51*89c4ff92SAndroid Build Coastguard Worker shape { 52*89c4ff92SAndroid Build Coastguard Worker dim { 53*89c4ff92SAndroid Build Coastguard Worker dim_value: 4 54*89c4ff92SAndroid Build Coastguard Worker } 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker } 57*89c4ff92SAndroid Build Coastguard Worker } 58*89c4ff92SAndroid Build Coastguard Worker } 59*89c4ff92SAndroid Build Coastguard Worker } 60*89c4ff92SAndroid Build Coastguard Worker opset_import { 61*89c4ff92SAndroid Build Coastguard Worker version: 7 62*89c4ff92SAndroid Build Coastguard Worker })"; 63*89c4ff92SAndroid Build Coastguard Worker Setup(); 64*89c4ff92SAndroid Build Coastguard Worker } 65*89c4ff92SAndroid Build Coastguard Worker }; 66*89c4ff92SAndroid Build Coastguard Worker 67*89c4ff92SAndroid Build Coastguard Worker 68*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetInput") 69*89c4ff92SAndroid Build Coastguard Worker { 70*89c4ff92SAndroid Build Coastguard Worker ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); 71*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); 72*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, tensors.size()); 73*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ("Input", tensors[0]); 74*89c4ff92SAndroid Build Coastguard Worker 75*89c4ff92SAndroid Build Coastguard Worker } 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetOutput") 78*89c4ff92SAndroid Build Coastguard Worker { 79*89c4ff92SAndroid Build Coastguard Worker ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); 80*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetOutputs(model); 81*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(1, tensors.size()); 82*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ("Output", tensors[0]); 83*89c4ff92SAndroid Build Coastguard Worker } 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 86*89c4ff92SAndroid Build Coastguard Worker { GetEmptyInputsOutputsFixtureGetEmptyInputsOutputsFixture87*89c4ff92SAndroid Build Coastguard Worker GetEmptyInputsOutputsFixture() 88*89c4ff92SAndroid Build Coastguard Worker { 89*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 90*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 91*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK " 92*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1 " 93*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk " 94*89c4ff92SAndroid Build Coastguard Worker model_version: 1 95*89c4ff92SAndroid Build Coastguard Worker graph { 96*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph " 97*89c4ff92SAndroid Build Coastguard Worker node { 98*89c4ff92SAndroid Build Coastguard Worker output: "Output" 99*89c4ff92SAndroid Build Coastguard Worker attribute { 100*89c4ff92SAndroid Build Coastguard Worker name: "value" 101*89c4ff92SAndroid Build Coastguard Worker t { 102*89c4ff92SAndroid Build Coastguard Worker dims: 7 103*89c4ff92SAndroid Build Coastguard Worker data_type: 1 104*89c4ff92SAndroid Build Coastguard Worker float_data: 0.0 105*89c4ff92SAndroid Build Coastguard Worker float_data: 1.0 106*89c4ff92SAndroid Build Coastguard Worker float_data: 2.0 107*89c4ff92SAndroid Build Coastguard Worker float_data: 3.0 108*89c4ff92SAndroid Build Coastguard Worker float_data: 4.0 109*89c4ff92SAndroid Build Coastguard Worker float_data: 5.0 110*89c4ff92SAndroid Build Coastguard Worker float_data: 6.0 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker } 113*89c4ff92SAndroid Build Coastguard Worker type: 1 114*89c4ff92SAndroid Build Coastguard Worker } 115*89c4ff92SAndroid Build Coastguard Worker name: "constantNode" 116*89c4ff92SAndroid Build Coastguard Worker op_type: "Constant" 117*89c4ff92SAndroid Build Coastguard Worker } 118*89c4ff92SAndroid Build Coastguard Worker output { 119*89c4ff92SAndroid Build Coastguard Worker name: "Output" 120*89c4ff92SAndroid Build Coastguard Worker type { 121*89c4ff92SAndroid Build Coastguard Worker tensor_type { 122*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 123*89c4ff92SAndroid Build Coastguard Worker shape { 124*89c4ff92SAndroid Build Coastguard Worker dim { 125*89c4ff92SAndroid Build Coastguard Worker dim_value: 7 126*89c4ff92SAndroid Build Coastguard Worker } 127*89c4ff92SAndroid Build Coastguard Worker } 128*89c4ff92SAndroid Build Coastguard Worker } 129*89c4ff92SAndroid Build Coastguard Worker } 130*89c4ff92SAndroid Build Coastguard Worker } 131*89c4ff92SAndroid Build Coastguard Worker } 132*89c4ff92SAndroid Build Coastguard Worker opset_import { 133*89c4ff92SAndroid Build Coastguard Worker version: 7 134*89c4ff92SAndroid Build Coastguard Worker })"; 135*89c4ff92SAndroid Build Coastguard Worker Setup(); 136*89c4ff92SAndroid Build Coastguard Worker } 137*89c4ff92SAndroid Build Coastguard Worker }; 138*89c4ff92SAndroid Build Coastguard Worker 139*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyInputs") 140*89c4ff92SAndroid Build Coastguard Worker { 141*89c4ff92SAndroid Build Coastguard Worker ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); 142*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); 143*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(0, tensors.size()); 144*89c4ff92SAndroid Build Coastguard Worker } 145*89c4ff92SAndroid Build Coastguard Worker 146*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetInputsNullModel") 147*89c4ff92SAndroid Build Coastguard Worker { 148*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString(""), armnn::InvalidArgumentException); 149*89c4ff92SAndroid Build Coastguard Worker } 150*89c4ff92SAndroid Build Coastguard Worker 151*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetOutputsNullModel") 152*89c4ff92SAndroid Build Coastguard Worker { 153*89c4ff92SAndroid Build Coastguard Worker auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf 154*89c4ff92SAndroid Build Coastguard Worker CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString("nknnk"), armnn::ParseException); 155*89c4ff92SAndroid Build Coastguard Worker } 156*89c4ff92SAndroid Build Coastguard Worker 157*89c4ff92SAndroid Build Coastguard Worker struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser> 158*89c4ff92SAndroid Build Coastguard Worker { GetInputsMultipleFixtureGetInputsMultipleFixture159*89c4ff92SAndroid Build Coastguard Worker GetInputsMultipleFixture() { 160*89c4ff92SAndroid Build Coastguard Worker 161*89c4ff92SAndroid Build Coastguard Worker m_Prototext = R"( 162*89c4ff92SAndroid Build Coastguard Worker ir_version: 3 163*89c4ff92SAndroid Build Coastguard Worker producer_name: "CNTK" 164*89c4ff92SAndroid Build Coastguard Worker producer_version: "2.5.1" 165*89c4ff92SAndroid Build Coastguard Worker domain: "ai.cntk" 166*89c4ff92SAndroid Build Coastguard Worker model_version: 1 167*89c4ff92SAndroid Build Coastguard Worker graph { 168*89c4ff92SAndroid Build Coastguard Worker name: "CNTKGraph" 169*89c4ff92SAndroid Build Coastguard Worker input { 170*89c4ff92SAndroid Build Coastguard Worker name: "Input0" 171*89c4ff92SAndroid Build Coastguard Worker type { 172*89c4ff92SAndroid Build Coastguard Worker tensor_type { 173*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 174*89c4ff92SAndroid Build Coastguard Worker shape { 175*89c4ff92SAndroid Build Coastguard Worker dim { 176*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 177*89c4ff92SAndroid Build Coastguard Worker } 178*89c4ff92SAndroid Build Coastguard Worker dim { 179*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 180*89c4ff92SAndroid Build Coastguard Worker } 181*89c4ff92SAndroid Build Coastguard Worker dim { 182*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 183*89c4ff92SAndroid Build Coastguard Worker } 184*89c4ff92SAndroid Build Coastguard Worker dim { 185*89c4ff92SAndroid Build Coastguard Worker dim_value: 4 186*89c4ff92SAndroid Build Coastguard Worker } 187*89c4ff92SAndroid Build Coastguard Worker } 188*89c4ff92SAndroid Build Coastguard Worker } 189*89c4ff92SAndroid Build Coastguard Worker } 190*89c4ff92SAndroid Build Coastguard Worker } 191*89c4ff92SAndroid Build Coastguard Worker input { 192*89c4ff92SAndroid Build Coastguard Worker name: "Input1" 193*89c4ff92SAndroid Build Coastguard Worker type { 194*89c4ff92SAndroid Build Coastguard Worker tensor_type { 195*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 196*89c4ff92SAndroid Build Coastguard Worker shape { 197*89c4ff92SAndroid Build Coastguard Worker dim { 198*89c4ff92SAndroid Build Coastguard Worker dim_value: 4 199*89c4ff92SAndroid Build Coastguard Worker } 200*89c4ff92SAndroid Build Coastguard Worker } 201*89c4ff92SAndroid Build Coastguard Worker } 202*89c4ff92SAndroid Build Coastguard Worker } 203*89c4ff92SAndroid Build Coastguard Worker } 204*89c4ff92SAndroid Build Coastguard Worker node { 205*89c4ff92SAndroid Build Coastguard Worker input: "Input0" 206*89c4ff92SAndroid Build Coastguard Worker input: "Input1" 207*89c4ff92SAndroid Build Coastguard Worker output: "Output" 208*89c4ff92SAndroid Build Coastguard Worker name: "addition" 209*89c4ff92SAndroid Build Coastguard Worker op_type: "Add" 210*89c4ff92SAndroid Build Coastguard Worker doc_string: "" 211*89c4ff92SAndroid Build Coastguard Worker domain: "" 212*89c4ff92SAndroid Build Coastguard Worker } 213*89c4ff92SAndroid Build Coastguard Worker output { 214*89c4ff92SAndroid Build Coastguard Worker name: "Output" 215*89c4ff92SAndroid Build Coastguard Worker type { 216*89c4ff92SAndroid Build Coastguard Worker tensor_type { 217*89c4ff92SAndroid Build Coastguard Worker elem_type: 1 218*89c4ff92SAndroid Build Coastguard Worker shape { 219*89c4ff92SAndroid Build Coastguard Worker dim { 220*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 221*89c4ff92SAndroid Build Coastguard Worker } 222*89c4ff92SAndroid Build Coastguard Worker dim { 223*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 224*89c4ff92SAndroid Build Coastguard Worker } 225*89c4ff92SAndroid Build Coastguard Worker dim { 226*89c4ff92SAndroid Build Coastguard Worker dim_value: 1 227*89c4ff92SAndroid Build Coastguard Worker } 228*89c4ff92SAndroid Build Coastguard Worker dim { 229*89c4ff92SAndroid Build Coastguard Worker dim_value: 4 230*89c4ff92SAndroid Build Coastguard Worker } 231*89c4ff92SAndroid Build Coastguard Worker } 232*89c4ff92SAndroid Build Coastguard Worker } 233*89c4ff92SAndroid Build Coastguard Worker } 234*89c4ff92SAndroid Build Coastguard Worker } 235*89c4ff92SAndroid Build Coastguard Worker } 236*89c4ff92SAndroid Build Coastguard Worker opset_import { 237*89c4ff92SAndroid Build Coastguard Worker version: 7 238*89c4ff92SAndroid Build Coastguard Worker })"; 239*89c4ff92SAndroid Build Coastguard Worker Setup(); 240*89c4ff92SAndroid Build Coastguard Worker } 241*89c4ff92SAndroid Build Coastguard Worker }; 242*89c4ff92SAndroid Build Coastguard Worker 243*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsMultipleFixture, "GetInputsMultipleInputs") 244*89c4ff92SAndroid Build Coastguard Worker { 245*89c4ff92SAndroid Build Coastguard Worker ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str()); 246*89c4ff92SAndroid Build Coastguard Worker std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model); 247*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ(2, tensors.size()); 248*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ("Input0", tensors[0]); 249*89c4ff92SAndroid Build Coastguard Worker CHECK_EQ("Input1", tensors[1]); 250*89c4ff92SAndroid Build Coastguard Worker } 251*89c4ff92SAndroid Build Coastguard Worker 252*89c4ff92SAndroid Build Coastguard Worker } 253