xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Pack.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser_Pack")
10 {
11 struct PackFixture : public ParserFlatbuffersFixture
12 {
PackFixturePackFixture13     explicit PackFixture(const std::string & inputShape,
14                          const unsigned int numInputs,
15                          const std::string & outputShape,
16                          const std::string & axis)
17     {
18         m_JsonString = R"(
19             {
20                 "version": 3,
21                 "operator_codes": [ { "builtin_code": "PACK" } ],
22                 "subgraphs": [ {
23                     "tensors": [)";
24 
25         for (unsigned int i = 0; i < numInputs; ++i)
26         {
27             m_JsonString += R"(
28                         {
29                             "shape": )" + inputShape + R"(,
30                             "type": "FLOAT32",
31                             "buffer": )" + std::to_string(i) + R"(,
32                             "name": "inputTensor)" + std::to_string(i + 1) + R"(",
33                             "quantization": {
34                                 "min": [ 0.0 ],
35                                 "max": [ 255.0 ],
36                                 "scale": [ 1.0 ],
37                                 "zero_point": [ 0 ],
38                             }
39                         },)";
40         }
41 
42         std::string inputIndexes;
43         for (unsigned int i = 0; i < numInputs-1; ++i)
44         {
45             inputIndexes += std::to_string(i) + R"(, )";
46         }
47         inputIndexes += std::to_string(numInputs-1);
48 
49         m_JsonString += R"(
50                         {
51                             "shape": )" + outputShape + R"( ,
52                             "type": "FLOAT32",
53                             "buffer": )" + std::to_string(numInputs) + R"(,
54                             "name": "outputTensor",
55                             "quantization": {
56                                 "min": [ 0.0 ],
57                                 "max": [ 255.0 ],
58                                 "scale": [ 1.0 ],
59                                 "zero_point": [ 0 ],
60                             }
61                         }
62                     ],
63                     "inputs": [ )" + inputIndexes + R"( ],
64                     "outputs": [ 2 ],
65                     "operators": [
66                         {
67                             "opcode_index": 0,
68                             "inputs": [ )" + inputIndexes + R"( ],
69                             "outputs": [ 2 ],
70                             "builtin_options_type": "PackOptions",
71                             "builtin_options": {
72                                 "axis": )" + axis + R"(,
73                                 "values_count": )" + std::to_string(numInputs) + R"(
74                             },
75                             "custom_options_format": "FLEXBUFFERS"
76                         }
77                     ],
78                 } ],
79                 "buffers" : [)";
80 
81             for (unsigned int i = 0; i < numInputs-1; ++i)
82             {
83                 m_JsonString += R"(
84                     { },)";
85             }
86             m_JsonString += R"(
87                     { }
88                 ]
89             })";
90         Setup();
91     }
92 };
93 
94 struct SimplePackFixture : PackFixture
95 {
SimplePackFixtureSimplePackFixture96     SimplePackFixture() : PackFixture("[ 3, 2, 3 ]",
97                                       2,
98                                       "[ 3, 2, 3, 2 ]",
99                                       "3") {}
100 };
101 
102 TEST_CASE_FIXTURE(SimplePackFixture, "ParsePack")
103 {
104     RunTest<4, armnn::DataType::Float32>(
105     0,
106     { {"inputTensor1", { 1, 2, 3,
107                          4, 5, 6,
108 
109                          7, 8, 9,
110                          10, 11, 12,
111 
112                          13, 14, 15,
113                          16, 17, 18 } },
114     {"inputTensor2", { 19, 20, 21,
115                        22, 23, 24,
116 
117                        25, 26, 27,
118                        28, 29, 30,
119 
120                        31, 32, 33,
121                        34, 35, 36 } } },
122     { {"outputTensor", { 1, 19,
123                          2, 20,
124                          3, 21,
125 
126                          4, 22,
127                          5, 23,
128                          6, 24,
129 
130 
131                          7, 25,
132                          8, 26,
133                          9, 27,
134 
135                          10, 28,
136                          11, 29,
137                          12, 30,
138 
139 
140                          13, 31,
141                          14, 32,
142                          15, 33,
143 
144                          16, 34,
145                          17, 35,
146                          18, 36 } } });
147 }
148 
149 }
150