1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017-2018,2020,2023 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp" 8*89c4ff92SAndroid Build Coastguard Worker 9*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/PolymorphicDowncast.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker namespace armnn 13*89c4ff92SAndroid Build Coastguard Worker { 14*89c4ff92SAndroid Build Coastguard Worker namespace optimizations 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker class MovePermuteUpImpl 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker public: 19*89c4ff92SAndroid Build Coastguard Worker /// Run for every connection between a base Layer (any) and a child PermuteLayer. If the type 20*89c4ff92SAndroid Build Coastguard Worker /// of the base layer allows it, it moves the permutation to the inputs of the base layer. 21*89c4ff92SAndroid Build Coastguard Worker /// I.e., adds equivalent permutations before the inputs of the base layer and moves the 22*89c4ff92SAndroid Build Coastguard Worker /// connections in the output of the child permute layer to the output of the base layer. Run(Graph & graph,InputSlot & connection) const23*89c4ff92SAndroid Build Coastguard Worker void Run(Graph& graph, InputSlot& connection) const 24*89c4ff92SAndroid Build Coastguard Worker { 25*89c4ff92SAndroid Build Coastguard Worker OutputSlot& baseOutput = *connection.GetConnectedOutputSlot(); 26*89c4ff92SAndroid Build Coastguard Worker 27*89c4ff92SAndroid Build Coastguard Worker if (baseOutput.GetNumConnections() == 1U) 28*89c4ff92SAndroid Build Coastguard Worker { 29*89c4ff92SAndroid Build Coastguard Worker Layer& base = baseOutput.GetOwningLayer(); 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker if (CanMovePermuteToInputs(base)) 32*89c4ff92SAndroid Build Coastguard Worker { 33*89c4ff92SAndroid Build Coastguard Worker auto permute = PolymorphicDowncast<PermuteLayer*>(&connection.GetOwningLayer()); 34*89c4ff92SAndroid Build Coastguard Worker const PermutationVector& perm = permute->GetPermutation(); 35*89c4ff92SAndroid Build Coastguard Worker 36*89c4ff92SAndroid Build Coastguard Worker // Inserts an equivalent permute before every input of the base layer. 37*89c4ff92SAndroid Build Coastguard Worker for (auto baseInput = base.BeginInputSlots(); baseInput != base.EndInputSlots(); ++baseInput) 38*89c4ff92SAndroid Build Coastguard Worker { 39*89c4ff92SAndroid Build Coastguard Worker // Inserts a new permute layer. 40*89c4ff92SAndroid Build Coastguard Worker const std::string name = std::string("moved_up-") + permute->GetName(); 41*89c4ff92SAndroid Build Coastguard Worker PermuteLayer& permLayer = *graph.InsertNewLayer<PermuteLayer>(*baseInput, perm, name.c_str()); 42*89c4ff92SAndroid Build Coastguard Worker 43*89c4ff92SAndroid Build Coastguard Worker // Sets output tensor info for the new layer. 44*89c4ff92SAndroid Build Coastguard Worker OutputSlot& parentOutput = *permLayer.GetInputSlot(0).GetConnectedOutputSlot(); 45*89c4ff92SAndroid Build Coastguard Worker const TensorInfo permOutInfo = armnnUtils::Permuted(parentOutput.GetTensorInfo(), perm); 46*89c4ff92SAndroid Build Coastguard Worker permLayer.GetOutputHandler().SetTensorInfo(permOutInfo); 47*89c4ff92SAndroid Build Coastguard Worker } 48*89c4ff92SAndroid Build Coastguard Worker 49*89c4ff92SAndroid Build Coastguard Worker // Bypasses permute. It will be removed as it's left unconnected. 50*89c4ff92SAndroid Build Coastguard Worker permute->GetOutputSlot().MoveAllConnections(base.GetOutputSlot()); 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker } 53*89c4ff92SAndroid Build Coastguard Worker } 54*89c4ff92SAndroid Build Coastguard Worker 55*89c4ff92SAndroid Build Coastguard Worker protected: 56*89c4ff92SAndroid Build Coastguard Worker MovePermuteUpImpl() = default; 57*89c4ff92SAndroid Build Coastguard Worker ~MovePermuteUpImpl() = default; 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker private: CanMovePermuteToInputs(const Layer & base)60*89c4ff92SAndroid Build Coastguard Worker static bool CanMovePermuteToInputs(const Layer& base) 61*89c4ff92SAndroid Build Coastguard Worker { 62*89c4ff92SAndroid Build Coastguard Worker switch (base.GetType()) 63*89c4ff92SAndroid Build Coastguard Worker { 64*89c4ff92SAndroid Build Coastguard Worker case LayerType::Activation: 65*89c4ff92SAndroid Build Coastguard Worker case LayerType::Addition: 66*89c4ff92SAndroid Build Coastguard Worker case LayerType::FakeQuantization: 67*89c4ff92SAndroid Build Coastguard Worker case LayerType::Floor: 68*89c4ff92SAndroid Build Coastguard Worker case LayerType::MemCopy: 69*89c4ff92SAndroid Build Coastguard Worker case LayerType::Multiplication: 70*89c4ff92SAndroid Build Coastguard Worker return true; 71*89c4ff92SAndroid Build Coastguard Worker case LayerType::ElementwiseBinary: 72*89c4ff92SAndroid Build Coastguard Worker { 73*89c4ff92SAndroid Build Coastguard Worker auto descriptor = PolymorphicDowncast<const ElementwiseBinaryDescriptor*>(&base.GetParameters()); 74*89c4ff92SAndroid Build Coastguard Worker return (descriptor->m_Operation == BinaryOperation::Add || 75*89c4ff92SAndroid Build Coastguard Worker descriptor->m_Operation == BinaryOperation::Mul); 76*89c4ff92SAndroid Build Coastguard Worker } 77*89c4ff92SAndroid Build Coastguard Worker default: 78*89c4ff92SAndroid Build Coastguard Worker return false; 79*89c4ff92SAndroid Build Coastguard Worker } 80*89c4ff92SAndroid Build Coastguard Worker } 81*89c4ff92SAndroid Build Coastguard Worker }; 82*89c4ff92SAndroid Build Coastguard Worker 83*89c4ff92SAndroid Build Coastguard Worker using MovePermuteUp = OptimizeForConnection<Layer, PermuteLayer, MovePermuteUpImpl>; 84*89c4ff92SAndroid Build Coastguard Worker 85*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations 86*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn 87