1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2019 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "ParserFlatbuffersFixture.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_Transpose") 9*89c4ff92SAndroid Build Coastguard Worker { 10*89c4ff92SAndroid Build Coastguard Worker struct TransposeFixture : public ParserFlatbuffersFixture 11*89c4ff92SAndroid Build Coastguard Worker { TransposeFixtureTransposeFixture12*89c4ff92SAndroid Build Coastguard Worker explicit TransposeFixture(const std::string & inputShape, 13*89c4ff92SAndroid Build Coastguard Worker const std::string & permuteData, 14*89c4ff92SAndroid Build Coastguard Worker const std::string & outputShape) 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker m_JsonString = R"( 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker "version": 3, 19*89c4ff92SAndroid Build Coastguard Worker "operator_codes": [ 20*89c4ff92SAndroid Build Coastguard Worker { 21*89c4ff92SAndroid Build Coastguard Worker "builtin_code": "TRANSPOSE", 22*89c4ff92SAndroid Build Coastguard Worker "version": 1 23*89c4ff92SAndroid Build Coastguard Worker } 24*89c4ff92SAndroid Build Coastguard Worker ], 25*89c4ff92SAndroid Build Coastguard Worker "subgraphs": [ 26*89c4ff92SAndroid Build Coastguard Worker { 27*89c4ff92SAndroid Build Coastguard Worker "tensors": [ 28*89c4ff92SAndroid Build Coastguard Worker { 29*89c4ff92SAndroid Build Coastguard Worker "shape": )" + inputShape + R"(, 30*89c4ff92SAndroid Build Coastguard Worker "type": "FLOAT32", 31*89c4ff92SAndroid Build Coastguard Worker "buffer": 0, 32*89c4ff92SAndroid Build Coastguard Worker "name": "inputTensor", 33*89c4ff92SAndroid Build Coastguard Worker "quantization": { 34*89c4ff92SAndroid Build Coastguard Worker "min": [ 35*89c4ff92SAndroid Build Coastguard Worker 0.0 36*89c4ff92SAndroid Build Coastguard Worker ], 37*89c4ff92SAndroid Build Coastguard Worker "max": [ 38*89c4ff92SAndroid Build Coastguard Worker 255.0 39*89c4ff92SAndroid Build Coastguard Worker ], 40*89c4ff92SAndroid Build Coastguard Worker "details_type": 0, 41*89c4ff92SAndroid Build Coastguard Worker "quantized_dimension": 0 42*89c4ff92SAndroid Build Coastguard Worker }, 43*89c4ff92SAndroid Build Coastguard Worker "is_variable": false 44*89c4ff92SAndroid Build Coastguard Worker }, 45*89c4ff92SAndroid Build Coastguard Worker { 46*89c4ff92SAndroid Build Coastguard Worker "shape": )" + outputShape + R"(, 47*89c4ff92SAndroid Build Coastguard Worker "type": "FLOAT32", 48*89c4ff92SAndroid Build Coastguard Worker "buffer": 1, 49*89c4ff92SAndroid Build Coastguard Worker "name": "outputTensor", 50*89c4ff92SAndroid Build Coastguard Worker "quantization": { 51*89c4ff92SAndroid Build Coastguard Worker "details_type": 0, 52*89c4ff92SAndroid Build Coastguard Worker "quantized_dimension": 0 53*89c4ff92SAndroid Build Coastguard Worker }, 54*89c4ff92SAndroid Build Coastguard Worker "is_variable": false 55*89c4ff92SAndroid Build Coastguard Worker })"; 56*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"(, 57*89c4ff92SAndroid Build Coastguard Worker { 58*89c4ff92SAndroid Build Coastguard Worker "shape": [ 59*89c4ff92SAndroid Build Coastguard Worker 3 60*89c4ff92SAndroid Build Coastguard Worker ], 61*89c4ff92SAndroid Build Coastguard Worker "type": "INT32", 62*89c4ff92SAndroid Build Coastguard Worker "buffer": 2, 63*89c4ff92SAndroid Build Coastguard Worker "name": "permuteTensor", 64*89c4ff92SAndroid Build Coastguard Worker "quantization": { 65*89c4ff92SAndroid Build Coastguard Worker "details_type": 0, 66*89c4ff92SAndroid Build Coastguard Worker "quantized_dimension": 0 67*89c4ff92SAndroid Build Coastguard Worker }, 68*89c4ff92SAndroid Build Coastguard Worker "is_variable": false 69*89c4ff92SAndroid Build Coastguard Worker })"; 70*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"(], 71*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 72*89c4ff92SAndroid Build Coastguard Worker 0 73*89c4ff92SAndroid Build Coastguard Worker ], 74*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 75*89c4ff92SAndroid Build Coastguard Worker 1 76*89c4ff92SAndroid Build Coastguard Worker ], 77*89c4ff92SAndroid Build Coastguard Worker "operators": [ 78*89c4ff92SAndroid Build Coastguard Worker { 79*89c4ff92SAndroid Build Coastguard Worker "opcode_index": 0, 80*89c4ff92SAndroid Build Coastguard Worker "inputs": [ 81*89c4ff92SAndroid Build Coastguard Worker 0)"; 82*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"(,2)"; 83*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"(], 84*89c4ff92SAndroid Build Coastguard Worker "outputs": [ 85*89c4ff92SAndroid Build Coastguard Worker 1 86*89c4ff92SAndroid Build Coastguard Worker ], 87*89c4ff92SAndroid Build Coastguard Worker "builtin_options_type": "TransposeOptions", 88*89c4ff92SAndroid Build Coastguard Worker "builtin_options": { 89*89c4ff92SAndroid Build Coastguard Worker }, 90*89c4ff92SAndroid Build Coastguard Worker "custom_options_format": "FLEXBUFFERS" 91*89c4ff92SAndroid Build Coastguard Worker } 92*89c4ff92SAndroid Build Coastguard Worker ] 93*89c4ff92SAndroid Build Coastguard Worker } 94*89c4ff92SAndroid Build Coastguard Worker ], 95*89c4ff92SAndroid Build Coastguard Worker "description": "TOCO Converted.", 96*89c4ff92SAndroid Build Coastguard Worker "buffers": [ 97*89c4ff92SAndroid Build Coastguard Worker { }, 98*89c4ff92SAndroid Build Coastguard Worker { })"; 99*89c4ff92SAndroid Build Coastguard Worker if (!permuteData.empty()) 100*89c4ff92SAndroid Build Coastguard Worker { 101*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"(,{"data": )" + permuteData + R"( })"; 102*89c4ff92SAndroid Build Coastguard Worker } 103*89c4ff92SAndroid Build Coastguard Worker m_JsonString += R"( 104*89c4ff92SAndroid Build Coastguard Worker ] 105*89c4ff92SAndroid Build Coastguard Worker } 106*89c4ff92SAndroid Build Coastguard Worker )"; 107*89c4ff92SAndroid Build Coastguard Worker Setup(); 108*89c4ff92SAndroid Build Coastguard Worker } 109*89c4ff92SAndroid Build Coastguard Worker }; 110*89c4ff92SAndroid Build Coastguard Worker 111*89c4ff92SAndroid Build Coastguard Worker // Note that this assumes the Tensorflow permutation vector implementation as opposed to the armnn implemenation. 112*89c4ff92SAndroid Build Coastguard Worker struct TransposeFixtureWithPermuteData : TransposeFixture 113*89c4ff92SAndroid Build Coastguard Worker { TransposeFixtureWithPermuteDataTransposeFixtureWithPermuteData114*89c4ff92SAndroid Build Coastguard Worker TransposeFixtureWithPermuteData() : TransposeFixture("[ 2, 2, 3 ]", 115*89c4ff92SAndroid Build Coastguard Worker "[ 0, 0, 0, 0, 2, 0, 0, 0, 1, 0, 0, 0 ]", 116*89c4ff92SAndroid Build Coastguard Worker "[ 2, 3, 2 ]") {} 117*89c4ff92SAndroid Build Coastguard Worker }; 118*89c4ff92SAndroid Build Coastguard Worker 119*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TransposeFixtureWithPermuteData, "TransposeWithPermuteData") 120*89c4ff92SAndroid Build Coastguard Worker { 121*89c4ff92SAndroid Build Coastguard Worker RunTest<3, armnn::DataType::Float32>( 122*89c4ff92SAndroid Build Coastguard Worker 0, 123*89c4ff92SAndroid Build Coastguard Worker {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}, 124*89c4ff92SAndroid Build Coastguard Worker {{"outputTensor", { 1, 4, 2, 5, 3, 6, 7, 10, 8, 11, 9, 12 }}}); 125*89c4ff92SAndroid Build Coastguard Worker 126*89c4ff92SAndroid Build Coastguard Worker CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() 127*89c4ff92SAndroid Build Coastguard Worker == armnn::TensorShape({2,3,2}))); 128*89c4ff92SAndroid Build Coastguard Worker } 129*89c4ff92SAndroid Build Coastguard Worker 130*89c4ff92SAndroid Build Coastguard Worker // Tensorflow default permutation behavior assumes no permute argument will create permute vector [n-1...0], 131*89c4ff92SAndroid Build Coastguard Worker // where n is the number of dimensions of the input tensor 132*89c4ff92SAndroid Build Coastguard Worker // In this case we should get output shape 3,2,2 given default permutation vector 2,1,0 133*89c4ff92SAndroid Build Coastguard Worker struct TransposeFixtureWithoutPermuteData : TransposeFixture 134*89c4ff92SAndroid Build Coastguard Worker { TransposeFixtureWithoutPermuteDataTransposeFixtureWithoutPermuteData135*89c4ff92SAndroid Build Coastguard Worker TransposeFixtureWithoutPermuteData() : TransposeFixture("[ 2, 2, 3 ]", 136*89c4ff92SAndroid Build Coastguard Worker "[ 2, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0 ]", 137*89c4ff92SAndroid Build Coastguard Worker "[ 3, 2, 2 ]") {} 138*89c4ff92SAndroid Build Coastguard Worker }; 139*89c4ff92SAndroid Build Coastguard Worker 140*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TransposeFixtureWithoutPermuteData, "TransposeWithoutPermuteDims") 141*89c4ff92SAndroid Build Coastguard Worker { 142*89c4ff92SAndroid Build Coastguard Worker RunTest<3, armnn::DataType::Float32>( 143*89c4ff92SAndroid Build Coastguard Worker 0, 144*89c4ff92SAndroid Build Coastguard Worker {{"inputTensor", { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12 }}}, 145*89c4ff92SAndroid Build Coastguard Worker {{"outputTensor", { 1, 7, 4, 10, 2, 8, 5, 11, 3, 9, 6, 12 }}}); 146*89c4ff92SAndroid Build Coastguard Worker 147*89c4ff92SAndroid Build Coastguard Worker CHECK((m_Parser->GetNetworkOutputBindingInfo(0, "outputTensor").second.GetShape() 148*89c4ff92SAndroid Build Coastguard Worker == armnn::TensorShape({3,2,2}))); 149*89c4ff92SAndroid Build Coastguard Worker } 150*89c4ff92SAndroid Build Coastguard Worker 151*89c4ff92SAndroid Build Coastguard Worker }