xref: /aosp_15_r20/external/armnn/src/armnn/optimizations/ConvertConstPermuteLayersToConstLayers.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2022 Arm Ltd and Contributors. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include "Optimization.hpp"
9*89c4ff92SAndroid Build Coastguard Worker #include <armnnUtils/Permute.hpp>
10*89c4ff92SAndroid Build Coastguard Worker #include <ResolveType.hpp>
11*89c4ff92SAndroid Build Coastguard Worker 
12*89c4ff92SAndroid Build Coastguard Worker namespace armnn
13*89c4ff92SAndroid Build Coastguard Worker {
14*89c4ff92SAndroid Build Coastguard Worker namespace optimizations
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker 
17*89c4ff92SAndroid Build Coastguard Worker class ConvertConstPermuteLayersToConstLayers
18*89c4ff92SAndroid Build Coastguard Worker {
19*89c4ff92SAndroid Build Coastguard Worker public:
Run(Graph & graph,InputSlot & connection) const20*89c4ff92SAndroid Build Coastguard Worker     void Run(Graph& graph, InputSlot& connection) const
21*89c4ff92SAndroid Build Coastguard Worker     {
22*89c4ff92SAndroid Build Coastguard Worker         Layer& base = connection.GetConnectedOutputSlot()->GetOwningLayer();
23*89c4ff92SAndroid Build Coastguard Worker         Layer& child = connection.GetOwningLayer();
24*89c4ff92SAndroid Build Coastguard Worker 
25*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(base.GetType() == LayerType::Constant);
26*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(child.GetType() == LayerType::Permute);
27*89c4ff92SAndroid Build Coastguard Worker 
28*89c4ff92SAndroid Build Coastguard Worker         if (base.GetDataType() == child.GetDataType())
29*89c4ff92SAndroid Build Coastguard Worker         {
30*89c4ff92SAndroid Build Coastguard Worker             switch (base.GetDataType())
31*89c4ff92SAndroid Build Coastguard Worker             {
32*89c4ff92SAndroid Build Coastguard Worker                 case DataType::Float16:
33*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::Float16>(graph,
34*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
35*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
36*89c4ff92SAndroid Build Coastguard Worker                     break;
37*89c4ff92SAndroid Build Coastguard Worker                 case DataType::Float32:
38*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::Float32>(graph,
39*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
40*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
41*89c4ff92SAndroid Build Coastguard Worker                     break;
42*89c4ff92SAndroid Build Coastguard Worker                 case DataType::QAsymmU8:
43*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::QAsymmU8>(graph,
44*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
45*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
46*89c4ff92SAndroid Build Coastguard Worker                     break;
47*89c4ff92SAndroid Build Coastguard Worker                 case DataType::Signed32:
48*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::Signed32>(graph,
49*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
50*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
51*89c4ff92SAndroid Build Coastguard Worker                     break;
52*89c4ff92SAndroid Build Coastguard Worker                 case DataType::QSymmS16:
53*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::QSymmS16>(graph,
54*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
55*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
56*89c4ff92SAndroid Build Coastguard Worker                     break;
57*89c4ff92SAndroid Build Coastguard Worker                 case DataType::QSymmS8:
58*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::QSymmS8>(graph,
59*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
60*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
61*89c4ff92SAndroid Build Coastguard Worker                     break;
62*89c4ff92SAndroid Build Coastguard Worker                 case DataType::QAsymmS8:
63*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::QAsymmS8>(graph,
64*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
65*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
66*89c4ff92SAndroid Build Coastguard Worker                     break;
67*89c4ff92SAndroid Build Coastguard Worker                 case DataType::BFloat16:
68*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::BFloat16>(graph,
69*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
70*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
71*89c4ff92SAndroid Build Coastguard Worker                     break;
72*89c4ff92SAndroid Build Coastguard Worker                 case DataType::Signed64:
73*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::Signed64>(graph,
74*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
75*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
76*89c4ff92SAndroid Build Coastguard Worker                     break;
77*89c4ff92SAndroid Build Coastguard Worker                 case DataType::Boolean:
78*89c4ff92SAndroid Build Coastguard Worker                     ReplaceConstPermuteLayer<DataType::Boolean>(graph,
79*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<ConstantLayer*>(&base),
80*89c4ff92SAndroid Build Coastguard Worker                                                                  PolymorphicDowncast<PermuteLayer*>(&child));
81*89c4ff92SAndroid Build Coastguard Worker                     break;
82*89c4ff92SAndroid Build Coastguard Worker             }
83*89c4ff92SAndroid Build Coastguard Worker         }
84*89c4ff92SAndroid Build Coastguard Worker     }
85*89c4ff92SAndroid Build Coastguard Worker protected:
86*89c4ff92SAndroid Build Coastguard Worker     ConvertConstPermuteLayersToConstLayers()  = default;
87*89c4ff92SAndroid Build Coastguard Worker     ~ConvertConstPermuteLayersToConstLayers() = default;
88*89c4ff92SAndroid Build Coastguard Worker private:
89*89c4ff92SAndroid Build Coastguard Worker     template<armnn::DataType ArmnnType,
90*89c4ff92SAndroid Build Coastguard Worker              typename T = armnn::ResolveType<ArmnnType>>
ReplaceConstPermuteLayer(Graph & graph,ConstantLayer * constantLayer,PermuteLayer * permuteLayer)91*89c4ff92SAndroid Build Coastguard Worker     static void ReplaceConstPermuteLayer(Graph& graph,
92*89c4ff92SAndroid Build Coastguard Worker                                          ConstantLayer* constantLayer,
93*89c4ff92SAndroid Build Coastguard Worker                                          PermuteLayer* permuteLayer)
94*89c4ff92SAndroid Build Coastguard Worker     {
95*89c4ff92SAndroid Build Coastguard Worker         IgnoreUnused(graph);
96*89c4ff92SAndroid Build Coastguard Worker         /**
97*89c4ff92SAndroid Build Coastguard Worker          * This optimisation is to find situations where a constant set of inputs is being provided to a Permute
98*89c4ff92SAndroid Build Coastguard Worker          * layer. In this case we don't want the overhead of Permuting the values on every inference, instead we
99*89c4ff92SAndroid Build Coastguard Worker          * want to Permute them once and store them in a Const layer to be used everytime as they will not change.
100*89c4ff92SAndroid Build Coastguard Worker          */
101*89c4ff92SAndroid Build Coastguard Worker         TensorInfo outputPermuteInfo = permuteLayer->GetOutputSlot(0).GetTensorInfo();
102*89c4ff92SAndroid Build Coastguard Worker         std::vector<T> newValues(outputPermuteInfo.GetNumElements());
103*89c4ff92SAndroid Build Coastguard Worker         armnnUtils::Permute(outputPermuteInfo.GetShape(), permuteLayer->GetPermutation(),
104*89c4ff92SAndroid Build Coastguard Worker                             constantLayer->m_LayerOutput->Map(true), newValues.data(),
105*89c4ff92SAndroid Build Coastguard Worker                             GetDataTypeSize(outputPermuteInfo.GetDataType()));
106*89c4ff92SAndroid Build Coastguard Worker 
107*89c4ff92SAndroid Build Coastguard Worker         TensorInfo newInfo = outputPermuteInfo;
108*89c4ff92SAndroid Build Coastguard Worker         newInfo.SetConstant(true);
109*89c4ff92SAndroid Build Coastguard Worker         ConstTensor newInput(newInfo, newValues);
110*89c4ff92SAndroid Build Coastguard Worker         constantLayer->m_LayerOutput.reset(new ScopedTensorHandle(newInput));
111*89c4ff92SAndroid Build Coastguard Worker 
112*89c4ff92SAndroid Build Coastguard Worker         // Moves connections in permute output to the constant layer.
113*89c4ff92SAndroid Build Coastguard Worker         // Permute layer will be removed if left unconnected.
114*89c4ff92SAndroid Build Coastguard Worker         permuteLayer->GetOutputSlot().MoveAllConnections(constantLayer->GetOutputSlot());
115*89c4ff92SAndroid Build Coastguard Worker 
116*89c4ff92SAndroid Build Coastguard Worker         // Updating the output tensor
117*89c4ff92SAndroid Build Coastguard Worker         constantLayer->GetOutputSlot(0).SetTensorInfo(newInfo);
118*89c4ff92SAndroid Build Coastguard Worker         ARMNN_ASSERT(constantLayer->GetOutputSlot(0).GetTensorInfo().IsConstant() == true);
119*89c4ff92SAndroid Build Coastguard Worker     }
120*89c4ff92SAndroid Build Coastguard Worker };
121*89c4ff92SAndroid Build Coastguard Worker 
122*89c4ff92SAndroid Build Coastguard Worker using FusePermuteIntoConstLayer = OptimizeForConnection<ConstantLayer,
123*89c4ff92SAndroid Build Coastguard Worker                                                         PermuteLayer,
124*89c4ff92SAndroid Build Coastguard Worker                                                         ConvertConstPermuteLayersToConstLayers>;
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker } // namespace optimizations
127*89c4ff92SAndroid Build Coastguard Worker } // namespace armnn