1*89c4ff92SAndroid Build Coastguard Worker // 2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved. 3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT 4*89c4ff92SAndroid Build Coastguard Worker // 5*89c4ff92SAndroid Build Coastguard Worker #pragma once 6*89c4ff92SAndroid Build Coastguard Worker 7*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp" 8*89c4ff92SAndroid Build Coastguard Worker #include "NetworkUtils.hpp" 9*89c4ff92SAndroid Build Coastguard Worker 10*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp> 11*89c4ff92SAndroid Build Coastguard Worker 12*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h> 13*89c4ff92SAndroid Build Coastguard Worker 14*89c4ff92SAndroid Build Coastguard Worker namespace armnn 15*89c4ff92SAndroid Build Coastguard Worker { 16*89c4ff92SAndroid Build Coastguard Worker namespace optimizations 17*89c4ff92SAndroid Build Coastguard Worker { 18*89c4ff92SAndroid Build Coastguard Worker 19*89c4ff92SAndroid Build Coastguard Worker class PermuteDepthwiseConv2dWeightsImpl 20*89c4ff92SAndroid Build Coastguard Worker { 21*89c4ff92SAndroid Build Coastguard Worker public: 22*89c4ff92SAndroid Build Coastguard Worker Run(Graph & graph,Layer & layer) const23*89c4ff92SAndroid Build Coastguard Worker void Run(Graph& graph, Layer& layer) const 24*89c4ff92SAndroid Build Coastguard Worker { 25*89c4ff92SAndroid Build Coastguard Worker if (layer.GetType() == LayerType::DepthwiseConvolution2d) 26*89c4ff92SAndroid Build Coastguard Worker { 27*89c4ff92SAndroid Build Coastguard Worker AddPermuteLayer(graph, PolymorphicDowncast<DepthwiseConvolution2dLayer*>(&layer)); 28*89c4ff92SAndroid Build Coastguard Worker } 29*89c4ff92SAndroid Build Coastguard Worker } 30*89c4ff92SAndroid Build Coastguard Worker 31*89c4ff92SAndroid Build Coastguard Worker protected: 32*89c4ff92SAndroid Build Coastguard Worker PermuteDepthwiseConv2dWeightsImpl() = default; 33*89c4ff92SAndroid Build Coastguard Worker ~PermuteDepthwiseConv2dWeightsImpl() = default; 34*89c4ff92SAndroid Build Coastguard Worker 35*89c4ff92SAndroid Build Coastguard Worker private: 36*89c4ff92SAndroid Build Coastguard Worker /// ArmNN format for weights for depthwise is [1, H, W, C] independently of the input/output layout 37*89c4ff92SAndroid Build Coastguard Worker /// 38*89c4ff92SAndroid Build Coastguard Worker /// ACL format for weights for depthwise is: 39*89c4ff92SAndroid Build Coastguard Worker /// - [1, H, W, C] for [N, H, W, C] input/output layout (matches with ArmNN) 40*89c4ff92SAndroid Build Coastguard Worker /// - [1, C, H, W] for [N, C, H, W] input/output layout 41*89c4ff92SAndroid Build Coastguard Worker /// 42*89c4ff92SAndroid Build Coastguard Worker /// Therefore ArmNN weights have to be permuted when input/output layout is [N, C, H, W] to pass them to ACL. AddPermuteLayer(Graph & graph,DepthwiseConvolution2dLayer * layer)43*89c4ff92SAndroid Build Coastguard Worker static void AddPermuteLayer(Graph& graph, DepthwiseConvolution2dLayer* layer) 44*89c4ff92SAndroid Build Coastguard Worker { 45*89c4ff92SAndroid Build Coastguard Worker TensorInfo inputInfo = layer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo(); 46*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightInfo = layer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo(); 47*89c4ff92SAndroid Build Coastguard Worker if (layer->GetParameters().m_DataLayout == armnn::DataLayout::NHWC) 48*89c4ff92SAndroid Build Coastguard Worker { 49*89c4ff92SAndroid Build Coastguard Worker // No permutation required. Input and weights data layouts are the same. 50*89c4ff92SAndroid Build Coastguard Worker return; 51*89c4ff92SAndroid Build Coastguard Worker } 52*89c4ff92SAndroid Build Coastguard Worker else if (layer->GetParameters().m_DataLayout == armnn::DataLayout::NCHW) 53*89c4ff92SAndroid Build Coastguard Worker { 54*89c4ff92SAndroid Build Coastguard Worker // Weights permutation required. Weights [N,H,W,C] and input [N,C,H,W] data layouts are different. 55*89c4ff92SAndroid Build Coastguard Worker // [ 1, H, W, I*M] --> [ 1, I * M, H, W ] 56*89c4ff92SAndroid Build Coastguard Worker PermutationVector permutationVector = { 0, 2, 3, 1 }; 57*89c4ff92SAndroid Build Coastguard Worker TensorInfo weightsPermuted = armnnUtils::Permuted(weightInfo, permutationVector); 58*89c4ff92SAndroid Build Coastguard Worker 59*89c4ff92SAndroid Build Coastguard Worker // Inserts NewLayer so layers don't need to be re-sorted. 60*89c4ff92SAndroid Build Coastguard Worker PermuteLayer* permuteLayer = 61*89c4ff92SAndroid Build Coastguard Worker graph.InsertNewLayer<PermuteLayer>(layer->GetInputSlot(1), 62*89c4ff92SAndroid Build Coastguard Worker PermuteDescriptor(permutationVector), 63*89c4ff92SAndroid Build Coastguard Worker "permute_layer"); 64*89c4ff92SAndroid Build Coastguard Worker permuteLayer->GetOutputSlot().SetTensorInfo(weightsPermuted); 65*89c4ff92SAndroid Build Coastguard Worker 66*89c4ff92SAndroid Build Coastguard Worker // Assign Permute BackendId to be the same as the Depthwise Conv2d BackendId. 67*89c4ff92SAndroid Build Coastguard Worker // Needed as backends have already been assigned at this stage. 68*89c4ff92SAndroid Build Coastguard Worker permuteLayer->SetBackendId(layer->GetBackendId()); 69*89c4ff92SAndroid Build Coastguard Worker } 70*89c4ff92SAndroid Build Coastguard Worker else 71*89c4ff92SAndroid Build Coastguard Worker { 72*89c4ff92SAndroid Build Coastguard Worker throw InvalidArgumentException(fmt::format("Unknown data layout for tensor info conversion: {}", 73*89c4ff92SAndroid Build Coastguard Worker GetDataLayoutName(layer->GetParameters().m_DataLayout))); 74*89c4ff92SAndroid Build Coastguard Worker } 75*89c4ff92SAndroid Build Coastguard Worker } 76*89c4ff92SAndroid Build Coastguard Worker }; 77*89c4ff92SAndroid Build Coastguard Worker 78*89c4ff92SAndroid Build Coastguard Worker using PermuteDepthwiseConv2dWeights = OptimizeForType<Layer, PermuteDepthwiseConv2dWeightsImpl>; 79*89c4ff92SAndroid Build Coastguard Worker 80*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations 81*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn 82