1*3e777be0SXin Li // 2*3e777be0SXin Li // Copyright © 2017 Arm Ltd. All rights reserved. 3*3e777be0SXin Li // SPDX-License-Identifier: MIT 4*3e777be0SXin Li // 5*3e777be0SXin Li 6*3e777be0SXin Li #pragma once 7*3e777be0SXin Li 8*3e777be0SXin Li #include <armnn/Tensor.hpp> 9*3e777be0SXin Li 10*3e777be0SXin Li #include "../ConversionUtils.hpp" 11*3e777be0SXin Li 12*3e777be0SXin Li namespace armnn_driver 13*3e777be0SXin Li { 14*3e777be0SXin Li FlattenFullyConnectedInput(const armnn::TensorShape & inputShape,const armnn::TensorShape & weightsShape)15*3e777be0SXin Liinline armnn::TensorShape FlattenFullyConnectedInput(const armnn::TensorShape& inputShape, 16*3e777be0SXin Li const armnn::TensorShape& weightsShape) 17*3e777be0SXin Li { 18*3e777be0SXin Li if (inputShape.GetNumDimensions() > 2U) 19*3e777be0SXin Li { 20*3e777be0SXin Li unsigned int totalInputElements = inputShape.GetNumElements(); 21*3e777be0SXin Li unsigned int inputSize = weightsShape[1]; 22*3e777be0SXin Li 23*3e777be0SXin Li unsigned int batchSize = totalInputElements / inputSize; 24*3e777be0SXin Li 25*3e777be0SXin Li if(totalInputElements % batchSize != 0) 26*3e777be0SXin Li { 27*3e777be0SXin Li throw std::runtime_error("Failed to deduce tensor shape"); 28*3e777be0SXin Li } 29*3e777be0SXin Li 30*3e777be0SXin Li return armnn::TensorShape({batchSize, inputSize}); 31*3e777be0SXin Li } 32*3e777be0SXin Li else 33*3e777be0SXin Li { 34*3e777be0SXin Li return inputShape; 35*3e777be0SXin Li } 36*3e777be0SXin Li } 37*3e777be0SXin Li VerifyFullyConnectedShapes(const armnn::TensorShape & inputShape,const armnn::TensorShape & weightsShape,const armnn::TensorShape & outputShape,bool transposeWeightMatrix)38*3e777be0SXin Liinline bool VerifyFullyConnectedShapes(const armnn::TensorShape& inputShape, 39*3e777be0SXin Li const armnn::TensorShape& weightsShape, 40*3e777be0SXin Li const armnn::TensorShape& outputShape, 41*3e777be0SXin Li bool transposeWeightMatrix) 42*3e777be0SXin Li { 43*3e777be0SXin Li unsigned int dimIdx = transposeWeightMatrix ? 0 : 1; 44*3e777be0SXin Li return (inputShape[0] == outputShape[0] && weightsShape[dimIdx] == outputShape[1]); 45*3e777be0SXin Li } 46*3e777be0SXin Li 47*3e777be0SXin Li }