xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/InputOutputTensorNames.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker #include "../TfLiteParser.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_InputOutputTensorNames")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct EmptyNetworkFixture : public ParserFlatbuffersFixture
12*89c4ff92SAndroid Build Coastguard Worker {
EmptyNetworkFixtureEmptyNetworkFixture13*89c4ff92SAndroid Build Coastguard Worker     explicit EmptyNetworkFixture() {
14*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
15*89c4ff92SAndroid Build Coastguard Worker             {
16*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
17*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [],
18*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {} ]
19*89c4ff92SAndroid Build Coastguard Worker             })";
20*89c4ff92SAndroid Build Coastguard Worker     }
21*89c4ff92SAndroid Build Coastguard Worker };
22*89c4ff92SAndroid Build Coastguard Worker 
23*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(EmptyNetworkFixture, "EmptyNetworkHasNoInputsAndOutputs")
24*89c4ff92SAndroid Build Coastguard Worker {
25*89c4ff92SAndroid Build Coastguard Worker     Setup(false);
26*89c4ff92SAndroid Build Coastguard Worker     CHECK(m_Parser->GetSubgraphCount() == 1);
27*89c4ff92SAndroid Build Coastguard Worker     CHECK(m_Parser->GetSubgraphInputTensorNames(0).size() == 0);
28*89c4ff92SAndroid Build Coastguard Worker     CHECK(m_Parser->GetSubgraphOutputTensorNames(0).size() == 0);
29*89c4ff92SAndroid Build Coastguard Worker }
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker struct MissingTensorsFixture : public ParserFlatbuffersFixture
32*89c4ff92SAndroid Build Coastguard Worker {
MissingTensorsFixtureMissingTensorsFixture33*89c4ff92SAndroid Build Coastguard Worker     explicit MissingTensorsFixture()
34*89c4ff92SAndroid Build Coastguard Worker     {
35*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
36*89c4ff92SAndroid Build Coastguard Worker             {
37*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
38*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [],
39*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [{
40*89c4ff92SAndroid Build Coastguard Worker                     "inputs" : [ 0, 1 ],
41*89c4ff92SAndroid Build Coastguard Worker                     "outputs" : [ 2, 3 ],
42*89c4ff92SAndroid Build Coastguard Worker                 }]
43*89c4ff92SAndroid Build Coastguard Worker             })";
44*89c4ff92SAndroid Build Coastguard Worker     }
45*89c4ff92SAndroid Build Coastguard Worker };
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(MissingTensorsFixture, "MissingTensorsThrowException")
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker     // this throws because it cannot do the input output tensor connections
50*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup(), armnn::ParseException);
51*89c4ff92SAndroid Build Coastguard Worker }
52*89c4ff92SAndroid Build Coastguard Worker 
53*89c4ff92SAndroid Build Coastguard Worker struct InvalidTensorsFixture : public ParserFlatbuffersFixture
54*89c4ff92SAndroid Build Coastguard Worker {
InvalidTensorsFixtureInvalidTensorsFixture55*89c4ff92SAndroid Build Coastguard Worker     explicit InvalidTensorsFixture()
56*89c4ff92SAndroid Build Coastguard Worker     {
57*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
58*89c4ff92SAndroid Build Coastguard Worker             {
59*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
60*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ ],
61*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [{
62*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [ {
63*89c4ff92SAndroid Build Coastguard Worker                         "shape": [ 1, 1, 1, 1, 1, 1 ],
64*89c4ff92SAndroid Build Coastguard Worker                         "type": "FLOAT32",
65*89c4ff92SAndroid Build Coastguard Worker                         "name": "In",
66*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 0
67*89c4ff92SAndroid Build Coastguard Worker                     }, {
68*89c4ff92SAndroid Build Coastguard Worker                         "shape": [ 1, 1, 1, 1, 1, 1 ],
69*89c4ff92SAndroid Build Coastguard Worker                         "type": "FLOAT32",
70*89c4ff92SAndroid Build Coastguard Worker                         "name": "Out",
71*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 1
72*89c4ff92SAndroid Build Coastguard Worker                     }],
73*89c4ff92SAndroid Build Coastguard Worker                     "inputs" : [ 0 ],
74*89c4ff92SAndroid Build Coastguard Worker                     "outputs" : [ 1 ],
75*89c4ff92SAndroid Build Coastguard Worker                 }]
76*89c4ff92SAndroid Build Coastguard Worker             })";
77*89c4ff92SAndroid Build Coastguard Worker     }
78*89c4ff92SAndroid Build Coastguard Worker };
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(InvalidTensorsFixture, "InvalidTensorsThrowException")
81*89c4ff92SAndroid Build Coastguard Worker {
82*89c4ff92SAndroid Build Coastguard Worker     // Tensor numDimensions must be less than or equal to MaxNumOfTensorDimensions
83*89c4ff92SAndroid Build Coastguard Worker     static_assert(armnn::MaxNumOfTensorDimensions == 5, "Please update InvalidTensorsFixture");
84*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(Setup(), armnn::InvalidArgumentException);
85*89c4ff92SAndroid Build Coastguard Worker }
86*89c4ff92SAndroid Build Coastguard Worker 
87*89c4ff92SAndroid Build Coastguard Worker struct ValidTensorsFixture : public ParserFlatbuffersFixture
88*89c4ff92SAndroid Build Coastguard Worker {
ValidTensorsFixtureValidTensorsFixture89*89c4ff92SAndroid Build Coastguard Worker     explicit ValidTensorsFixture()
90*89c4ff92SAndroid Build Coastguard Worker     {
91*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
92*89c4ff92SAndroid Build Coastguard Worker             {
93*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
94*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "AVERAGE_POOL_2D" } ],
95*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [{
96*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [ {
97*89c4ff92SAndroid Build Coastguard Worker                         "shape": [ 1, 1, 1, 1 ],
98*89c4ff92SAndroid Build Coastguard Worker                         "type": "FLOAT32",
99*89c4ff92SAndroid Build Coastguard Worker                         "name": "In",
100*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 0,
101*89c4ff92SAndroid Build Coastguard Worker                     }, {
102*89c4ff92SAndroid Build Coastguard Worker                         "shape": [ 1, 1, 1, 1 ],
103*89c4ff92SAndroid Build Coastguard Worker                         "type": "FLOAT32",
104*89c4ff92SAndroid Build Coastguard Worker                         "name": "Out",
105*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 1,
106*89c4ff92SAndroid Build Coastguard Worker                     }],
107*89c4ff92SAndroid Build Coastguard Worker                     "inputs" : [ 0 ],
108*89c4ff92SAndroid Build Coastguard Worker                     "outputs" : [ 1 ],
109*89c4ff92SAndroid Build Coastguard Worker                     "operators": [{
110*89c4ff92SAndroid Build Coastguard Worker                         "opcode_index": 0,
111*89c4ff92SAndroid Build Coastguard Worker                         "inputs": [ 0 ],
112*89c4ff92SAndroid Build Coastguard Worker                         "outputs": [ 1 ],
113*89c4ff92SAndroid Build Coastguard Worker                         "builtin_options_type": "Pool2DOptions",
114*89c4ff92SAndroid Build Coastguard Worker                         "builtin_options":
115*89c4ff92SAndroid Build Coastguard Worker                         {
116*89c4ff92SAndroid Build Coastguard Worker                             "padding": "VALID",
117*89c4ff92SAndroid Build Coastguard Worker                             "stride_w": 1,
118*89c4ff92SAndroid Build Coastguard Worker                             "stride_h": 1,
119*89c4ff92SAndroid Build Coastguard Worker                             "filter_width": 1,
120*89c4ff92SAndroid Build Coastguard Worker                             "filter_height": 1,
121*89c4ff92SAndroid Build Coastguard Worker                             "fused_activation_function": "NONE"
122*89c4ff92SAndroid Build Coastguard Worker                         },
123*89c4ff92SAndroid Build Coastguard Worker                         "custom_options_format": "FLEXBUFFERS"
124*89c4ff92SAndroid Build Coastguard Worker                     }]
125*89c4ff92SAndroid Build Coastguard Worker                 }]
126*89c4ff92SAndroid Build Coastguard Worker             })";
127*89c4ff92SAndroid Build Coastguard Worker     }
128*89c4ff92SAndroid Build Coastguard Worker };
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ValidTensorsFixture, "GetValidInputOutputTensorNames")
131*89c4ff92SAndroid Build Coastguard Worker {
132*89c4ff92SAndroid Build Coastguard Worker     Setup();
133*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0).size(), 1u);
134*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
135*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0)[0], "In");
136*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
137*89c4ff92SAndroid Build Coastguard Worker }
138*89c4ff92SAndroid Build Coastguard Worker 
139*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ValidTensorsFixture, "ThrowIfSubgraphIdInvalidForInOutNames")
140*89c4ff92SAndroid Build Coastguard Worker {
141*89c4ff92SAndroid Build Coastguard Worker     Setup();
142*89c4ff92SAndroid Build Coastguard Worker 
143*89c4ff92SAndroid Build Coastguard Worker     // these throw because of the invalid subgraph id
144*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(m_Parser->GetSubgraphInputTensorNames(1), armnn::ParseException);
145*89c4ff92SAndroid Build Coastguard Worker     CHECK_THROWS_AS(m_Parser->GetSubgraphOutputTensorNames(1), armnn::ParseException);
146*89c4ff92SAndroid Build Coastguard Worker }
147*89c4ff92SAndroid Build Coastguard Worker 
148*89c4ff92SAndroid Build Coastguard Worker struct Rank0TensorFixture : public ParserFlatbuffersFixture
149*89c4ff92SAndroid Build Coastguard Worker {
Rank0TensorFixtureRank0TensorFixture150*89c4ff92SAndroid Build Coastguard Worker     explicit Rank0TensorFixture()
151*89c4ff92SAndroid Build Coastguard Worker     {
152*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
153*89c4ff92SAndroid Build Coastguard Worker             {
154*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
155*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "MINIMUM" } ],
156*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [{
157*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [ {
158*89c4ff92SAndroid Build Coastguard Worker                         "shape": [  ],
159*89c4ff92SAndroid Build Coastguard Worker                         "type": "FLOAT32",
160*89c4ff92SAndroid Build Coastguard Worker                         "name": "In0",
161*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 0,
162*89c4ff92SAndroid Build Coastguard Worker                     }, {
163*89c4ff92SAndroid Build Coastguard Worker                         "shape": [  ],
164*89c4ff92SAndroid Build Coastguard Worker                         "type": "FLOAT32",
165*89c4ff92SAndroid Build Coastguard Worker                         "name": "In1",
166*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 1,
167*89c4ff92SAndroid Build Coastguard Worker                     }, {
168*89c4ff92SAndroid Build Coastguard Worker                         "shape": [ ],
169*89c4ff92SAndroid Build Coastguard Worker                         "type": "FLOAT32",
170*89c4ff92SAndroid Build Coastguard Worker                         "name": "Out",
171*89c4ff92SAndroid Build Coastguard Worker                         "buffer": 2,
172*89c4ff92SAndroid Build Coastguard Worker                     }],
173*89c4ff92SAndroid Build Coastguard Worker                     "inputs" : [ 0, 1 ],
174*89c4ff92SAndroid Build Coastguard Worker                     "outputs" : [ 2 ],
175*89c4ff92SAndroid Build Coastguard Worker                     "operators": [{
176*89c4ff92SAndroid Build Coastguard Worker                         "opcode_index": 0,
177*89c4ff92SAndroid Build Coastguard Worker                         "inputs": [ 0, 1 ],
178*89c4ff92SAndroid Build Coastguard Worker                         "outputs": [ 2 ],
179*89c4ff92SAndroid Build Coastguard Worker                         "custom_options_format": "FLEXBUFFERS"
180*89c4ff92SAndroid Build Coastguard Worker                     }]
181*89c4ff92SAndroid Build Coastguard Worker                 }]
182*89c4ff92SAndroid Build Coastguard Worker             }
183*89c4ff92SAndroid Build Coastguard Worker         )";
184*89c4ff92SAndroid Build Coastguard Worker     }
185*89c4ff92SAndroid Build Coastguard Worker };
186*89c4ff92SAndroid Build Coastguard Worker 
187*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(Rank0TensorFixture, "Rank0Tensor")
188*89c4ff92SAndroid Build Coastguard Worker {
189*89c4ff92SAndroid Build Coastguard Worker     Setup();
190*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0).size(), 2u);
191*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphOutputTensorNames(0).size(), 1u);
192*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0)[0], "In0");
193*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphInputTensorNames(0)[1], "In1");
194*89c4ff92SAndroid Build Coastguard Worker     CHECK_EQ(m_Parser->GetSubgraphOutputTensorNames(0)[0], "Out");
195*89c4ff92SAndroid Build Coastguard Worker }
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker }
198