xref: /aosp_15_r20/external/armnn/src/armnnTfLiteParser/test/OutputShapeOfSqueeze.cpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #include "../TfLiteParser.hpp"
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <doctest/doctest.h>
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker TEST_SUITE("TensorflowLiteParser_OutputShapeOfSqueeze")
11*89c4ff92SAndroid Build Coastguard Worker {
12*89c4ff92SAndroid Build Coastguard Worker 
13*89c4ff92SAndroid Build Coastguard Worker struct TfLiteParserFixture
14*89c4ff92SAndroid Build Coastguard Worker {
15*89c4ff92SAndroid Build Coastguard Worker 
16*89c4ff92SAndroid Build Coastguard Worker     armnnTfLiteParser::TfLiteParserImpl m_Parser;
17*89c4ff92SAndroid Build Coastguard Worker     unsigned int m_InputShape[4];
18*89c4ff92SAndroid Build Coastguard Worker 
TfLiteParserFixtureTfLiteParserFixture19*89c4ff92SAndroid Build Coastguard Worker     TfLiteParserFixture() : m_Parser( ), m_InputShape { 1, 2, 2, 1 } {}
~TfLiteParserFixtureTfLiteParserFixture20*89c4ff92SAndroid Build Coastguard Worker     ~TfLiteParserFixture()          {  }
21*89c4ff92SAndroid Build Coastguard Worker 
22*89c4ff92SAndroid Build Coastguard Worker };
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TfLiteParserFixture, "EmptySqueezeDims_OutputWithAllDimensionsSqueezed")
25*89c4ff92SAndroid Build Coastguard Worker {
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint32_t> squeezeDims = {  };
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
30*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
31*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputTensorInfo.GetNumElements() == 4);
32*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputTensorInfo.GetNumDimensions() == 2);
33*89c4ff92SAndroid Build Coastguard Worker     CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 2, 2 })));
34*89c4ff92SAndroid Build Coastguard Worker };
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsNotIncludingSizeOneDimensions_NoDimensionsSqueezedInOutput")
37*89c4ff92SAndroid Build Coastguard Worker {
38*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint32_t> squeezeDims = { 1, 2 };
39*89c4ff92SAndroid Build Coastguard Worker 
40*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
41*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
42*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputTensorInfo.GetNumElements() == 4);
43*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputTensorInfo.GetNumDimensions() == 4);
44*89c4ff92SAndroid Build Coastguard Worker     CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2, 1 })));
45*89c4ff92SAndroid Build Coastguard Worker };
46*89c4ff92SAndroid Build Coastguard Worker 
47*89c4ff92SAndroid Build Coastguard Worker TEST_CASE_FIXTURE(TfLiteParserFixture, "SqueezeDimsRangePartial_OutputWithDimensionsWithinRangeSqueezed")
48*89c4ff92SAndroid Build Coastguard Worker {
49*89c4ff92SAndroid Build Coastguard Worker     std::vector<uint32_t> squeezeDims = { 1, 3 };
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo inputTensorInfo = armnn::TensorInfo(4, m_InputShape, armnn::DataType::Float32);
52*89c4ff92SAndroid Build Coastguard Worker     armnn::TensorInfo outputTensorInfo = m_Parser.OutputShapeOfSqueeze(squeezeDims, inputTensorInfo);
53*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputTensorInfo.GetNumElements() == 4);
54*89c4ff92SAndroid Build Coastguard Worker     CHECK(outputTensorInfo.GetNumDimensions() == 3);
55*89c4ff92SAndroid Build Coastguard Worker     CHECK((outputTensorInfo.GetShape() == armnn::TensorShape({ 1, 2, 2 })));
56*89c4ff92SAndroid Build Coastguard Worker };
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker }