1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker 6*89c4ff92SAndroid Build Coastguard Worker #pragma once 7*89c4ff92SAndroid Build Coastguard Worker 8*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp" 9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp> 10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker namespace armnn 13*89c4ff92SAndroid Build Coastguard Worker { 14*89c4ff92SAndroid Build Coastguard Worker namespace optimizations 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker 17*89c4ff92SAndroid Build Coastguard Worker class ConvertConstPermuteLayersToConstLayers 18*89c4ff92SAndroid Build Coastguard Worker { 19*89c4ff92SAndroid Build Coastguard Worker public: Run(Graph & graph,InputSlot & connection) const20*89c4ff92SAndroid Build Coastguard Worker void Run(Graph& graph, InputSlot& connection) const 21*89c4ff92SAndroid Build Coastguard Worker { 22*89c4ff92SAndroid Build Coastguard Worker Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer(); 23*89c4ff92SAndroid Build Coastguard Worker Layer& child = connection.GetOwningLayer(); 24*89c4ff92SAndroid Build Coastguard Worker 25*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(base.GetType() == LayerType::Constant); 26*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(child.GetType() == LayerType::Permute); 27*89c4ff92SAndroid Build Coastguard Worker 28*89c4ff92SAndroid Build Coastguard Worker if (base.GetDataType() == child.GetDataType()) 29*89c4ff92SAndroid Build Coastguard Worker { 30*89c4ff92SAndroid Build Coastguard Worker switch (base.GetDataType()) 31*89c4ff92SAndroid Build Coastguard Worker { 32*89c4ff92SAndroid Build Coastguard Worker case DataType::Float16: 33*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::Float16>(graph, 34*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 35*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 36*89c4ff92SAndroid Build Coastguard Worker break; 37*89c4ff92SAndroid Build Coastguard Worker case DataType::Float32: 38*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::Float32>(graph, 39*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 40*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 41*89c4ff92SAndroid Build Coastguard Worker break; 42*89c4ff92SAndroid Build Coastguard Worker case DataType::QAsymmU8: 43*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::QAsymmU8>(graph, 44*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 45*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 46*89c4ff92SAndroid Build Coastguard Worker break; 47*89c4ff92SAndroid Build Coastguard Worker case DataType::Signed32: 48*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::Signed32>(graph, 49*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 50*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 51*89c4ff92SAndroid Build Coastguard Worker break; 52*89c4ff92SAndroid Build Coastguard Worker case DataType::QSymmS16: 53*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::QSymmS16>(graph, 54*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 55*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 56*89c4ff92SAndroid Build Coastguard Worker break; 57*89c4ff92SAndroid Build Coastguard Worker case DataType::QSymmS8: 58*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::QSymmS8>(graph, 59*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 60*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 61*89c4ff92SAndroid Build Coastguard Worker break; 62*89c4ff92SAndroid Build Coastguard Worker case DataType::QAsymmS8: 63*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::QAsymmS8>(graph, 64*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 65*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 66*89c4ff92SAndroid Build Coastguard Worker break; 67*89c4ff92SAndroid Build Coastguard Worker case DataType::BFloat16: 68*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::BFloat16>(graph, 69*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 70*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 71*89c4ff92SAndroid Build Coastguard Worker break; 72*89c4ff92SAndroid Build Coastguard Worker case DataType::Signed64: 73*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::Signed64>(graph, 74*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 75*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 76*89c4ff92SAndroid Build Coastguard Worker break; 77*89c4ff92SAndroid Build Coastguard Worker case DataType::Boolean: 78*89c4ff92SAndroid Build Coastguard Worker ReplaceConstPermuteLayer<DataType::Boolean>(graph, 79*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<ConstantLayer*>(&base), 80*89c4ff92SAndroid Build Coastguard Worker PolymorphicDowncast<PermuteLayer*>(&child)); 81*89c4ff92SAndroid Build Coastguard Worker break; 82*89c4ff92SAndroid Build Coastguard Worker } 83*89c4ff92SAndroid Build Coastguard Worker } 84*89c4ff92SAndroid Build Coastguard Worker } 85*89c4ff92SAndroid Build Coastguard Worker protected: 86*89c4ff92SAndroid Build Coastguard Worker ConvertConstPermuteLayersToConstLayers() = default; 87*89c4ff92SAndroid Build Coastguard Worker ~ConvertConstPermuteLayersToConstLayers() = default; 88*89c4ff92SAndroid Build Coastguard Worker private: 89*89c4ff92SAndroid Build Coastguard Worker template<armnn::DataType ArmnnType, 90*89c4ff92SAndroid Build Coastguard Worker typename T = armnn::ResolveType<ArmnnType>> ReplaceConstPermuteLayer(Graph & graph,ConstantLayer * constantLayer,PermuteLayer * permuteLayer)91*89c4ff92SAndroid Build Coastguard Worker static void ReplaceConstPermuteLayer(Graph& graph, 92*89c4ff92SAndroid Build Coastguard Worker ConstantLayer* constantLayer, 93*89c4ff92SAndroid Build Coastguard Worker PermuteLayer* permuteLayer) 94*89c4ff92SAndroid Build Coastguard Worker { 95*89c4ff92SAndroid Build Coastguard Worker IgnoreUnused(graph); 96*89c4ff92SAndroid Build Coastguard Worker /** 97*89c4ff92SAndroid Build Coastguard Worker * This optimisation is to find situations where a constant set of inputs is being provided to a Permute 98*89c4ff92SAndroid Build Coastguard Worker * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we 99*89c4ff92SAndroid Build Coastguard Worker * want to Permute them once and store them in a Const layer to be used everytime as they will not change. 100*89c4ff92SAndroid Build Coastguard Worker */ 101*89c4ff92SAndroid Build Coastguard Worker TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo(); 102*89c4ff92SAndroid Build Coastguard Worker std::vector<T> newValues(outputPermuteInfo.GetNumElements()); 103*89c4ff92SAndroid Build Coastguard Worker armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(), 104*89c4ff92SAndroid Build Coastguard Worker constantLayer->m_LayerOutput->Map(true), newValues.data(), 105*89c4ff92SAndroid Build Coastguard Worker GetDataTypeSize(outputPermuteInfo.GetDataType())); 106*89c4ff92SAndroid Build Coastguard Worker 107*89c4ff92SAndroid Build Coastguard Worker TensorInfo newInfo = outputPermuteInfo; 108*89c4ff92SAndroid Build Coastguard Worker newInfo.SetConstant(true); 109*89c4ff92SAndroid Build Coastguard Worker ConstTensor newInput(newInfo, newValues); 110*89c4ff92SAndroid Build Coastguard Worker constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput)); 111*89c4ff92SAndroid Build Coastguard Worker 112*89c4ff92SAndroid Build Coastguard Worker // Moves connections in permute output to the constant layer. 113*89c4ff92SAndroid Build Coastguard Worker // Permute layer will be removed if left unconnected. 114*89c4ff92SAndroid Build Coastguard Worker permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot()); 115*89c4ff92SAndroid Build Coastguard Worker 116*89c4ff92SAndroid Build Coastguard Worker // Updating the output tensor 117*89c4ff92SAndroid Build Coastguard Worker constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo); 118*89c4ff92SAndroid Build Coastguard Worker ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true); 119*89c4ff92SAndroid Build Coastguard Worker } 120*89c4ff92SAndroid Build Coastguard Worker }; 121*89c4ff92SAndroid Build Coastguard Worker 122*89c4ff92SAndroid Build Coastguard Worker using FusePermuteIntoConstLayer = OptimizeForConnection<ConstantLayer, 123*89c4ff92SAndroid Build Coastguard Worker PermuteLayer, 124*89c4ff92SAndroid Build Coastguard Worker ConvertConstPermuteLayersToConstLayers>; 125*89c4ff92SAndroid Build Coastguard Worker 126*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations 127*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn