1*3e777be0SXin Li // 2*3e777be0SXin Li // Copyright © 2017 Arm Ltd and Contributors. All rights reserved. 3*3e777be0SXin Li // SPDX-License-Identifier: MIT 4*3e777be0SXin Li // 5*3e777be0SXin Li 6*3e777be0SXin Li #include "../DriverTestHelpers.hpp" 7*3e777be0SXin Li 8*3e777be0SXin Li DOCTEST_TEST_SUITE("FullyConnectedReshapeTests") 9*3e777be0SXin Li { 10*3e777be0SXin Li DOCTEST_TEST_CASE("TestFlattenFullyConnectedInput") 11*3e777be0SXin Li { 12*3e777be0SXin Li using armnn::TensorShape; 13*3e777be0SXin Li 14*3e777be0SXin Li // Pass through 2d input 15*3e777be0SXin Li DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({2,2048}), 16*3e777be0SXin Li TensorShape({512, 2048})) == TensorShape({2, 2048})); 17*3e777be0SXin Li 18*3e777be0SXin Li // Trivial flattening of batched channels 19*3e777be0SXin Li DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({97,1,1,2048}), 20*3e777be0SXin Li TensorShape({512, 2048})) == TensorShape({97, 2048})); 21*3e777be0SXin Li 22*3e777be0SXin Li // Flatten single batch of rows 23*3e777be0SXin Li DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,97,1,2048}), 24*3e777be0SXin Li TensorShape({512, 2048})) == TensorShape({97, 2048})); 25*3e777be0SXin Li 26*3e777be0SXin Li // Flatten single batch of columns 27*3e777be0SXin Li DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,1,97,2048}), 28*3e777be0SXin Li TensorShape({512, 2048})) == TensorShape({97, 2048})); 29*3e777be0SXin Li 30*3e777be0SXin Li // Move batches into input dimension 31*3e777be0SXin Li DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({50,1,1,10}), 32*3e777be0SXin Li TensorShape({512, 20})) == TensorShape({25, 20})); 33*3e777be0SXin Li 34*3e777be0SXin Li // Flatten single batch of 3D data (e.g. convolution output) 35*3e777be0SXin Li DOCTEST_CHECK(FlattenFullyConnectedInput(TensorShape({1,16,16,10}), 36*3e777be0SXin Li TensorShape({512, 2560})) == TensorShape({1, 2560})); 37*3e777be0SXin Li } 38*3e777be0SXin Li 39*3e777be0SXin Li } 40