xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/TransposeAsReshape.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2020 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 #pragma once
6 
7 #include "Optimization.hpp"
8 
9 namespace armnn
10 {
11 namespace optimizations
12 {
13 
14 class TransposeAsReshapeImpl
15 {
16 public:
17     /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
Run(Graph & graph,TransposeLayer & transpose) const18     void Run(Graph& graph, TransposeLayer& transpose) const
19     {
20         if (IsReshape(transpose))
21         {
22             const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo();
23 
24             const std::string name = std::string("as_reshape-") + transpose.GetName();
25             const ReshapeDescriptor descriptor{outInfo.GetShape()};
26             // Inserts NewLayer so layers don't need to be re-sorted.
27             auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str());
28 
29             // Bypass transpose. It will be deleted since it's left unconnected.
30             transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31         }
32     }
33 
34 protected:
35     TransposeAsReshapeImpl() = default;
36     ~TransposeAsReshapeImpl() = default;
37 
38 private:
IsReshape(const TransposeLayer & layer)39     static bool IsReshape(const TransposeLayer& layer)
40     {
41         const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42         const PermutationVector& permutation = layer.GetPermutation();
43 
44         const unsigned int numDimensions = permutation.GetSize();
45         std::map<unsigned int, unsigned int> permuteMappings;
46         for (unsigned int i = 0; i < permutation.GetSize(); ++i)
47         {
48             permuteMappings[permutation[i]] = i;
49         }
50 
51         std::vector<unsigned int> permuteVector;
52         for (unsigned int i = 0; i < permutation.GetSize(); ++i)
53         {
54             permuteVector.push_back(permuteMappings.at(i));
55         }
56 
57         unsigned int lastGtOne = 0;
58         while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U))
59         {
60             ++lastGtOne;
61         }
62 
63         bool isReshape = true;
64         for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
65         {
66             if (outShape[permuteVector[i]] > 1U)
67             {
68                 isReshape = permuteVector[lastGtOne] < permuteVector[i];
69                 lastGtOne = i;
70             }
71         }
72 
73         return isReshape;
74     }
75 };
76 
77 using TransposeAsReshape = OptimizeForType<TransposeLayer, TransposeAsReshapeImpl>;
78 
79 } // namespace optimizations
80 } // namespace armnn
81