xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/TransposeAsReshape.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2020 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker namespace armnn
10*89c4ff92SAndroid Build Coastguard Worker {
11*89c4ff92SAndroid Build Coastguard Worker namespace optimizations
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker class TransposeAsReshapeImpl
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker public:
17*89c4ff92SAndroid Build Coastguard Worker     /// Run for every TransposeLayer. Replaces it with a ReshapeLayer if they are equivalent.
Run(Graph & graph,TransposeLayer & transpose) const18*89c4ff92SAndroid Build Coastguard Worker     void Run(Graph& graph, TransposeLayer& transpose) const
19*89c4ff92SAndroid Build Coastguard Worker     {
20*89c4ff92SAndroid Build Coastguard Worker         if (IsReshape(transpose))
21*89c4ff92SAndroid Build Coastguard Worker         {
22*89c4ff92SAndroid Build Coastguard Worker             const TensorInfo& outInfo = transpose.GetOutputHandler().GetTensorInfo();
23*89c4ff92SAndroid Build Coastguard Worker 
24*89c4ff92SAndroid Build Coastguard Worker             const std::string name = std::string("as_reshape-") + transpose.GetName();
25*89c4ff92SAndroid Build Coastguard Worker             const ReshapeDescriptor descriptor{outInfo.GetShape()};
26*89c4ff92SAndroid Build Coastguard Worker             // Inserts NewLayer so layers don't need to be re-sorted.
27*89c4ff92SAndroid Build Coastguard Worker             auto reshape = graph.InsertNewLayer<ReshapeLayer>(transpose.GetInputSlot(0), descriptor, name.c_str());
28*89c4ff92SAndroid Build Coastguard Worker 
29*89c4ff92SAndroid Build Coastguard Worker             // Bypass transpose. It will be deleted since it's left unconnected.
30*89c4ff92SAndroid Build Coastguard Worker             transpose.GetOutputSlot().MoveAllConnections(reshape->GetOutputSlot());
31*89c4ff92SAndroid Build Coastguard Worker         }
32*89c4ff92SAndroid Build Coastguard Worker     }
33*89c4ff92SAndroid Build Coastguard Worker 
34*89c4ff92SAndroid Build Coastguard Worker protected:
35*89c4ff92SAndroid Build Coastguard Worker     TransposeAsReshapeImpl() = default;
36*89c4ff92SAndroid Build Coastguard Worker     ~TransposeAsReshapeImpl() = default;
37*89c4ff92SAndroid Build Coastguard Worker 
38*89c4ff92SAndroid Build Coastguard Worker private:
IsReshape(const TransposeLayer & layer)39*89c4ff92SAndroid Build Coastguard Worker     static bool IsReshape(const TransposeLayer& layer)
40*89c4ff92SAndroid Build Coastguard Worker     {
41*89c4ff92SAndroid Build Coastguard Worker         const TensorShape& outShape = layer.GetOutputHandler().GetTensorInfo().GetShape();
42*89c4ff92SAndroid Build Coastguard Worker         const PermutationVector& permutation = layer.GetPermutation();
43*89c4ff92SAndroid Build Coastguard Worker 
44*89c4ff92SAndroid Build Coastguard Worker         const unsigned int numDimensions = permutation.GetSize();
45*89c4ff92SAndroid Build Coastguard Worker         std::map<unsigned int, unsigned int> permuteMappings;
46*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < permutation.GetSize(); ++i)
47*89c4ff92SAndroid Build Coastguard Worker         {
48*89c4ff92SAndroid Build Coastguard Worker             permuteMappings[permutation[i]] = i;
49*89c4ff92SAndroid Build Coastguard Worker         }
50*89c4ff92SAndroid Build Coastguard Worker 
51*89c4ff92SAndroid Build Coastguard Worker         std::vector<unsigned int> permuteVector;
52*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = 0; i < permutation.GetSize(); ++i)
53*89c4ff92SAndroid Build Coastguard Worker         {
54*89c4ff92SAndroid Build Coastguard Worker             permuteVector.push_back(permuteMappings.at(i));
55*89c4ff92SAndroid Build Coastguard Worker         }
56*89c4ff92SAndroid Build Coastguard Worker 
57*89c4ff92SAndroid Build Coastguard Worker         unsigned int lastGtOne = 0;
58*89c4ff92SAndroid Build Coastguard Worker         while ((lastGtOne < numDimensions) && (outShape[(permuteVector[lastGtOne])] == 1U))
59*89c4ff92SAndroid Build Coastguard Worker         {
60*89c4ff92SAndroid Build Coastguard Worker             ++lastGtOne;
61*89c4ff92SAndroid Build Coastguard Worker         }
62*89c4ff92SAndroid Build Coastguard Worker 
63*89c4ff92SAndroid Build Coastguard Worker         bool isReshape = true;
64*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i = lastGtOne + 1U; isReshape && (i < numDimensions); ++i)
65*89c4ff92SAndroid Build Coastguard Worker         {
66*89c4ff92SAndroid Build Coastguard Worker             if (outShape[permuteVector[i]] > 1U)
67*89c4ff92SAndroid Build Coastguard Worker             {
68*89c4ff92SAndroid Build Coastguard Worker                 isReshape = permuteVector[lastGtOne] < permuteVector[i];
69*89c4ff92SAndroid Build Coastguard Worker                 lastGtOne = i;
70*89c4ff92SAndroid Build Coastguard Worker             }
71*89c4ff92SAndroid Build Coastguard Worker         }
72*89c4ff92SAndroid Build Coastguard Worker 
73*89c4ff92SAndroid Build Coastguard Worker         return isReshape;
74*89c4ff92SAndroid Build Coastguard Worker     }
75*89c4ff92SAndroid Build Coastguard Worker };
76*89c4ff92SAndroid Build Coastguard Worker 
77*89c4ff92SAndroid Build Coastguard Worker using TransposeAsReshape = OptimizeForType<TransposeLayer, TransposeAsReshapeImpl>;
78*89c4ff92SAndroid Build Coastguard Worker 
79*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations
80*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
81