xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Reshape.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #include "ParserFlatbuffersFixture.hpp"
7 
8 
9 TEST_SUITE("TensorflowLiteParser_Reshape")
10 {
11 struct ReshapeFixture : public ParserFlatbuffersFixture
12 {
ReshapeFixtureReshapeFixture13     explicit ReshapeFixture(const std::string& inputShape,
14                             const std::string& outputShape,
15                             const std::string& newShape)
16     {
17         m_JsonString = R"(
18             {
19                 "version": 3,
20                 "operator_codes": [ { "builtin_code": "RESHAPE" } ],
21                 "subgraphs": [ {
22                     "tensors": [
23                         {)";
24         m_JsonString += R"(
25                             "shape" : )" + inputShape + ",";
26         m_JsonString += R"(
27                             "type": "UINT8",
28                             "buffer": 0,
29                             "name": "inputTensor",
30                             "quantization": {
31                                 "min": [ 0.0 ],
32                                 "max": [ 255.0 ],
33                                 "scale": [ 1.0 ],
34                                 "zero_point": [ 0 ],
35                             }
36                         },
37                         {)";
38         m_JsonString += R"(
39                             "shape" : )" + outputShape;
40         m_JsonString += R"(,
41                             "type": "UINT8",
42                             "buffer": 1,
43                             "name": "outputTensor",
44                             "quantization": {
45                                 "min": [ 0.0 ],
46                                 "max": [ 255.0 ],
47                                 "scale": [ 1.0 ],
48                                 "zero_point": [ 0 ],
49                             }
50                         }
51                     ],
52                     "inputs": [ 0 ],
53                     "outputs": [ 1 ],
54                     "operators": [
55                         {
56                             "opcode_index": 0,
57                             "inputs": [ 0 ],
58                             "outputs": [ 1 ],
59                             "builtin_options_type": "ReshapeOptions",
60                             "builtin_options": {)";
61         if (!newShape.empty())
62         {
63             m_JsonString += R"("new_shape" : )" + newShape;
64         }
65         m_JsonString += R"(},
66                             "custom_options_format": "FLEXBUFFERS"
67                         }
68                     ],
69                 } ],
70                 "buffers" : [ {}, {} ]
71             }
72         )";
73 
74     }
75 };
76 
77 struct ReshapeFixtureWithReshapeDims : ReshapeFixture
78 {
ReshapeFixtureWithReshapeDimsReshapeFixtureWithReshapeDims79     ReshapeFixtureWithReshapeDims() : ReshapeFixture("[ 1, 9 ]", "[ 3, 3 ]", "[ 3, 3 ]") {}
80 };
81 
82 TEST_CASE_FIXTURE(ReshapeFixtureWithReshapeDims, "ParseReshapeWithReshapeDims")
83 {
84     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
85     RunTest<2, armnn::DataType::QAsymmU8>(0,
86                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
87                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
88     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
89                 == armnn::TensorShape({3,3})));
90 }
91 
92 struct ReshapeFixtureWithReshapeDimsFlatten : ReshapeFixture
93 {
ReshapeFixtureWithReshapeDimsFlattenReshapeFixtureWithReshapeDimsFlatten94     ReshapeFixtureWithReshapeDimsFlatten() : ReshapeFixture("[ 3, 3 ]", "[ 9 ]", "[ -1 ]") {}
95 };
96 
97 TEST_CASE_FIXTURE(ReshapeFixtureWithReshapeDimsFlatten, "ParseReshapeWithReshapeDimsFlatten")
98 {
99     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
100     RunTest<1, armnn::DataType::QAsymmU8>(0,
101                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 },
102                                                  { 1, 2, 3, 4, 5, 6, 7, 8, 9 });
103     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
104                 == armnn::TensorShape({9})));
105 }
106 
107 struct ReshapeFixtureWithReshapeDimsFlattenTwoDims : ReshapeFixture
108 {
ReshapeFixtureWithReshapeDimsFlattenTwoDimsReshapeFixtureWithReshapeDimsFlattenTwoDims109     ReshapeFixtureWithReshapeDimsFlattenTwoDims() : ReshapeFixture("[ 3, 2, 3 ]", "[ 2, 9 ]", "[ 2, -1 ]") {}
110 };
111 
112 TEST_CASE_FIXTURE(ReshapeFixtureWithReshapeDimsFlattenTwoDims, "ParseReshapeWithReshapeDimsFlattenTwoDims")
113 {
114     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
115     RunTest<2, armnn::DataType::QAsymmU8>(0,
116                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
117                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
118     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
119                 == armnn::TensorShape({2,9})));
120 }
121 
122 struct ReshapeFixtureWithReshapeDimsFlattenOneDim : ReshapeFixture
123 {
ReshapeFixtureWithReshapeDimsFlattenOneDimReshapeFixtureWithReshapeDimsFlattenOneDim124     ReshapeFixtureWithReshapeDimsFlattenOneDim() : ReshapeFixture("[ 2, 9 ]", "[ 2, 3, 3 ]", "[ 2, -1, 3 ]") {}
125 };
126 
127 TEST_CASE_FIXTURE(ReshapeFixtureWithReshapeDimsFlattenOneDim, "ParseReshapeWithReshapeDimsFlattenOneDim")
128 {
129     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
130     RunTest<3, armnn::DataType::QAsymmU8>(0,
131                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 },
132                                                  { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 });
133     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
134                 == armnn::TensorShape({2,3,3})));
135 }
136 
137 struct DynamicReshapeFixtureWithReshapeDimsFlattenOneDim : ReshapeFixture
138 {
DynamicReshapeFixtureWithReshapeDimsFlattenOneDimDynamicReshapeFixtureWithReshapeDimsFlattenOneDim139     DynamicReshapeFixtureWithReshapeDimsFlattenOneDim() : ReshapeFixture("[ 2, 9 ]",
140                                                                          "[ ]",
141                                                                          "[ 2, -1, 3 ]") {}
142 };
143 
144 TEST_CASE_FIXTURE(DynamicReshapeFixtureWithReshapeDimsFlattenOneDim, "DynParseReshapeWithReshapeDimsFlattenOneDim")
145 {
146     SetupSingleInputSingleOutput("inputTensor", "outputTensor");
147      RunTest<3,
148         armnn::DataType::QAsymmU8,
149         armnn::DataType::QAsymmU8>(0,
150                                    { { "inputTensor", {  1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 } } },
151                                    { { "outputTensor", { 1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6, 6, 6 } } },
152                                    true);
153 }
154 
155 }
156