xref: /aosp_15_r20/external/armnn/src/armnnOnnxParser/test/GetInputsOutputs.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "../OnnxParser.hpp"
7 #include  "ParserPrototxtFixture.hpp"
8 #include <onnx/onnx.pb.h>
9 #include "google/protobuf/stubs/logging.h"
10 
11 using ModelPtr = std::unique_ptr<onnx::ModelProto>;
12 
13 TEST_SUITE("OnnxParser_GetInputsOutputs")
14 {
15 struct GetInputsOutputsMainFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
16 {
GetInputsOutputsMainFixtureGetInputsOutputsMainFixture17     explicit GetInputsOutputsMainFixture()
18     {
19         m_Prototext = R"(
20                    ir_version: 3
21                    producer_name:  "CNTK"
22                    producer_version:  "2.5.1"
23                    domain:  "ai.cntk"
24                    model_version: 1
25                    graph {
26                      name:  "CNTKGraph"
27                      input {
28                         name: "Input"
29                         type {
30                           tensor_type {
31                             elem_type: 1
32                             shape {
33                               dim {
34                                 dim_value: 4
35                               }
36                             }
37                           }
38                         }
39                       }
40                      node {
41                          input: "Input"
42                          output: "Output"
43                          name: "ActivationLayer"
44                          op_type: "Relu"
45                     }
46                       output {
47                           name: "Output"
48                           type {
49                              tensor_type {
50                                elem_type: 1
51                                shape {
52                                    dim {
53                                        dim_value: 4
54                                    }
55                                }
56                             }
57                          }
58                       }
59                     }
60                    opset_import {
61                       version: 7
62                     })";
63         Setup();
64     }
65 };
66 
67 
68 TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetInput")
69 {
70     ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
71     std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
72     CHECK_EQ(1, tensors.size());
73     CHECK_EQ("Input", tensors[0]);
74 
75 }
76 
77 TEST_CASE_FIXTURE(GetInputsOutputsMainFixture, "GetOutput")
78 {
79     ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
80     std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetOutputs(model);
81     CHECK_EQ(1, tensors.size());
82     CHECK_EQ("Output", tensors[0]);
83 }
84 
85 struct GetEmptyInputsOutputsFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
86 {
GetEmptyInputsOutputsFixtureGetEmptyInputsOutputsFixture87     GetEmptyInputsOutputsFixture()
88     {
89         m_Prototext = R"(
90                    ir_version: 3
91                    producer_name:  "CNTK "
92                    producer_version:  "2.5.1 "
93                    domain:  "ai.cntk "
94                    model_version: 1
95                    graph {
96                      name:  "CNTKGraph "
97                      node {
98                         output:  "Output"
99                         attribute {
100                           name: "value"
101                           t {
102                               dims: 7
103                               data_type: 1
104                               float_data: 0.0
105                               float_data: 1.0
106                               float_data: 2.0
107                               float_data: 3.0
108                               float_data: 4.0
109                               float_data: 5.0
110                               float_data: 6.0
111 
112                           }
113                           type: 1
114                         }
115                         name:  "constantNode"
116                         op_type:  "Constant"
117                       }
118                       output {
119                           name:  "Output"
120                           type {
121                              tensor_type {
122                                elem_type: 1
123                                shape {
124                                  dim {
125                                     dim_value: 7
126                                  }
127                                }
128                              }
129                           }
130                       }
131                    }
132                    opset_import {
133                       version: 7
134                     })";
135         Setup();
136     }
137 };
138 
139 TEST_CASE_FIXTURE(GetEmptyInputsOutputsFixture, "GetEmptyInputs")
140 {
141     ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
142     std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
143     CHECK_EQ(0, tensors.size());
144 }
145 
146 TEST_CASE("GetInputsNullModel")
147 {
148     CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString(""), armnn::InvalidArgumentException);
149 }
150 
151 TEST_CASE("GetOutputsNullModel")
152 {
153     auto silencer = google::protobuf::LogSilencer(); //get rid of errors from protobuf
154     CHECK_THROWS_AS(armnnOnnxParser::OnnxParserImpl::LoadModelFromString("nknnk"), armnn::ParseException);
155 }
156 
157 struct GetInputsMultipleFixture : public armnnUtils::ParserPrototxtFixture<armnnOnnxParser::IOnnxParser>
158 {
GetInputsMultipleFixtureGetInputsMultipleFixture159     GetInputsMultipleFixture() {
160 
161         m_Prototext = R"(
162                    ir_version: 3
163                    producer_name:  "CNTK"
164                    producer_version:  "2.5.1"
165                    domain:  "ai.cntk"
166                    model_version: 1
167                    graph {
168                      name:  "CNTKGraph"
169                      input {
170                         name: "Input0"
171                         type {
172                           tensor_type {
173                             elem_type: 1
174                             shape {
175                               dim {
176                                 dim_value: 1
177                               }
178                               dim {
179                                 dim_value: 1
180                               }
181                               dim {
182                                 dim_value: 1
183                               }
184                               dim {
185                                 dim_value: 4
186                               }
187                             }
188                           }
189                         }
190                       }
191                       input {
192                          name: "Input1"
193                          type {
194                            tensor_type {
195                              elem_type: 1
196                              shape {
197                                  dim {
198                                    dim_value: 4
199                                  }
200                              }
201                            }
202                          }
203                        }
204                        node {
205                             input: "Input0"
206                             input: "Input1"
207                             output: "Output"
208                             name: "addition"
209                             op_type: "Add"
210                             doc_string: ""
211                             domain: ""
212                           }
213                           output {
214                               name: "Output"
215                               type {
216                                  tensor_type {
217                                    elem_type: 1
218                                    shape {
219                                        dim {
220                                            dim_value: 1
221                                        }
222                                        dim {
223                                            dim_value: 1
224                                        }
225                                        dim {
226                                            dim_value: 1
227                                        }
228                                        dim {
229                                            dim_value: 4
230                                        }
231                                    }
232                                 }
233                             }
234                         }
235                     }
236                    opset_import {
237                       version: 7
238                     })";
239         Setup();
240     }
241 };
242 
243 TEST_CASE_FIXTURE(GetInputsMultipleFixture, "GetInputsMultipleInputs")
244 {
245     ModelPtr model = armnnOnnxParser::OnnxParserImpl::LoadModelFromString(m_Prototext.c_str());
246     std::vector<std::string> tensors = armnnOnnxParser::OnnxParserImpl::GetInputs(model);
247     CHECK_EQ(2, tensors.size());
248     CHECK_EQ("Input0", tensors[0]);
249     CHECK_EQ("Input1", tensors[1]);
250 }
251 
252 }
253