1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp" 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker namespace armnn 10*89c4ff92SAndroid Build Coastguard Worker { 11*89c4ff92SAndroid Build Coastguard Worker namespace optimizations 12*89c4ff92SAndroid Build Coastguard Worker { 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker class TransposeAsReshapeImpl 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker public: 17*89c4ff92SAndroid Build Coastguard Worker /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent. Run(Graph & graph,TransposeLayer & transpose) const18*89c4ff92SAndroid Build Coastguard Worker void Run(Graph& graph, TransposeLayer& transpose) const 19*89c4ff92SAndroid Build Coastguard Worker { 20*89c4ff92SAndroid Build Coastguard Worker if (IsReshape(transpose)) 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo(); 23*89c4ff92SAndroid Build Coastguard Worker 24*89c4ff92SAndroid Build Coastguard Worker const std::string name = std::string("as_reshape-") + transpose.GetName(); 25*89c4ff92SAndroid Build Coastguard Worker const ReshapeDescriptor descriptor{outInfo.GetShape()}; 26*89c4ff92SAndroid Build Coastguard Worker // Inserts NewLayer so layers don't need to be re-sorted. 27*89c4ff92SAndroid Build Coastguard Worker auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str()); 28*89c4ff92SAndroid Build Coastguard Worker 29*89c4ff92SAndroid Build Coastguard Worker // Bypass transpose. It will be deleted since it's left unconnected. 30*89c4ff92SAndroid Build Coastguard Worker transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot()); 31*89c4ff92SAndroid Build Coastguard Worker } 32*89c4ff92SAndroid Build Coastguard Worker } 33*89c4ff92SAndroid Build Coastguard Worker 34*89c4ff92SAndroid Build Coastguard Worker protected: 35*89c4ff92SAndroid Build Coastguard Worker TransposeAsReshapeImpl() = default; 36*89c4ff92SAndroid Build Coastguard Worker ~TransposeAsReshapeImpl() = default; 37*89c4ff92SAndroid Build Coastguard Worker 38*89c4ff92SAndroid Build Coastguard Worker private: IsReshape(const TransposeLayer & layer)39*89c4ff92SAndroid Build Coastguard Worker static bool IsReshape(const TransposeLayer& layer) 40*89c4ff92SAndroid Build Coastguard Worker { 41*89c4ff92SAndroid Build Coastguard Worker const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape(); 42*89c4ff92SAndroid Build Coastguard Worker const PermutationVector& permutation = layer.GetPermutation(); 43*89c4ff92SAndroid Build Coastguard Worker 44*89c4ff92SAndroid Build Coastguard Worker const unsigned int numDimensions = permutation.GetSize(); 45*89c4ff92SAndroid Build Coastguard Worker std::map<unsigned int, unsigned int> permuteMappings; 46*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < permutation.GetSize(); ++i) 47*89c4ff92SAndroid Build Coastguard Worker { 48*89c4ff92SAndroid Build Coastguard Worker permuteMappings[permutation[i]] = i; 49*89c4ff92SAndroid Build Coastguard Worker } 50*89c4ff92SAndroid Build Coastguard Worker 51*89c4ff92SAndroid Build Coastguard Worker std::vector<unsigned int> permuteVector; 52*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = 0; i < permutation.GetSize(); ++i) 53*89c4ff92SAndroid Build Coastguard Worker { 54*89c4ff92SAndroid Build Coastguard Worker permuteVector.push_back(permuteMappings.at(i)); 55*89c4ff92SAndroid Build Coastguard Worker } 56*89c4ff92SAndroid Build Coastguard Worker 57*89c4ff92SAndroid Build Coastguard Worker unsigned int lastGtOne = 0; 58*89c4ff92SAndroid Build Coastguard Worker while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U)) 59*89c4ff92SAndroid Build Coastguard Worker { 60*89c4ff92SAndroid Build Coastguard Worker ++lastGtOne; 61*89c4ff92SAndroid Build Coastguard Worker } 62*89c4ff92SAndroid Build Coastguard Worker 63*89c4ff92SAndroid Build Coastguard Worker bool isReshape = true; 64*89c4ff92SAndroid Build Coastguard Worker for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i) 65*89c4ff92SAndroid Build Coastguard Worker { 66*89c4ff92SAndroid Build Coastguard Worker if (outShape[permuteVector[i]] > 1U) 67*89c4ff92SAndroid Build Coastguard Worker { 68*89c4ff92SAndroid Build Coastguard Worker isReshape = permuteVector[lastGtOne] < permuteVector[i]; 69*89c4ff92SAndroid Build Coastguard Worker lastGtOne = i; 70*89c4ff92SAndroid Build Coastguard Worker } 71*89c4ff92SAndroid Build Coastguard Worker } 72*89c4ff92SAndroid Build Coastguard Worker 73*89c4ff92SAndroid Build Coastguard Worker return isReshape; 74*89c4ff92SAndroid Build Coastguard Worker } 75*89c4ff92SAndroid Build Coastguard Worker }; 76*89c4ff92SAndroid Build Coastguard Worker 77*89c4ff92SAndroid Build Coastguard Worker using TransposeAsReshape = OptimizeForType<TransposeLayer, TransposeAsReshapeImpl>; 78*89c4ff92SAndroid Build Coastguard Worker 79*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations 80*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn 81