xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/SplitV.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser")
10 {
11 struct SplitVFixture : public ParserFlatbuffersFixture
12 {
SplitVFixtureSplitVFixture13     explicit SplitVFixture(const std::string& inputShape,
14                            const std::string& splitValues,
15                            const std::string& sizeSplitsShape,
16                            const std::string& axisShape,
17                            const std::string& numSplits,
18                            const std::string& outputShape1,
19                            const std::string& outputShape2,
20                            const std::string& axisData,
21                            const std::string& dataType)
22     {
23         m_JsonString = R"(
24             {
25                 "version": 3,
26                 "operator_codes": [ { "builtin_code": "SPLIT_V" } ],
27                 "subgraphs": [ {
28                     "tensors": [
29                         {
30                             "shape": )" + inputShape + R"(,
31                             "type": )" + dataType + R"(,
32                             "buffer": 0,
33                             "name": "inputTensor",
34                             "quantization": {
35                                 "min": [ 0.0 ],
36                                 "max": [ 255.0 ],
37                                 "scale": [ 1.0 ],
38                                 "zero_point": [ 0 ],
39                             }
40                         },
41                         {
42                             "shape": )" + sizeSplitsShape + R"(,
43                             "type": "INT32",
44                             "buffer": 1,
45                             "name": "sizeSplits",
46                             "quantization": {
47                                 "min": [ 0.0 ],
48                                 "max": [ 255.0 ],
49                                 "scale": [ 1.0 ],
50                                 "zero_point": [ 0 ],
51                             }
52                         },
53                         {
54                             "shape": )" + axisShape + R"(,
55                             "type": "INT32",
56                             "buffer": 2,
57                             "name": "axis",
58                             "quantization": {
59                                 "min": [ 0.0 ],
60                                 "max": [ 255.0 ],
61                                 "scale": [ 1.0 ],
62                                 "zero_point": [ 0 ],
63                             }
64                         },
65                         {
66                             "shape": )" + outputShape1 + R"( ,
67                             "type":)" + dataType + R"(,
68                             "buffer": 3,
69                             "name": "outputTensor1",
70                             "quantization": {
71                                 "min": [ 0.0 ],
72                                 "max": [ 255.0 ],
73                                 "scale": [ 1.0 ],
74                                 "zero_point": [ 0 ],
75                             }
76                         },
77                         {
78                             "shape": )" + outputShape2 + R"( ,
79                             "type":)" + dataType + R"(,
80                             "buffer": 4,
81                             "name": "outputTensor2",
82                             "quantization": {
83                                 "min": [ 0.0 ],
84                                 "max": [ 255.0 ],
85                                 "scale": [ 1.0 ],
86                                 "zero_point": [ 0 ],
87                             }
88                         }
89                     ],
90                     "inputs": [ 0, 1, 2 ],
91                     "outputs": [ 3, 4 ],
92                     "operators": [
93                         {
94                             "opcode_index": 0,
95                             "inputs": [ 0, 1, 2 ],
96                             "outputs": [ 3, 4 ],
97                             "builtin_options_type": "SplitVOptions",
98                             "builtin_options": {
99                                 "num_splits": )" + numSplits + R"(
100                             },
101                             "custom_options_format": "FLEXBUFFERS"
102                         }
103                     ],
104                 } ],
105                 "buffers" : [ {}, { "data": )" + splitValues + R"( }, { "data": )" + axisData + R"( }, {}, {}]
106             }
107         )";
108 
109         Setup();
110     }
111 };
112 
113 /*
114  *  Tested inferred splitSizes with splitValues [-1, 1] locally.
115  */
116 
117 struct SimpleSplitVAxisOneFixture : SplitVFixture
118 {
SimpleSplitVAxisOneFixtureSimpleSplitVAxisOneFixture119     SimpleSplitVAxisOneFixture()
120         : SplitVFixture( "[ 4, 2, 2, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
121                          "[ 1, 2, 2, 2 ]", "[ 3, 2, 2, 2 ]", "[ 0, 0, 0, 0 ]", "FLOAT32")
122     {}
123 };
124 
125 TEST_CASE_FIXTURE(SimpleSplitVAxisOneFixture, "ParseAxisOneSplitVTwo")
126 {
127     RunTest<4, armnn::DataType::Float32>(
128         0,
129         { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
130                               9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
131                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
132                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
133         { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f } },
134           {"outputTensor2", { 9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
135                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
136                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
137 }
138 
139 struct SimpleSplitVAxisTwoFixture : SplitVFixture
140 {
SimpleSplitVAxisTwoFixtureSimpleSplitVAxisTwoFixture141     SimpleSplitVAxisTwoFixture()
142         : SplitVFixture( "[ 2, 4, 2, 2 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
143                          "[ 2, 3, 2, 2 ]", "[ 2, 1, 2, 2 ]", "[ 1, 0, 0, 0 ]", "FLOAT32")
144     {}
145 };
146 
147 TEST_CASE_FIXTURE(SimpleSplitVAxisTwoFixture, "ParseAxisTwoSplitVTwo")
148 {
149     RunTest<4, armnn::DataType::Float32>(
150         0,
151         { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
152                               9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
153                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
154                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
155         { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
156                               9.0f, 10.0f, 11.0f, 12.0f, 17.0f, 18.0f, 19.0f, 20.0f,
157                               21.0f, 22.0f, 23.0f, 24.0f, 25.0f, 26.0f, 27.0f, 28.0f } },
158           {"outputTensor2", { 13.0f, 14.0f, 15.0f, 16.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
159 }
160 
161 struct SimpleSplitVAxisThreeFixture : SplitVFixture
162 {
SimpleSplitVAxisThreeFixtureSimpleSplitVAxisThreeFixture163     SimpleSplitVAxisThreeFixture()
164         : SplitVFixture( "[ 2, 2, 4, 2 ]", "[ 1, 0, 0, 0, 3, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
165                          "[ 2, 2, 1, 2 ]", "[ 2, 2, 3, 2 ]", "[ 2, 0, 0, 0 ]", "FLOAT32")
166     {}
167 };
168 
169 TEST_CASE_FIXTURE(SimpleSplitVAxisThreeFixture, "ParseAxisThreeSplitVTwo")
170 {
171     RunTest<4, armnn::DataType::Float32>(
172         0,
173         { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
174                               9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
175                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
176                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
177         { {"outputTensor1", { 1.0f, 2.0f, 9.0f, 10.0f, 17.0f, 18.0f, 25.0f, 26.0f } },
178           {"outputTensor2", { 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 11.0f, 12.0f,
179                               13.0f, 14.0f, 15.0f, 16.0f, 19.0f, 20.0f, 21.0f, 22.0f,
180                               23.0f, 24.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } } );
181 }
182 
183 struct SimpleSplitVAxisFourFixture : SplitVFixture
184 {
SimpleSplitVAxisFourFixtureSimpleSplitVAxisFourFixture185     SimpleSplitVAxisFourFixture()
186         : SplitVFixture( "[ 2, 2, 2, 4 ]", "[ 3, 0, 0, 0, 1, 0, 0, 0 ]", "[ 2 ]","[ ]", "2",
187                          "[ 2, 2, 2, 3 ]", "[ 2, 2, 2, 1 ]", "[ 3, 0, 0, 0 ]", "FLOAT32")
188     {}
189 };
190 
191 TEST_CASE_FIXTURE(SimpleSplitVAxisFourFixture, "ParseAxisFourSplitVTwo")
192 {
193     RunTest<4, armnn::DataType::Float32>(
194         0,
195         { {"inputTensor",   { 1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f,
196                               9.0f, 10.0f, 11.0f, 12.0f, 13.0f, 14.0f, 15.0f, 16.0f,
197                               17.0f, 18.0f, 19.0f, 20.0f, 21.0f, 22.0f, 23.0f, 24.0f,
198                               25.0f, 26.0f, 27.0f, 28.0f, 29.0f, 30.0f, 31.0f, 32.0f } } },
199         { {"outputTensor1", { 1.0f, 2.0f, 3.0f, 5.0f, 6.0f, 7.0f, 9.0f, 10.0f,
200                               11.0f, 13.0f, 14.0f, 15.0f, 17.0f, 18.0f, 19.0f, 21.0f,
201                               22.0f, 23.0f, 25.0f, 26.0f, 27.0f, 29.0f, 30.0f, 31.0f} },
202           {"outputTensor2", { 4.0f, 8.0f, 12.0f, 16.0f, 20.0f, 24.0f, 28.0f, 32.0f } } } );
203 }
204 
205 }
206