1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #include "../TfLiteParser.hpp" 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h> 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze") 11*89c4ff92SAndroid Build Coastguard Worker { 12*89c4ff92SAndroid Build Coastguard Worker 13*89c4ff92SAndroid Build Coastguard Worker struct TfLiteParserFixture 14*89c4ff92SAndroid Build Coastguard Worker { 15*89c4ff92SAndroid Build Coastguard Worker 16*89c4ff92SAndroid Build Coastguard Worker armnnTfLiteParser::TfLiteParserImpl m_Parser; 17*89c4ff92SAndroid Build Coastguard Worker unsigned int m_InputShape[4]; 18*89c4ff92SAndroid Build Coastguard Worker TfLiteParserFixtureTfLiteParserFixture19*89c4ff92SAndroid Build Coastguard Worker TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {} ~TfLiteParserFixtureTfLiteParserFixture20*89c4ff92SAndroid Build Coastguard Worker ~TfLiteParserFixture() { } 21*89c4ff92SAndroid Build Coastguard Worker 22*89c4ff92SAndroid Build Coastguard Worker }; 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TfLiteParserFixture, "EmptySqueezeDims_OutputWithAllDimensionsSqueezed") 25*89c4ff92SAndroid Build Coastguard Worker { 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> squeezeDims = { }; 28*89c4ff92SAndroid Build Coastguard Worker 29*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32); 30*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo); 31*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo.GetNumElements() == 4); 32*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo.GetNumDimensions() == 2); 33*89c4ff92SAndroid Build Coastguard Worker CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 }))); 34*89c4ff92SAndroid Build Coastguard Worker }; 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput") 37*89c4ff92SAndroid Build Coastguard Worker { 38*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> squeezeDims = { 1, 2 }; 39*89c4ff92SAndroid Build Coastguard Worker 40*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32); 41*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo); 42*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo.GetNumElements() == 4); 43*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo.GetNumDimensions() == 4); 44*89c4ff92SAndroid Build Coastguard Worker CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 }))); 45*89c4ff92SAndroid Build Coastguard Worker }; 46*89c4ff92SAndroid Build Coastguard Worker 47*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed") 48*89c4ff92SAndroid Build Coastguard Worker { 49*89c4ff92SAndroid Build Coastguard Worker std::vector<uint32_t> squeezeDims = { 1, 3 }; 50*89c4ff92SAndroid Build Coastguard Worker 51*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32); 52*89c4ff92SAndroid Build Coastguard Worker armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo); 53*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo.GetNumElements() == 4); 54*89c4ff92SAndroid Build Coastguard Worker CHECK(outputTensorInfo.GetNumDimensions() == 3); 55*89c4ff92SAndroid Build Coastguard Worker CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 }))); 56*89c4ff92SAndroid Build Coastguard Worker }; 57*89c4ff92SAndroid Build Coastguard Worker 58*89c4ff92SAndroid Build Coastguard Worker }