xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Reshape.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Reshape")
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker struct ReshapeFixture : public ParserFlatbuffersFixture
12*89c4ff92SAndroid Build Coastguard Worker {
ReshapeFixtureReshapeFixture13*89c4ff92SAndroid Build Coastguard Worker     explicit ReshapeFixture(const std::string& inputShape,
14*89c4ff92SAndroid Build Coastguard Worker                             const std::string& outputShape,
15*89c4ff92SAndroid Build Coastguard Worker                             const std::string& newShape)
16*89c4ff92SAndroid Build Coastguard Worker     {
17*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
18*89c4ff92SAndroid Build Coastguard Worker             {
19*89c4ff92SAndroid Build Coastguard Worker                 "version": 3,
20*89c4ff92SAndroid Build Coastguard Worker                 "operator_codes": [ { "builtin_code": "RESHAPE" } ],
21*89c4ff92SAndroid Build Coastguard Worker                 "subgraphs": [ {
22*89c4ff92SAndroid Build Coastguard Worker                     "tensors": [
23*89c4ff92SAndroid Build Coastguard Worker                         {)";
24*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
25*89c4ff92SAndroid Build Coastguard Worker                             "shape" : )" + inputShape + ",";
26*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
27*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
28*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 0,
29*89c4ff92SAndroid Build Coastguard Worker                             "name": "inputTensor",
30*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
31*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
32*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
33*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
34*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
35*89c4ff92SAndroid Build Coastguard Worker                             }
36*89c4ff92SAndroid Build Coastguard Worker                         },
37*89c4ff92SAndroid Build Coastguard Worker                         {)";
38*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
39*89c4ff92SAndroid Build Coastguard Worker                             "shape" : )" + outputShape;
40*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(,
41*89c4ff92SAndroid Build Coastguard Worker                             "type": "UINT8",
42*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 1,
43*89c4ff92SAndroid Build Coastguard Worker                             "name": "outputTensor",
44*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
45*89c4ff92SAndroid Build Coastguard Worker                                 "min": [ 0.0 ],
46*89c4ff92SAndroid Build Coastguard Worker                                 "max": [ 255.0 ],
47*89c4ff92SAndroid Build Coastguard Worker                                 "scale": [ 1.0 ],
48*89c4ff92SAndroid Build Coastguard Worker                                 "zero_point": [ 0 ],
49*89c4ff92SAndroid Build Coastguard Worker                             }
50*89c4ff92SAndroid Build Coastguard Worker                         }
51*89c4ff92SAndroid Build Coastguard Worker                     ],
52*89c4ff92SAndroid Build Coastguard Worker                     "inputs": [ 0 ],
53*89c4ff92SAndroid Build Coastguard Worker                     "outputs": [ 1 ],
54*89c4ff92SAndroid Build Coastguard Worker                     "operators": [
55*89c4ff92SAndroid Build Coastguard Worker                         {
56*89c4ff92SAndroid Build Coastguard Worker                             "opcode_index": 0,
57*89c4ff92SAndroid Build Coastguard Worker                             "inputs": [ 0 ],
58*89c4ff92SAndroid Build Coastguard Worker                             "outputs": [ 1 ],
59*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options_type": "ReshapeOptions",
60*89c4ff92SAndroid Build Coastguard Worker                             "builtin_options": {)";
61*89c4ff92SAndroid Build Coastguard Worker         if (!newShape.empty())
62*89c4ff92SAndroid Build Coastguard Worker         {
63*89c4ff92SAndroid Build Coastguard Worker             m_JsonString += R"("new_shape" : )" + newShape;
64*89c4ff92SAndroid Build Coastguard Worker         }
65*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(},
66*89c4ff92SAndroid Build Coastguard Worker                             "custom_options_format": "FLEXBUFFERS"
67*89c4ff92SAndroid Build Coastguard Worker                         }
68*89c4ff92SAndroid Build Coastguard Worker                     ],
69*89c4ff92SAndroid Build Coastguard Worker                 } ],
70*89c4ff92SAndroid Build Coastguard Worker                 "buffers" : [ {}, {} ]
71*89c4ff92SAndroid Build Coastguard Worker             }
72*89c4ff92SAndroid Build Coastguard Worker         )";
73*89c4ff92SAndroid Build Coastguard Worker 
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker };
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker struct ReshapeFixtureWithReshapeDims : ReshapeFixture
78*89c4ff92SAndroid Build Coastguard Worker {
ReshapeFixtureWithReshapeDimsReshapeFixtureWithReshapeDims79*89c4ff92SAndroid Build Coastguard Worker     ReshapeFixtureWithReshapeDims() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]") {}
80*89c4ff92SAndroid Build Coastguard Worker };
81*89c4ff92SAndroid Build Coastguard Worker 
82*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeFixtureWithReshapeDims, "ParseReshapeWithReshapeDims")
83*89c4ff92SAndroid Build Coastguard Worker {
84*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
85*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(0,
86*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
87*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
88*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
89*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({3,3})));
90*89c4ff92SAndroid Build Coastguard Worker }
91*89c4ff92SAndroid Build Coastguard Worker 
92*89c4ff92SAndroid Build Coastguard Worker struct ReshapeFixtureWithReshapeDimsFlatten : ReshapeFixture
93*89c4ff92SAndroid Build Coastguard Worker {
ReshapeFixtureWithReshapeDimsFlattenReshapeFixtureWithReshapeDimsFlatten94*89c4ff92SAndroid Build Coastguard Worker     ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 9 ]", "[ -1 ]") {}
95*89c4ff92SAndroid Build Coastguard Worker };
96*89c4ff92SAndroid Build Coastguard Worker 
97*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeFixtureWithReshapeDimsFlatten, "ParseReshapeWithReshapeDimsFlatten")
98*89c4ff92SAndroid Build Coastguard Worker {
99*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
100*89c4ff92SAndroid Build Coastguard Worker     RunTest<1, armnn::DataType::QAsymmU8>(0,
101*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
102*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
103*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
104*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({9})));
105*89c4ff92SAndroid Build Coastguard Worker }
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker struct ReshapeFixtureWithReshapeDimsFlattenTwoDims : ReshapeFixture
108*89c4ff92SAndroid Build Coastguard Worker {
ReshapeFixtureWithReshapeDimsFlattenTwoDimsReshapeFixtureWithReshapeDimsFlattenTwoDims109*89c4ff92SAndroid Build Coastguard Worker     ReshapeFixtureWithReshapeDimsFlattenTwoDims() : ReshapeFixture("[ 3, 2, 3 ]", "[ 2, 9 ]", "[ 2, -1 ]") {}
110*89c4ff92SAndroid Build Coastguard Worker };
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeFixtureWithReshapeDimsFlattenTwoDims, "ParseReshapeWithReshapeDimsFlattenTwoDims")
113*89c4ff92SAndroid Build Coastguard Worker {
114*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
115*89c4ff92SAndroid Build Coastguard Worker     RunTest<2, armnn::DataType::QAsymmU8>(0,
116*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
117*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
118*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
119*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({2,9})));
120*89c4ff92SAndroid Build Coastguard Worker }
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker struct ReshapeFixtureWithReshapeDimsFlattenOneDim : ReshapeFixture
123*89c4ff92SAndroid Build Coastguard Worker {
ReshapeFixtureWithReshapeDimsFlattenOneDimReshapeFixtureWithReshapeDimsFlattenOneDim124*89c4ff92SAndroid Build Coastguard Worker     ReshapeFixtureWithReshapeDimsFlattenOneDim() : ReshapeFixture("[ 2, 9 ]", "[ 2, 3, 3 ]", "[ 2, -1, 3 ]") {}
125*89c4ff92SAndroid Build Coastguard Worker };
126*89c4ff92SAndroid Build Coastguard Worker 
127*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(ReshapeFixtureWithReshapeDimsFlattenOneDim, "ParseReshapeWithReshapeDimsFlattenOneDim")
128*89c4ff92SAndroid Build Coastguard Worker {
129*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
130*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::QAsymmU8>(0,
131*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
132*89c4ff92SAndroid Build Coastguard Worker                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
133*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
134*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({2,3,3})));
135*89c4ff92SAndroid Build Coastguard Worker }
136*89c4ff92SAndroid Build Coastguard Worker 
137*89c4ff92SAndroid Build Coastguard Worker struct DynamicReshapeFixtureWithReshapeDimsFlattenOneDim : ReshapeFixture
138*89c4ff92SAndroid Build Coastguard Worker {
DynamicReshapeFixtureWithReshapeDimsFlattenOneDimDynamicReshapeFixtureWithReshapeDimsFlattenOneDim139*89c4ff92SAndroid Build Coastguard Worker     DynamicReshapeFixtureWithReshapeDimsFlattenOneDim() : ReshapeFixture("[ 2, 9 ]",
140*89c4ff92SAndroid Build Coastguard Worker                                                                          "[ ]",
141*89c4ff92SAndroid Build Coastguard Worker                                                                          "[ 2, -1, 3 ]") {}
142*89c4ff92SAndroid Build Coastguard Worker };
143*89c4ff92SAndroid Build Coastguard Worker 
144*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(DynamicReshapeFixtureWithReshapeDimsFlattenOneDim, "DynParseReshapeWithReshapeDimsFlattenOneDim")
145*89c4ff92SAndroid Build Coastguard Worker {
146*89c4ff92SAndroid Build Coastguard Worker     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
147*89c4ff92SAndroid Build Coastguard Worker      RunTest<3,
148*89c4ff92SAndroid Build Coastguard Worker         armnn::DataType::QAsymmU8,
149*89c4ff92SAndroid Build Coastguard Worker         armnn::DataType::QAsymmU8>(0,
150*89c4ff92SAndroid Build Coastguard Worker                                    { { "inputTensor", {  1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 } } },
151*89c4ff92SAndroid Build Coastguard Worker                                    { { "outputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 } } },
152*89c4ff92SAndroid Build Coastguard Worker                                    true);
153*89c4ff92SAndroid Build Coastguard Worker }
154*89c4ff92SAndroid Build Coastguard Worker 
155*89c4ff92SAndroid Build Coastguard Worker }
156