xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/LayerSupportRules.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1*89c4ff92SAndroid Build Coastguard Worker //
2*89c4ff92SAndroid Build Coastguard Worker // Copyright © 2017 Arm Ltd. All rights reserved.
3*89c4ff92SAndroid Build Coastguard Worker // SPDX-License-Identifier: MIT
4*89c4ff92SAndroid Build Coastguard Worker //
5*89c4ff92SAndroid Build Coastguard Worker 
6*89c4ff92SAndroid Build Coastguard Worker #pragma once
7*89c4ff92SAndroid Build Coastguard Worker 
8*89c4ff92SAndroid Build Coastguard Worker #include <armnn/utility/Assert.hpp>
9*89c4ff92SAndroid Build Coastguard Worker #include <algorithm>
10*89c4ff92SAndroid Build Coastguard Worker 
11*89c4ff92SAndroid Build Coastguard Worker namespace armnn
12*89c4ff92SAndroid Build Coastguard Worker {
13*89c4ff92SAndroid Build Coastguard Worker 
GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)14*89c4ff92SAndroid Build Coastguard Worker inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
15*89c4ff92SAndroid Build Coastguard Worker {
16*89c4ff92SAndroid Build Coastguard Worker     if (!weightsType)
17*89c4ff92SAndroid Build Coastguard Worker     {
18*89c4ff92SAndroid Build Coastguard Worker         return weightsType;
19*89c4ff92SAndroid Build Coastguard Worker     }
20*89c4ff92SAndroid Build Coastguard Worker 
21*89c4ff92SAndroid Build Coastguard Worker     switch(weightsType.value())
22*89c4ff92SAndroid Build Coastguard Worker     {
23*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Float16:
24*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::Float32:
25*89c4ff92SAndroid Build Coastguard Worker             return weightsType;
26*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmS8:
27*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QAsymmU8:
28*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QSymmS8:
29*89c4ff92SAndroid Build Coastguard Worker         case armnn::DataType::QSymmS16:
30*89c4ff92SAndroid Build Coastguard Worker             return armnn::DataType::Signed32;
31*89c4ff92SAndroid Build Coastguard Worker         default:
32*89c4ff92SAndroid Build Coastguard Worker             ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
33*89c4ff92SAndroid Build Coastguard Worker     }
34*89c4ff92SAndroid Build Coastguard Worker     return armnn::EmptyOptional();
35*89c4ff92SAndroid Build Coastguard Worker }
36*89c4ff92SAndroid Build Coastguard Worker 
37*89c4ff92SAndroid Build Coastguard Worker template<typename F>
CheckSupportRule(F rule,Optional<std::string &> reasonIfUnsupported,const char * reason)38*89c4ff92SAndroid Build Coastguard Worker bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
39*89c4ff92SAndroid Build Coastguard Worker {
40*89c4ff92SAndroid Build Coastguard Worker     bool supported = rule();
41*89c4ff92SAndroid Build Coastguard Worker     if (!supported && reason)
42*89c4ff92SAndroid Build Coastguard Worker     {
43*89c4ff92SAndroid Build Coastguard Worker         reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
44*89c4ff92SAndroid Build Coastguard Worker     }
45*89c4ff92SAndroid Build Coastguard Worker     return supported;
46*89c4ff92SAndroid Build Coastguard Worker }
47*89c4ff92SAndroid Build Coastguard Worker 
48*89c4ff92SAndroid Build Coastguard Worker struct Rule
49*89c4ff92SAndroid Build Coastguard Worker {
operator ()armnn::Rule50*89c4ff92SAndroid Build Coastguard Worker     bool operator()() const
51*89c4ff92SAndroid Build Coastguard Worker     {
52*89c4ff92SAndroid Build Coastguard Worker         return m_Res;
53*89c4ff92SAndroid Build Coastguard Worker     }
54*89c4ff92SAndroid Build Coastguard Worker 
55*89c4ff92SAndroid Build Coastguard Worker     bool m_Res = true;
56*89c4ff92SAndroid Build Coastguard Worker };
57*89c4ff92SAndroid Build Coastguard Worker 
58*89c4ff92SAndroid Build Coastguard Worker template<typename T>
AllTypesAreEqualImpl(T)59*89c4ff92SAndroid Build Coastguard Worker bool AllTypesAreEqualImpl(T)
60*89c4ff92SAndroid Build Coastguard Worker {
61*89c4ff92SAndroid Build Coastguard Worker     return true;
62*89c4ff92SAndroid Build Coastguard Worker }
63*89c4ff92SAndroid Build Coastguard Worker 
64*89c4ff92SAndroid Build Coastguard Worker template<typename T, typename... Rest>
AllTypesAreEqualImpl(T t1,T t2,Rest...rest)65*89c4ff92SAndroid Build Coastguard Worker bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
66*89c4ff92SAndroid Build Coastguard Worker {
67*89c4ff92SAndroid Build Coastguard Worker     static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
68*89c4ff92SAndroid Build Coastguard Worker 
69*89c4ff92SAndroid Build Coastguard Worker     return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
70*89c4ff92SAndroid Build Coastguard Worker }
71*89c4ff92SAndroid Build Coastguard Worker 
72*89c4ff92SAndroid Build Coastguard Worker struct TypesAreEqual : public Rule
73*89c4ff92SAndroid Build Coastguard Worker {
74*89c4ff92SAndroid Build Coastguard Worker     template<typename ... Ts>
TypesAreEqualarmnn::TypesAreEqual75*89c4ff92SAndroid Build Coastguard Worker     TypesAreEqual(const Ts&... ts)
76*89c4ff92SAndroid Build Coastguard Worker     {
77*89c4ff92SAndroid Build Coastguard Worker         m_Res = AllTypesAreEqualImpl(ts...);
78*89c4ff92SAndroid Build Coastguard Worker     }
79*89c4ff92SAndroid Build Coastguard Worker };
80*89c4ff92SAndroid Build Coastguard Worker 
81*89c4ff92SAndroid Build Coastguard Worker struct QuantizationParametersAreEqual : public Rule
82*89c4ff92SAndroid Build Coastguard Worker {
QuantizationParametersAreEqualarmnn::QuantizationParametersAreEqual83*89c4ff92SAndroid Build Coastguard Worker     QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
84*89c4ff92SAndroid Build Coastguard Worker     {
85*89c4ff92SAndroid Build Coastguard Worker         m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
86*89c4ff92SAndroid Build Coastguard Worker                 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
87*89c4ff92SAndroid Build Coastguard Worker     }
88*89c4ff92SAndroid Build Coastguard Worker };
89*89c4ff92SAndroid Build Coastguard Worker 
90*89c4ff92SAndroid Build Coastguard Worker struct TypeAnyOf : public Rule
91*89c4ff92SAndroid Build Coastguard Worker {
92*89c4ff92SAndroid Build Coastguard Worker     template<typename Container>
TypeAnyOfarmnn::TypeAnyOf93*89c4ff92SAndroid Build Coastguard Worker     TypeAnyOf(const TensorInfo& info, const Container& c)
94*89c4ff92SAndroid Build Coastguard Worker     {
95*89c4ff92SAndroid Build Coastguard Worker         m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
96*89c4ff92SAndroid Build Coastguard Worker         {
97*89c4ff92SAndroid Build Coastguard Worker             return dt == info.GetDataType();
98*89c4ff92SAndroid Build Coastguard Worker         });
99*89c4ff92SAndroid Build Coastguard Worker     }
100*89c4ff92SAndroid Build Coastguard Worker };
101*89c4ff92SAndroid Build Coastguard Worker 
102*89c4ff92SAndroid Build Coastguard Worker struct TypeIs : public Rule
103*89c4ff92SAndroid Build Coastguard Worker {
TypeIsarmnn::TypeIs104*89c4ff92SAndroid Build Coastguard Worker     TypeIs(const TensorInfo& info, DataType dt)
105*89c4ff92SAndroid Build Coastguard Worker     {
106*89c4ff92SAndroid Build Coastguard Worker         m_Res = dt == info.GetDataType();
107*89c4ff92SAndroid Build Coastguard Worker     }
108*89c4ff92SAndroid Build Coastguard Worker };
109*89c4ff92SAndroid Build Coastguard Worker 
110*89c4ff92SAndroid Build Coastguard Worker struct TypeNotPerAxisQuantized : public Rule
111*89c4ff92SAndroid Build Coastguard Worker {
TypeNotPerAxisQuantizedarmnn::TypeNotPerAxisQuantized112*89c4ff92SAndroid Build Coastguard Worker     TypeNotPerAxisQuantized(const TensorInfo& info)
113*89c4ff92SAndroid Build Coastguard Worker     {
114*89c4ff92SAndroid Build Coastguard Worker         m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
115*89c4ff92SAndroid Build Coastguard Worker     }
116*89c4ff92SAndroid Build Coastguard Worker };
117*89c4ff92SAndroid Build Coastguard Worker 
118*89c4ff92SAndroid Build Coastguard Worker struct BiasAndWeightsTypesMatch : public Rule
119*89c4ff92SAndroid Build Coastguard Worker {
BiasAndWeightsTypesMatcharmnn::BiasAndWeightsTypesMatch120*89c4ff92SAndroid Build Coastguard Worker     BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121*89c4ff92SAndroid Build Coastguard Worker     {
122*89c4ff92SAndroid Build Coastguard Worker         m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123*89c4ff92SAndroid Build Coastguard Worker     }
124*89c4ff92SAndroid Build Coastguard Worker };
125*89c4ff92SAndroid Build Coastguard Worker 
126*89c4ff92SAndroid Build Coastguard Worker struct BiasAndWeightsTypesCompatible : public Rule
127*89c4ff92SAndroid Build Coastguard Worker {
128*89c4ff92SAndroid Build Coastguard Worker     template<typename Container>
BiasAndWeightsTypesCompatiblearmnn::BiasAndWeightsTypesCompatible129*89c4ff92SAndroid Build Coastguard Worker     BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130*89c4ff92SAndroid Build Coastguard Worker     {
131*89c4ff92SAndroid Build Coastguard Worker         m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132*89c4ff92SAndroid Build Coastguard Worker             {
133*89c4ff92SAndroid Build Coastguard Worker                 return dt ==  GetBiasTypeFromWeightsType(info.GetDataType()).value();
134*89c4ff92SAndroid Build Coastguard Worker             });
135*89c4ff92SAndroid Build Coastguard Worker     }
136*89c4ff92SAndroid Build Coastguard Worker };
137*89c4ff92SAndroid Build Coastguard Worker 
138*89c4ff92SAndroid Build Coastguard Worker struct ShapesAreSameRank : public Rule
139*89c4ff92SAndroid Build Coastguard Worker {
ShapesAreSameRankarmnn::ShapesAreSameRank140*89c4ff92SAndroid Build Coastguard Worker     ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141*89c4ff92SAndroid Build Coastguard Worker     {
142*89c4ff92SAndroid Build Coastguard Worker         m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143*89c4ff92SAndroid Build Coastguard Worker     }
144*89c4ff92SAndroid Build Coastguard Worker };
145*89c4ff92SAndroid Build Coastguard Worker 
146*89c4ff92SAndroid Build Coastguard Worker struct ShapesAreSameTotalSize : public Rule
147*89c4ff92SAndroid Build Coastguard Worker {
ShapesAreSameTotalSizearmnn::ShapesAreSameTotalSize148*89c4ff92SAndroid Build Coastguard Worker     ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149*89c4ff92SAndroid Build Coastguard Worker     {
150*89c4ff92SAndroid Build Coastguard Worker         m_Res = info0.GetNumElements() == info1.GetNumElements();
151*89c4ff92SAndroid Build Coastguard Worker     }
152*89c4ff92SAndroid Build Coastguard Worker };
153*89c4ff92SAndroid Build Coastguard Worker 
154*89c4ff92SAndroid Build Coastguard Worker struct ShapesAreBroadcastCompatible : public Rule
155*89c4ff92SAndroid Build Coastguard Worker {
CalcInputSizearmnn::ShapesAreBroadcastCompatible156*89c4ff92SAndroid Build Coastguard Worker     unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157*89c4ff92SAndroid Build Coastguard Worker     {
158*89c4ff92SAndroid Build Coastguard Worker         unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159*89c4ff92SAndroid Build Coastguard Worker         unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160*89c4ff92SAndroid Build Coastguard Worker         return sizeIn;
161*89c4ff92SAndroid Build Coastguard Worker     }
162*89c4ff92SAndroid Build Coastguard Worker 
ShapesAreBroadcastCompatiblearmnn::ShapesAreBroadcastCompatible163*89c4ff92SAndroid Build Coastguard Worker     ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164*89c4ff92SAndroid Build Coastguard Worker     {
165*89c4ff92SAndroid Build Coastguard Worker         const TensorShape& shape0 = in0.GetShape();
166*89c4ff92SAndroid Build Coastguard Worker         const TensorShape& shape1 = in1.GetShape();
167*89c4ff92SAndroid Build Coastguard Worker         const TensorShape& outShape = out.GetShape();
168*89c4ff92SAndroid Build Coastguard Worker 
169*89c4ff92SAndroid Build Coastguard Worker         for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170*89c4ff92SAndroid Build Coastguard Worker         {
171*89c4ff92SAndroid Build Coastguard Worker             unsigned int sizeOut = outShape[i];
172*89c4ff92SAndroid Build Coastguard Worker             unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173*89c4ff92SAndroid Build Coastguard Worker             unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174*89c4ff92SAndroid Build Coastguard Worker 
175*89c4ff92SAndroid Build Coastguard Worker             m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176*89c4ff92SAndroid Build Coastguard Worker                      ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177*89c4ff92SAndroid Build Coastguard Worker         }
178*89c4ff92SAndroid Build Coastguard Worker     }
179*89c4ff92SAndroid Build Coastguard Worker };
180*89c4ff92SAndroid Build Coastguard Worker 
181*89c4ff92SAndroid Build Coastguard Worker struct TensorNumDimensionsAreCorrect : public Rule
182*89c4ff92SAndroid Build Coastguard Worker {
TensorNumDimensionsAreCorrectarmnn::TensorNumDimensionsAreCorrect183*89c4ff92SAndroid Build Coastguard Worker     TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
184*89c4ff92SAndroid Build Coastguard Worker     {
185*89c4ff92SAndroid Build Coastguard Worker         m_Res = info.GetNumDimensions() == expectedNumDimensions;
186*89c4ff92SAndroid Build Coastguard Worker     }
187*89c4ff92SAndroid Build Coastguard Worker };
188*89c4ff92SAndroid Build Coastguard Worker 
189*89c4ff92SAndroid Build Coastguard Worker struct TensorNumDimensionsAreGreaterOrEqualTo : public Rule
190*89c4ff92SAndroid Build Coastguard Worker {
TensorNumDimensionsAreGreaterOrEqualToarmnn::TensorNumDimensionsAreGreaterOrEqualTo191*89c4ff92SAndroid Build Coastguard Worker     TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
192*89c4ff92SAndroid Build Coastguard Worker     {
193*89c4ff92SAndroid Build Coastguard Worker         m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
194*89c4ff92SAndroid Build Coastguard Worker     }
195*89c4ff92SAndroid Build Coastguard Worker };
196*89c4ff92SAndroid Build Coastguard Worker 
197*89c4ff92SAndroid Build Coastguard Worker } //namespace armnn