xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/Transpose.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Transpose")
9*89c4ff92SAndroid Build Coastguard Worker {
10*89c4ff92SAndroid Build Coastguard Worker struct TransposeFixture : public ParserFlatbuffersFixture
11*89c4ff92SAndroid Build Coastguard Worker {
TransposeFixtureTransposeFixture12*89c4ff92SAndroid Build Coastguard Worker     explicit TransposeFixture(const std::string & inputShape,
13*89c4ff92SAndroid Build Coastguard Worker                               const std::string & permuteData,
14*89c4ff92SAndroid Build Coastguard Worker                               const std::string & outputShape)
15*89c4ff92SAndroid Build Coastguard Worker     {
16*89c4ff92SAndroid Build Coastguard Worker         m_JsonString = R"(
17*89c4ff92SAndroid Build Coastguard Worker             {
18*89c4ff92SAndroid Build Coastguard Worker                   "version": 3,
19*89c4ff92SAndroid Build Coastguard Worker                   "operator_codes": [
20*89c4ff92SAndroid Build Coastguard Worker                     {
21*89c4ff92SAndroid Build Coastguard Worker                       "builtin_code": "TRANSPOSE",
22*89c4ff92SAndroid Build Coastguard Worker                       "version": 1
23*89c4ff92SAndroid Build Coastguard Worker                     }
24*89c4ff92SAndroid Build Coastguard Worker                   ],
25*89c4ff92SAndroid Build Coastguard Worker                   "subgraphs": [
26*89c4ff92SAndroid Build Coastguard Worker                     {
27*89c4ff92SAndroid Build Coastguard Worker                       "tensors": [
28*89c4ff92SAndroid Build Coastguard Worker                         {
29*89c4ff92SAndroid Build Coastguard Worker                           "shape": )" + inputShape + R"(,
30*89c4ff92SAndroid Build Coastguard Worker                           "type": "FLOAT32",
31*89c4ff92SAndroid Build Coastguard Worker                           "buffer": 0,
32*89c4ff92SAndroid Build Coastguard Worker                           "name": "inputTensor",
33*89c4ff92SAndroid Build Coastguard Worker                           "quantization": {
34*89c4ff92SAndroid Build Coastguard Worker                             "min": [
35*89c4ff92SAndroid Build Coastguard Worker                               0.0
36*89c4ff92SAndroid Build Coastguard Worker                             ],
37*89c4ff92SAndroid Build Coastguard Worker                             "max": [
38*89c4ff92SAndroid Build Coastguard Worker                               255.0
39*89c4ff92SAndroid Build Coastguard Worker                             ],
40*89c4ff92SAndroid Build Coastguard Worker                             "details_type": 0,
41*89c4ff92SAndroid Build Coastguard Worker                             "quantized_dimension": 0
42*89c4ff92SAndroid Build Coastguard Worker                           },
43*89c4ff92SAndroid Build Coastguard Worker                           "is_variable": false
44*89c4ff92SAndroid Build Coastguard Worker                         },
45*89c4ff92SAndroid Build Coastguard Worker                         {
46*89c4ff92SAndroid Build Coastguard Worker                           "shape": )" + outputShape + R"(,
47*89c4ff92SAndroid Build Coastguard Worker                           "type": "FLOAT32",
48*89c4ff92SAndroid Build Coastguard Worker                           "buffer": 1,
49*89c4ff92SAndroid Build Coastguard Worker                           "name": "outputTensor",
50*89c4ff92SAndroid Build Coastguard Worker                           "quantization": {
51*89c4ff92SAndroid Build Coastguard Worker                             "details_type": 0,
52*89c4ff92SAndroid Build Coastguard Worker                             "quantized_dimension": 0
53*89c4ff92SAndroid Build Coastguard Worker                           },
54*89c4ff92SAndroid Build Coastguard Worker                           "is_variable": false
55*89c4ff92SAndroid Build Coastguard Worker                         })";
56*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(,
57*89c4ff92SAndroid Build Coastguard Worker                           {
58*89c4ff92SAndroid Build Coastguard Worker                             "shape": [
59*89c4ff92SAndroid Build Coastguard Worker                               3
60*89c4ff92SAndroid Build Coastguard Worker                             ],
61*89c4ff92SAndroid Build Coastguard Worker                             "type": "INT32",
62*89c4ff92SAndroid Build Coastguard Worker                             "buffer": 2,
63*89c4ff92SAndroid Build Coastguard Worker                             "name": "permuteTensor",
64*89c4ff92SAndroid Build Coastguard Worker                             "quantization": {
65*89c4ff92SAndroid Build Coastguard Worker                               "details_type": 0,
66*89c4ff92SAndroid Build Coastguard Worker                               "quantized_dimension": 0
67*89c4ff92SAndroid Build Coastguard Worker                             },
68*89c4ff92SAndroid Build Coastguard Worker                             "is_variable": false
69*89c4ff92SAndroid Build Coastguard Worker                           })";
70*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(],
71*89c4ff92SAndroid Build Coastguard Worker                       "inputs": [
72*89c4ff92SAndroid Build Coastguard Worker                         0
73*89c4ff92SAndroid Build Coastguard Worker                       ],
74*89c4ff92SAndroid Build Coastguard Worker                       "outputs": [
75*89c4ff92SAndroid Build Coastguard Worker                         1
76*89c4ff92SAndroid Build Coastguard Worker                       ],
77*89c4ff92SAndroid Build Coastguard Worker                       "operators": [
78*89c4ff92SAndroid Build Coastguard Worker                         {
79*89c4ff92SAndroid Build Coastguard Worker                           "opcode_index": 0,
80*89c4ff92SAndroid Build Coastguard Worker                           "inputs": [
81*89c4ff92SAndroid Build Coastguard Worker                             0)";
82*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(,2)";
83*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(],
84*89c4ff92SAndroid Build Coastguard Worker                           "outputs": [
85*89c4ff92SAndroid Build Coastguard Worker                             1
86*89c4ff92SAndroid Build Coastguard Worker                           ],
87*89c4ff92SAndroid Build Coastguard Worker                           "builtin_options_type": "TransposeOptions",
88*89c4ff92SAndroid Build Coastguard Worker                           "builtin_options": {
89*89c4ff92SAndroid Build Coastguard Worker                           },
90*89c4ff92SAndroid Build Coastguard Worker                           "custom_options_format": "FLEXBUFFERS"
91*89c4ff92SAndroid Build Coastguard Worker                         }
92*89c4ff92SAndroid Build Coastguard Worker                       ]
93*89c4ff92SAndroid Build Coastguard Worker                     }
94*89c4ff92SAndroid Build Coastguard Worker                   ],
95*89c4ff92SAndroid Build Coastguard Worker                   "description": "TOCO Converted.",
96*89c4ff92SAndroid Build Coastguard Worker                   "buffers": [
97*89c4ff92SAndroid Build Coastguard Worker                     { },
98*89c4ff92SAndroid Build Coastguard Worker                     { })";
99*89c4ff92SAndroid Build Coastguard Worker         if (!permuteData.empty())
100*89c4ff92SAndroid Build Coastguard Worker         {
101*89c4ff92SAndroid Build Coastguard Worker             m_JsonString += R"(,{"data": )" + permuteData + R"( })";
102*89c4ff92SAndroid Build Coastguard Worker         }
103*89c4ff92SAndroid Build Coastguard Worker         m_JsonString += R"(
104*89c4ff92SAndroid Build Coastguard Worker                   ]
105*89c4ff92SAndroid Build Coastguard Worker                 }
106*89c4ff92SAndroid Build Coastguard Worker         )";
107*89c4ff92SAndroid Build Coastguard Worker         Setup();
108*89c4ff92SAndroid Build Coastguard Worker     }
109*89c4ff92SAndroid Build Coastguard Worker };
110*89c4ff92SAndroid Build Coastguard Worker 
111*89c4ff92SAndroid Build Coastguard Worker // Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation.
112*89c4ff92SAndroid Build Coastguard Worker struct TransposeFixtureWithPermuteData : TransposeFixture
113*89c4ff92SAndroid Build Coastguard Worker {
TransposeFixtureWithPermuteDataTransposeFixtureWithPermuteData114*89c4ff92SAndroid Build Coastguard Worker     TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
115*89c4ff92SAndroid Build Coastguard Worker                                                          "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]",
116*89c4ff92SAndroid Build Coastguard Worker                                                          "[ 2, 3, 2 ]") {}
117*89c4ff92SAndroid Build Coastguard Worker };
118*89c4ff92SAndroid Build Coastguard Worker 
119*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TransposeFixtureWithPermuteData, "TransposeWithPermuteData")
120*89c4ff92SAndroid Build Coastguard Worker {
121*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::Float32>(
122*89c4ff92SAndroid Build Coastguard Worker       0,
123*89c4ff92SAndroid Build Coastguard Worker       {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
124*89c4ff92SAndroid Build Coastguard Worker       {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}});
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
127*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({2,3,2})));
128*89c4ff92SAndroid Build Coastguard Worker }
129*89c4ff92SAndroid Build Coastguard Worker 
130*89c4ff92SAndroid Build Coastguard Worker // Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0],
131*89c4ff92SAndroid Build Coastguard Worker // where n is the number of dimensions of the input tensor
132*89c4ff92SAndroid Build Coastguard Worker // In this case we should get output shape 3,2,2 given default permutation vector 2,1,0
133*89c4ff92SAndroid Build Coastguard Worker struct TransposeFixtureWithoutPermuteData : TransposeFixture
134*89c4ff92SAndroid Build Coastguard Worker {
TransposeFixtureWithoutPermuteDataTransposeFixtureWithoutPermuteData135*89c4ff92SAndroid Build Coastguard Worker     TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]",
136*89c4ff92SAndroid Build Coastguard Worker                                                             "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]",
137*89c4ff92SAndroid Build Coastguard Worker                                                             "[ 3, 2, 2 ]") {}
138*89c4ff92SAndroid Build Coastguard Worker };
139*89c4ff92SAndroid Build Coastguard Worker 
140*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TransposeFixtureWithoutPermuteData, "TransposeWithoutPermuteDims")
141*89c4ff92SAndroid Build Coastguard Worker {
142*89c4ff92SAndroid Build Coastguard Worker     RunTest<3, armnn::DataType::Float32>(
143*89c4ff92SAndroid Build Coastguard Worker         0,
144*89c4ff92SAndroid Build Coastguard Worker         {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}},
145*89c4ff92SAndroid Build Coastguard Worker         {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}});
146*89c4ff92SAndroid Build Coastguard Worker 
147*89c4ff92SAndroid Build Coastguard Worker     CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape()
148*89c4ff92SAndroid Build Coastguard Worker                 == armnn::TensorShape({3,2,2})));
149*89c4ff92SAndroid Build Coastguard Worker }
150*89c4ff92SAndroid Build Coastguard Worker 
151*89c4ff92SAndroid Build Coastguard Worker }