xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/MovePermuteUp.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2018,2020,2023 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp"
8*89c4ff92SAndroid Build Coastguard Worker 
9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker namespace armnn
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker namespace optimizations
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker class MovePermuteUpImpl
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker public:
19*89c4ff92SAndroid Build Coastguard Worker     /// Run for every connection between a base Layer (any) and a child PermuteLayer. If the type
20*89c4ff92SAndroid Build Coastguard Worker     /// of the base layer allows it, it moves the permutation to the inputs of the base layer.
21*89c4ff92SAndroid Build Coastguard Worker     /// I.e., adds equivalent permutations before the inputs of the base layer and moves the
22*89c4ff92SAndroid Build Coastguard Worker     /// connections in the output of the child permute layer to the output of the base layer.
Run(Graph & graph,InputSlot & connection) const23*89c4ff92SAndroid Build Coastguard Worker     void Run(Graph& graph, InputSlot& connection) const
24*89c4ff92SAndroid Build Coastguard Worker     {
25*89c4ff92SAndroid Build Coastguard Worker         OutputSlot& baseOutput = *connection.GetConnectedOutputSlot();
26*89c4ff92SAndroid Build Coastguard Worker 
27*89c4ff92SAndroid Build Coastguard Worker         if (baseOutput.GetNumConnections() == 1U)
28*89c4ff92SAndroid Build Coastguard Worker         {
29*89c4ff92SAndroid Build Coastguard Worker             Layer& base = baseOutput.GetOwningLayer();
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker             if (CanMovePermuteToInputs(base))
32*89c4ff92SAndroid Build Coastguard Worker             {
33*89c4ff92SAndroid Build Coastguard Worker                 auto permute = PolymorphicDowncast<PermuteLayer*>(&connection.GetOwningLayer());
34*89c4ff92SAndroid Build Coastguard Worker                 const PermutationVector& perm = permute->GetPermutation();
35*89c4ff92SAndroid Build Coastguard Worker 
36*89c4ff92SAndroid Build Coastguard Worker                 // Inserts an equivalent permute before every input of the base layer.
37*89c4ff92SAndroid Build Coastguard Worker                 for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput)
38*89c4ff92SAndroid Build Coastguard Worker                 {
39*89c4ff92SAndroid Build Coastguard Worker                     // Inserts a new permute layer.
40*89c4ff92SAndroid Build Coastguard Worker                     const std::string name = std::string("moved_up-") + permute->GetName();
41*89c4ff92SAndroid Build Coastguard Worker                     PermuteLayer& permLayer = *graph.InsertNewLayer<PermuteLayer>(*baseInput, perm, name.c_str());
42*89c4ff92SAndroid Build Coastguard Worker 
43*89c4ff92SAndroid Build Coastguard Worker                     // Sets output tensor info for the new layer.
44*89c4ff92SAndroid Build Coastguard Worker                     OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot();
45*89c4ff92SAndroid Build Coastguard Worker                     const TensorInfo permOutInfo = armnnUtils::Permuted(parentOutput.GetTensorInfo(), perm);
46*89c4ff92SAndroid Build Coastguard Worker                     permLayer.GetOutputHandler().SetTensorInfo(permOutInfo);
47*89c4ff92SAndroid Build Coastguard Worker                 }
48*89c4ff92SAndroid Build Coastguard Worker 
49*89c4ff92SAndroid Build Coastguard Worker                 // Bypasses permute. It will be removed as it's left unconnected.
50*89c4ff92SAndroid Build Coastguard Worker                 permute->GetOutputSlot().MoveAllConnections(base.GetOutputSlot());
51*89c4ff92SAndroid Build Coastguard Worker             }
52*89c4ff92SAndroid Build Coastguard Worker         }
53*89c4ff92SAndroid Build Coastguard Worker     }
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker protected:
56*89c4ff92SAndroid Build Coastguard Worker     MovePermuteUpImpl() = default;
57*89c4ff92SAndroid Build Coastguard Worker     ~MovePermuteUpImpl() = default;
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker private:
CanMovePermuteToInputs(const Layer & base)60*89c4ff92SAndroid Build Coastguard Worker     static bool CanMovePermuteToInputs(const Layer& base)
61*89c4ff92SAndroid Build Coastguard Worker     {
62*89c4ff92SAndroid Build Coastguard Worker         switch (base.GetType())
63*89c4ff92SAndroid Build Coastguard Worker         {
64*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Activation:
65*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Addition:
66*89c4ff92SAndroid Build Coastguard Worker             case LayerType::FakeQuantization:
67*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Floor:
68*89c4ff92SAndroid Build Coastguard Worker             case LayerType::MemCopy:
69*89c4ff92SAndroid Build Coastguard Worker             case LayerType::Multiplication:
70*89c4ff92SAndroid Build Coastguard Worker                 return true;
71*89c4ff92SAndroid Build Coastguard Worker             case LayerType::ElementwiseBinary:
72*89c4ff92SAndroid Build Coastguard Worker             {
73*89c4ff92SAndroid Build Coastguard Worker                 auto descriptor = PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&base.GetParameters());
74*89c4ff92SAndroid Build Coastguard Worker                 return (descriptor->m_Operation == BinaryOperation::Add ||
75*89c4ff92SAndroid Build Coastguard Worker                         descriptor->m_Operation == BinaryOperation::Mul);
76*89c4ff92SAndroid Build Coastguard Worker             }
77*89c4ff92SAndroid Build Coastguard Worker             default:
78*89c4ff92SAndroid Build Coastguard Worker                 return false;
79*89c4ff92SAndroid Build Coastguard Worker         }
80*89c4ff92SAndroid Build Coastguard Worker     }
81*89c4ff92SAndroid Build Coastguard Worker };
82*89c4ff92SAndroid Build Coastguard Worker 
83*89c4ff92SAndroid Build Coastguard Worker using MovePermuteUp = OptimizeForConnection<Layer, PermuteLayer, MovePermuteUpImpl>;
84*89c4ff92SAndroid Build Coastguard Worker 
85*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations
86*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
87