xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/GetInputsOutputs.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "../OnnxParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include  "ParserPrototxtFixture.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include <onnx/onnx.pb.h>
9*89c4ff92SAndroid Build Coastguard Worker #include "google/protobuf/stubs/logging.h"
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker using ModelPtr = std::unique_ptr<onnx::ModelProto>;
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("OnnxParser_GetInputsOutputs")
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
16*89c4ff92SAndroid Build Coastguard Worker {
GetInputsOutputsMainFixtureGetInputsOutputsMainFixture17*89c4ff92SAndroid Build Coastguard Worker     explicit GetInputsOutputsMainFixture()
18*89c4ff92SAndroid Build Coastguard Worker     {
19*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
20*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
21*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
22*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
23*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
24*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
25*89c4ff92SAndroid Build Coastguard Worker                    graph {
26*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
27*89c4ff92SAndroid Build Coastguard Worker                      input {
28*89c4ff92SAndroid Build Coastguard Worker                         name: "Input"
29*89c4ff92SAndroid Build Coastguard Worker                         type {
30*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
31*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
32*89c4ff92SAndroid Build Coastguard Worker                             shape {
33*89c4ff92SAndroid Build Coastguard Worker                               dim {
34*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 4
35*89c4ff92SAndroid Build Coastguard Worker                               }
36*89c4ff92SAndroid Build Coastguard Worker                             }
37*89c4ff92SAndroid Build Coastguard Worker                           }
38*89c4ff92SAndroid Build Coastguard Worker                         }
39*89c4ff92SAndroid Build Coastguard Worker                       }
40*89c4ff92SAndroid Build Coastguard Worker                      node {
41*89c4ff92SAndroid Build Coastguard Worker                          input: "Input"
42*89c4ff92SAndroid Build Coastguard Worker                          output: "Output"
43*89c4ff92SAndroid Build Coastguard Worker                          name: "ActivationLayer"
44*89c4ff92SAndroid Build Coastguard Worker                          op_type: "Relu"
45*89c4ff92SAndroid Build Coastguard Worker                     }
46*89c4ff92SAndroid Build Coastguard Worker                       output {
47*89c4ff92SAndroid Build Coastguard Worker                           name: "Output"
48*89c4ff92SAndroid Build Coastguard Worker                           type {
49*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
50*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
51*89c4ff92SAndroid Build Coastguard Worker                                shape {
52*89c4ff92SAndroid Build Coastguard Worker                                    dim {
53*89c4ff92SAndroid Build Coastguard Worker                                        dim_value: 4
54*89c4ff92SAndroid Build Coastguard Worker                                    }
55*89c4ff92SAndroid Build Coastguard Worker                                }
56*89c4ff92SAndroid Build Coastguard Worker                             }
57*89c4ff92SAndroid Build Coastguard Worker                          }
58*89c4ff92SAndroid Build Coastguard Worker                       }
59*89c4ff92SAndroid Build Coastguard Worker                     }
60*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
61*89c4ff92SAndroid Build Coastguard Worker                       version: 7
62*89c4ff92SAndroid Build Coastguard Worker                     })";
63*89c4ff92SAndroid Build Coastguard Worker         Setup();
64*89c4ff92SAndroid Build Coastguard Worker     }
65*89c4ff92SAndroid Build Coastguard Worker };
66*89c4ff92SAndroid Build Coastguard Worker 
67*89c4ff92SAndroid Build Coastguard Worker 
68*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetInput")
69*89c4ff92SAndroid Build Coastguard Worker {
70*89c4ff92SAndroid Build Coastguard Worker     ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
71*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
72*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(1, tensors.size());
73*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ("Input", tensors[0]);
74*89c4ff92SAndroid Build Coastguard Worker 
75*89c4ff92SAndroid Build Coastguard Worker }
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetOutput")
78*89c4ff92SAndroid Build Coastguard Worker {
79*89c4ff92SAndroid Build Coastguard Worker     ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
80*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetOutputs(model);
81*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(1, tensors.size());
82*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ("Output", tensors[0]);
83*89c4ff92SAndroid Build Coastguard Worker }
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
86*89c4ff92SAndroid Build Coastguard Worker {
GetEmptyInputsOutputsFixtureGetEmptyInputsOutputsFixture87*89c4ff92SAndroid Build Coastguard Worker     GetEmptyInputsOutputsFixture()
88*89c4ff92SAndroid Build Coastguard Worker     {
89*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
90*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
91*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK "
92*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1 "
93*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk "
94*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
95*89c4ff92SAndroid Build Coastguard Worker                    graph {
96*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph "
97*89c4ff92SAndroid Build Coastguard Worker                      node {
98*89c4ff92SAndroid Build Coastguard Worker                         output:  "Output"
99*89c4ff92SAndroid Build Coastguard Worker                         attribute {
100*89c4ff92SAndroid Build Coastguard Worker                           name: "value"
101*89c4ff92SAndroid Build Coastguard Worker                           t {
102*89c4ff92SAndroid Build Coastguard Worker                               dims: 7
103*89c4ff92SAndroid Build Coastguard Worker                               data_type: 1
104*89c4ff92SAndroid Build Coastguard Worker                               float_data: 0.0
105*89c4ff92SAndroid Build Coastguard Worker                               float_data: 1.0
106*89c4ff92SAndroid Build Coastguard Worker                               float_data: 2.0
107*89c4ff92SAndroid Build Coastguard Worker                               float_data: 3.0
108*89c4ff92SAndroid Build Coastguard Worker                               float_data: 4.0
109*89c4ff92SAndroid Build Coastguard Worker                               float_data: 5.0
110*89c4ff92SAndroid Build Coastguard Worker                               float_data: 6.0
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker                           }
113*89c4ff92SAndroid Build Coastguard Worker                           type: 1
114*89c4ff92SAndroid Build Coastguard Worker                         }
115*89c4ff92SAndroid Build Coastguard Worker                         name:  "constantNode"
116*89c4ff92SAndroid Build Coastguard Worker                         op_type:  "Constant"
117*89c4ff92SAndroid Build Coastguard Worker                       }
118*89c4ff92SAndroid Build Coastguard Worker                       output {
119*89c4ff92SAndroid Build Coastguard Worker                           name:  "Output"
120*89c4ff92SAndroid Build Coastguard Worker                           type {
121*89c4ff92SAndroid Build Coastguard Worker                              tensor_type {
122*89c4ff92SAndroid Build Coastguard Worker                                elem_type: 1
123*89c4ff92SAndroid Build Coastguard Worker                                shape {
124*89c4ff92SAndroid Build Coastguard Worker                                  dim {
125*89c4ff92SAndroid Build Coastguard Worker                                     dim_value: 7
126*89c4ff92SAndroid Build Coastguard Worker                                  }
127*89c4ff92SAndroid Build Coastguard Worker                                }
128*89c4ff92SAndroid Build Coastguard Worker                              }
129*89c4ff92SAndroid Build Coastguard Worker                           }
130*89c4ff92SAndroid Build Coastguard Worker                       }
131*89c4ff92SAndroid Build Coastguard Worker                    }
132*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
133*89c4ff92SAndroid Build Coastguard Worker                       version: 7
134*89c4ff92SAndroid Build Coastguard Worker                     })";
135*89c4ff92SAndroid Build Coastguard Worker         Setup();
136*89c4ff92SAndroid Build Coastguard Worker     }
137*89c4ff92SAndroid Build Coastguard Worker };
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyInputs")
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker     ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
142*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
143*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(0, tensors.size());
144*89c4ff92SAndroid Build Coastguard Worker }
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetInputsNullModel")
147*89c4ff92SAndroid Build Coastguard Worker {
148*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString(""), armnn::InvalidArgumentException);
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker TEST_CASE("GetOutputsNullModel")
152*89c4ff92SAndroid Build Coastguard Worker {
153*89c4ff92SAndroid Build Coastguard Worker     auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf
154*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString("nknnk"), armnn::ParseException);
155*89c4ff92SAndroid Build Coastguard Worker }
156*89c4ff92SAndroid Build Coastguard Worker 
157*89c4ff92SAndroid Build Coastguard Worker struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
158*89c4ff92SAndroid Build Coastguard Worker {
GetInputsMultipleFixtureGetInputsMultipleFixture159*89c4ff92SAndroid Build Coastguard Worker     GetInputsMultipleFixture() {
160*89c4ff92SAndroid Build Coastguard Worker 
161*89c4ff92SAndroid Build Coastguard Worker         m_Prototext = R"(
162*89c4ff92SAndroid Build Coastguard Worker                    ir_version: 3
163*89c4ff92SAndroid Build Coastguard Worker                    producer_name:  "CNTK"
164*89c4ff92SAndroid Build Coastguard Worker                    producer_version:  "2.5.1"
165*89c4ff92SAndroid Build Coastguard Worker                    domain:  "ai.cntk"
166*89c4ff92SAndroid Build Coastguard Worker                    model_version: 1
167*89c4ff92SAndroid Build Coastguard Worker                    graph {
168*89c4ff92SAndroid Build Coastguard Worker                      name:  "CNTKGraph"
169*89c4ff92SAndroid Build Coastguard Worker                      input {
170*89c4ff92SAndroid Build Coastguard Worker                         name: "Input0"
171*89c4ff92SAndroid Build Coastguard Worker                         type {
172*89c4ff92SAndroid Build Coastguard Worker                           tensor_type {
173*89c4ff92SAndroid Build Coastguard Worker                             elem_type: 1
174*89c4ff92SAndroid Build Coastguard Worker                             shape {
175*89c4ff92SAndroid Build Coastguard Worker                               dim {
176*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
177*89c4ff92SAndroid Build Coastguard Worker                               }
178*89c4ff92SAndroid Build Coastguard Worker                               dim {
179*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
180*89c4ff92SAndroid Build Coastguard Worker                               }
181*89c4ff92SAndroid Build Coastguard Worker                               dim {
182*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 1
183*89c4ff92SAndroid Build Coastguard Worker                               }
184*89c4ff92SAndroid Build Coastguard Worker                               dim {
185*89c4ff92SAndroid Build Coastguard Worker                                 dim_value: 4
186*89c4ff92SAndroid Build Coastguard Worker                               }
187*89c4ff92SAndroid Build Coastguard Worker                             }
188*89c4ff92SAndroid Build Coastguard Worker                           }
189*89c4ff92SAndroid Build Coastguard Worker                         }
190*89c4ff92SAndroid Build Coastguard Worker                       }
191*89c4ff92SAndroid Build Coastguard Worker                       input {
192*89c4ff92SAndroid Build Coastguard Worker                          name: "Input1"
193*89c4ff92SAndroid Build Coastguard Worker                          type {
194*89c4ff92SAndroid Build Coastguard Worker                            tensor_type {
195*89c4ff92SAndroid Build Coastguard Worker                              elem_type: 1
196*89c4ff92SAndroid Build Coastguard Worker                              shape {
197*89c4ff92SAndroid Build Coastguard Worker                                  dim {
198*89c4ff92SAndroid Build Coastguard Worker                                    dim_value: 4
199*89c4ff92SAndroid Build Coastguard Worker                                  }
200*89c4ff92SAndroid Build Coastguard Worker                              }
201*89c4ff92SAndroid Build Coastguard Worker                            }
202*89c4ff92SAndroid Build Coastguard Worker                          }
203*89c4ff92SAndroid Build Coastguard Worker                        }
204*89c4ff92SAndroid Build Coastguard Worker                        node {
205*89c4ff92SAndroid Build Coastguard Worker                             input: "Input0"
206*89c4ff92SAndroid Build Coastguard Worker                             input: "Input1"
207*89c4ff92SAndroid Build Coastguard Worker                             output: "Output"
208*89c4ff92SAndroid Build Coastguard Worker                             name: "addition"
209*89c4ff92SAndroid Build Coastguard Worker                             op_type: "Add"
210*89c4ff92SAndroid Build Coastguard Worker                             doc_string: ""
211*89c4ff92SAndroid Build Coastguard Worker                             domain: ""
212*89c4ff92SAndroid Build Coastguard Worker                           }
213*89c4ff92SAndroid Build Coastguard Worker                           output {
214*89c4ff92SAndroid Build Coastguard Worker                               name: "Output"
215*89c4ff92SAndroid Build Coastguard Worker                               type {
216*89c4ff92SAndroid Build Coastguard Worker                                  tensor_type {
217*89c4ff92SAndroid Build Coastguard Worker                                    elem_type: 1
218*89c4ff92SAndroid Build Coastguard Worker                                    shape {
219*89c4ff92SAndroid Build Coastguard Worker                                        dim {
220*89c4ff92SAndroid Build Coastguard Worker                                            dim_value: 1
221*89c4ff92SAndroid Build Coastguard Worker                                        }
222*89c4ff92SAndroid Build Coastguard Worker                                        dim {
223*89c4ff92SAndroid Build Coastguard Worker                                            dim_value: 1
224*89c4ff92SAndroid Build Coastguard Worker                                        }
225*89c4ff92SAndroid Build Coastguard Worker                                        dim {
226*89c4ff92SAndroid Build Coastguard Worker                                            dim_value: 1
227*89c4ff92SAndroid Build Coastguard Worker                                        }
228*89c4ff92SAndroid Build Coastguard Worker                                        dim {
229*89c4ff92SAndroid Build Coastguard Worker                                            dim_value: 4
230*89c4ff92SAndroid Build Coastguard Worker                                        }
231*89c4ff92SAndroid Build Coastguard Worker                                    }
232*89c4ff92SAndroid Build Coastguard Worker                                 }
233*89c4ff92SAndroid Build Coastguard Worker                             }
234*89c4ff92SAndroid Build Coastguard Worker                         }
235*89c4ff92SAndroid Build Coastguard Worker                     }
236*89c4ff92SAndroid Build Coastguard Worker                    opset_import {
237*89c4ff92SAndroid Build Coastguard Worker                       version: 7
238*89c4ff92SAndroid Build Coastguard Worker                     })";
239*89c4ff92SAndroid Build Coastguard Worker         Setup();
240*89c4ff92SAndroid Build Coastguard Worker     }
241*89c4ff92SAndroid Build Coastguard Worker };
242*89c4ff92SAndroid Build Coastguard Worker 
243*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(GetInputsMultipleFixture, "GetInputsMultipleInputs")
244*89c4ff92SAndroid Build Coastguard Worker {
245*89c4ff92SAndroid Build Coastguard Worker     ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
246*89c4ff92SAndroid Build Coastguard Worker     std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
247*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(2, tensors.size());
248*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ("Input0", tensors[0]);
249*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ("Input1", tensors[1]);
250*89c4ff92SAndroid Build Coastguard Worker }
251*89c4ff92SAndroid Build Coastguard Worker 
252*89c4ff92SAndroid Build Coastguard Worker }
253