xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 #include "../TfLiteParser.hpp"
8 
9 TEST_SUITE("TensorflowLiteParser_InputOutputTensorNames")
10 {
11 struct EmptyNetworkFixture : public ParserFlatbuffersFixture
12 {
EmptyNetworkFixtureEmptyNetworkFixture13     explicit EmptyNetworkFixture() {
14         m_JsonString = R"(
15             {
16                 "version": 3,
17                 "operator_codes": [],
18                 "subgraphs": [ {} ]
19             })";
20     }
21 };
22 
23 TEST_CASE_FIXTURE(EmptyNetworkFixture, "EmptyNetworkHasNoInputsAndOutputs")
24 {
25     Setup(false);
26     CHECK(m_Parser->GetSubgraphCount() == 1);
27     CHECK(m_Parser->GetSubgraphInputTensorNames(0).size() == 0);
28     CHECK(m_Parser->GetSubgraphOutputTensorNames(0).size() == 0);
29 }
30 
31 struct MissingTensorsFixture : public ParserFlatbuffersFixture
32 {
MissingTensorsFixtureMissingTensorsFixture33     explicit MissingTensorsFixture()
34     {
35         m_JsonString = R"(
36             {
37                 "version": 3,
38                 "operator_codes": [],
39                 "subgraphs": [{
40                     "inputs" : [ 0, 1 ],
41                     "outputs" : [ 2, 3 ],
42                 }]
43             })";
44     }
45 };
46 
47 TEST_CASE_FIXTURE(MissingTensorsFixture, "MissingTensorsThrowException")
48 {
49     // this throws because it cannot do the input output tensor connections
50     CHECK_THROWS_AS(Setup(), armnn::ParseException);
51 }
52 
53 struct InvalidTensorsFixture : public ParserFlatbuffersFixture
54 {
InvalidTensorsFixtureInvalidTensorsFixture55     explicit InvalidTensorsFixture()
56     {
57         m_JsonString = R"(
58             {
59                 "version": 3,
60                 "operator_codes": [ ],
61                 "subgraphs": [{
62                     "tensors": [ {
63                         "shape": [ 1, 1, 1, 1, 1, 1 ],
64                         "type": "FLOAT32",
65                         "name": "In",
66                         "buffer": 0
67                     }, {
68                         "shape": [ 1, 1, 1, 1, 1, 1 ],
69                         "type": "FLOAT32",
70                         "name": "Out",
71                         "buffer": 1
72                     }],
73                     "inputs" : [ 0 ],
74                     "outputs" : [ 1 ],
75                 }]
76             })";
77     }
78 };
79 
80 TEST_CASE_FIXTURE(InvalidTensorsFixture, "InvalidTensorsThrowException")
81 {
82     // Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions
83     static_assert(armnn::MaxNumOfTensorDimensions == 5, "Please update InvalidTensorsFixture");
84     CHECK_THROWS_AS(Setup(), armnn::InvalidArgumentException);
85 }
86 
87 struct ValidTensorsFixture : public ParserFlatbuffersFixture
88 {
ValidTensorsFixtureValidTensorsFixture89     explicit ValidTensorsFixture()
90     {
91         m_JsonString = R"(
92             {
93                 "version": 3,
94                 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
95                 "subgraphs": [{
96                     "tensors": [ {
97                         "shape": [ 1, 1, 1, 1 ],
98                         "type": "FLOAT32",
99                         "name": "In",
100                         "buffer": 0,
101                     }, {
102                         "shape": [ 1, 1, 1, 1 ],
103                         "type": "FLOAT32",
104                         "name": "Out",
105                         "buffer": 1,
106                     }],
107                     "inputs" : [ 0 ],
108                     "outputs" : [ 1 ],
109                     "operators": [{
110                         "opcode_index": 0,
111                         "inputs": [ 0 ],
112                         "outputs": [ 1 ],
113                         "builtin_options_type": "Pool2DOptions",
114                         "builtin_options":
115                         {
116                             "padding": "VALID",
117                             "stride_w": 1,
118                             "stride_h": 1,
119                             "filter_width": 1,
120                             "filter_height": 1,
121                             "fused_activation_function": "NONE"
122                         },
123                         "custom_options_format": "FLEXBUFFERS"
124                     }]
125                 }]
126             })";
127     }
128 };
129 
130 TEST_CASE_FIXTURE(ValidTensorsFixture, "GetValidInputOutputTensorNames")
131 {
132     Setup();
133     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0).size(), 1u);
134     CHECK_EQ(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
135     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0)[0], "In");
136     CHECK_EQ(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
137 }
138 
139 TEST_CASE_FIXTURE(ValidTensorsFixture, "ThrowIfSubgraphIdInvalidForInOutNames")
140 {
141     Setup();
142 
143     // these throw because of the invalid subgraph id
144     CHECK_THROWS_AS(m_Parser->GetSubgraphInputTensorNames(1), armnn::ParseException);
145     CHECK_THROWS_AS(m_Parser->GetSubgraphOutputTensorNames(1), armnn::ParseException);
146 }
147 
148 struct Rank0TensorFixture : public ParserFlatbuffersFixture
149 {
Rank0TensorFixtureRank0TensorFixture150     explicit Rank0TensorFixture()
151     {
152         m_JsonString = R"(
153             {
154                 "version": 3,
155                 "operator_codes": [ { "builtin_code": "MINIMUM" } ],
156                 "subgraphs": [{
157                     "tensors": [ {
158                         "shape": [  ],
159                         "type": "FLOAT32",
160                         "name": "In0",
161                         "buffer": 0,
162                     }, {
163                         "shape": [  ],
164                         "type": "FLOAT32",
165                         "name": "In1",
166                         "buffer": 1,
167                     }, {
168                         "shape": [ ],
169                         "type": "FLOAT32",
170                         "name": "Out",
171                         "buffer": 2,
172                     }],
173                     "inputs" : [ 0, 1 ],
174                     "outputs" : [ 2 ],
175                     "operators": [{
176                         "opcode_index": 0,
177                         "inputs": [ 0, 1 ],
178                         "outputs": [ 2 ],
179                         "custom_options_format": "FLEXBUFFERS"
180                     }]
181                 }]
182             }
183         )";
184     }
185 };
186 
187 TEST_CASE_FIXTURE(Rank0TensorFixture, "Rank0Tensor")
188 {
189     Setup();
190     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0).size(), 2u);
191     CHECK_EQ(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
192     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0)[0], "In0");
193     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0)[1], "In1");
194     CHECK_EQ(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
195 }
196 
197 }
198