xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/PermuteDepthwiseConv2dWeights.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker #pragma once
6*89c4ff92SAndroid Build Coastguard Worker 
7*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp"
8*89c4ff92SAndroid Build Coastguard Worker #include "NetworkUtils.hpp"
9*89c4ff92SAndroid Build Coastguard Worker 
10*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker #include <fmt/format.h>
13*89c4ff92SAndroid Build Coastguard Worker 
14*89c4ff92SAndroid Build Coastguard Worker namespace armnn
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker namespace optimizations
17*89c4ff92SAndroid Build Coastguard Worker {
18*89c4ff92SAndroid Build Coastguard Worker 
19*89c4ff92SAndroid Build Coastguard Worker class PermuteDepthwiseConv2dWeightsImpl
20*89c4ff92SAndroid Build Coastguard Worker {
21*89c4ff92SAndroid Build Coastguard Worker public:
22*89c4ff92SAndroid Build Coastguard Worker 
Run(Graph & graph,Layer & layer) const23*89c4ff92SAndroid Build Coastguard Worker     void Run(Graph& graph, Layer& layer) const
24*89c4ff92SAndroid Build Coastguard Worker     {
25*89c4ff92SAndroid Build Coastguard Worker         if (layer.GetType() == LayerType::DepthwiseConvolution2d)
26*89c4ff92SAndroid Build Coastguard Worker         {
27*89c4ff92SAndroid Build Coastguard Worker             AddPermuteLayer(graph, PolymorphicDowncast<DepthwiseConvolution2dLayer*>(&layer));
28*89c4ff92SAndroid Build Coastguard Worker         }
29*89c4ff92SAndroid Build Coastguard Worker     }
30*89c4ff92SAndroid Build Coastguard Worker 
31*89c4ff92SAndroid Build Coastguard Worker protected:
32*89c4ff92SAndroid Build Coastguard Worker     PermuteDepthwiseConv2dWeightsImpl() = default;
33*89c4ff92SAndroid Build Coastguard Worker     ~PermuteDepthwiseConv2dWeightsImpl() = default;
34*89c4ff92SAndroid Build Coastguard Worker 
35*89c4ff92SAndroid Build Coastguard Worker private:
36*89c4ff92SAndroid Build Coastguard Worker     /// ArmNN format for weights for depthwise is [1, H, W, C] independently of the input/output layout
37*89c4ff92SAndroid Build Coastguard Worker     ///
38*89c4ff92SAndroid Build Coastguard Worker     /// ACL format for weights for depthwise is:
39*89c4ff92SAndroid Build Coastguard Worker     /// - [1, H, W, C] for [N, H, W, C] input/output layout (matches with ArmNN)
40*89c4ff92SAndroid Build Coastguard Worker     /// - [1, C, H, W] for [N, C, H, W] input/output layout
41*89c4ff92SAndroid Build Coastguard Worker     ///
42*89c4ff92SAndroid Build Coastguard Worker     /// Therefore ArmNN weights have to be permuted when input/output layout is [N, C, H, W] to pass them to ACL.
AddPermuteLayer(Graph & graph,DepthwiseConvolution2dLayer * layer)43*89c4ff92SAndroid Build Coastguard Worker     static void AddPermuteLayer(Graph& graph, DepthwiseConvolution2dLayer* layer)
44*89c4ff92SAndroid Build Coastguard Worker     {
45*89c4ff92SAndroid Build Coastguard Worker         TensorInfo inputInfo = layer->GetInputSlot(0).GetConnectedOutputSlot()->GetTensorInfo();
46*89c4ff92SAndroid Build Coastguard Worker         TensorInfo weightInfo = layer->GetInputSlot(1).GetConnectedOutputSlot()->GetTensorInfo();
47*89c4ff92SAndroid Build Coastguard Worker         if (layer->GetParameters().m_DataLayout == armnn::DataLayout::NHWC)
48*89c4ff92SAndroid Build Coastguard Worker         {
49*89c4ff92SAndroid Build Coastguard Worker             // No permutation required. Input and weights data layouts are the same.
50*89c4ff92SAndroid Build Coastguard Worker             return;
51*89c4ff92SAndroid Build Coastguard Worker         }
52*89c4ff92SAndroid Build Coastguard Worker         else if (layer->GetParameters().m_DataLayout == armnn::DataLayout::NCHW)
53*89c4ff92SAndroid Build Coastguard Worker         {
54*89c4ff92SAndroid Build Coastguard Worker             // Weights permutation required. Weights [N,H,W,C] and input [N,C,H,W] data layouts are different.
55*89c4ff92SAndroid Build Coastguard Worker             // [ 1, H, W, I*M] --> [ 1, I * M, H, W ]
56*89c4ff92SAndroid Build Coastguard Worker             PermutationVector permutationVector = { 0, 2, 3, 1 };
57*89c4ff92SAndroid Build Coastguard Worker             TensorInfo weightsPermuted = armnnUtils::Permuted(weightInfo, permutationVector);
58*89c4ff92SAndroid Build Coastguard Worker 
59*89c4ff92SAndroid Build Coastguard Worker             // Inserts NewLayer so layers don't need to be re-sorted.
60*89c4ff92SAndroid Build Coastguard Worker             PermuteLayer* permuteLayer =
61*89c4ff92SAndroid Build Coastguard Worker                 graph.InsertNewLayer<PermuteLayer>(layer->GetInputSlot(1),
62*89c4ff92SAndroid Build Coastguard Worker                                                    PermuteDescriptor(permutationVector),
63*89c4ff92SAndroid Build Coastguard Worker                                                    "permute_layer");
64*89c4ff92SAndroid Build Coastguard Worker             permuteLayer->GetOutputSlot().SetTensorInfo(weightsPermuted);
65*89c4ff92SAndroid Build Coastguard Worker 
66*89c4ff92SAndroid Build Coastguard Worker             // Assign Permute BackendId to be the same as the Depthwise Conv2d BackendId.
67*89c4ff92SAndroid Build Coastguard Worker             // Needed as backends have already been assigned at this stage.
68*89c4ff92SAndroid Build Coastguard Worker             permuteLayer->SetBackendId(layer->GetBackendId());
69*89c4ff92SAndroid Build Coastguard Worker         }
70*89c4ff92SAndroid Build Coastguard Worker         else
71*89c4ff92SAndroid Build Coastguard Worker         {
72*89c4ff92SAndroid Build Coastguard Worker             throw InvalidArgumentException(fmt::format("Unknown data layout for tensor info conversion: {}",
73*89c4ff92SAndroid Build Coastguard Worker                                                        GetDataLayoutName(layer->GetParameters().m_DataLayout)));
74*89c4ff92SAndroid Build Coastguard Worker         }
75*89c4ff92SAndroid Build Coastguard Worker     }
76*89c4ff92SAndroid Build Coastguard Worker };
77*89c4ff92SAndroid Build Coastguard Worker 
78*89c4ff92SAndroid Build Coastguard Worker using PermuteDepthwiseConv2dWeights = OptimizeForType<Layer, PermuteDepthwiseConv2dWeightsImpl>;
79*89c4ff92SAndroid Build Coastguard Worker 
80*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations
81*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn
82