1 // 2 // Copyright © 2020 Arm Ltd. All rights reserved. 3 // SPDX-License-Identifier: MIT 4 // 5 #pragma once 6 7 #include "Optimization.hpp" 8 9 namespace armnn 10 { 11 namespace optimizations 12 { 13 14 class TransposeAsReshapeImpl 15 { 16 public: 17 /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent. Run(Graph & graph,TransposeLayer & transpose) const18 void Run(Graph& graph, TransposeLayer& transpose) const 19 { 20 if (IsReshape(transpose)) 21 { 22 const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo(); 23 24 const std::string name = std::string("as_reshape-") + transpose.GetName(); 25 const ReshapeDescriptor descriptor{outInfo.GetShape()}; 26 // Inserts NewLayer so layers don't need to be re-sorted. 27 auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str()); 28 29 // Bypass transpose. It will be deleted since it's left unconnected. 30 transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot()); 31 } 32 } 33 34 protected: 35 TransposeAsReshapeImpl() = default; 36 ~TransposeAsReshapeImpl() = default; 37 38 private: IsReshape(const TransposeLayer & layer)39 static bool IsReshape(const TransposeLayer& layer) 40 { 41 const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape(); 42 const PermutationVector& permutation = layer.GetPermutation(); 43 44 const unsigned int numDimensions = permutation.GetSize(); 45 std::map<unsigned int, unsigned int> permuteMappings; 46 for (unsigned int i = 0; i < permutation.GetSize(); ++i) 47 { 48 permuteMappings[permutation[i]] = i; 49 } 50 51 std::vector<unsigned int> permuteVector; 52 for (unsigned int i = 0; i < permutation.GetSize(); ++i) 53 { 54 permuteVector.push_back(permuteMappings.at(i)); 55 } 56 57 unsigned int lastGtOne = 0; 58 while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U)) 59 { 60 ++lastGtOne; 61 } 62 63 bool isReshape = true; 64 for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i) 65 { 66 if (outShape[permuteVector[i]] > 1U) 67 { 68 isReshape = permuteVector[lastGtOne] < permuteVector[i]; 69 lastGtOne = i; 70 } 71 } 72 73 return isReshape; 74 } 75 }; 76 77 using TransposeAsReshape = OptimizeForType<TransposeLayer, TransposeAsReshapeImpl>; 78 79 } // namespace optimizations 80 } // namespace armnn 81