xref: /aosp_15_r20/external/armnn/src/backends/backendsCommon/LayerSupportRules.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include <armnn/utility/Assert.hpp>
9 #include <algorithm>
10 
11 namespace armnn
12 {
13 
GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)14 inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
15 {
16     if (!weightsType)
17     {
18         return weightsType;
19     }
20 
21     switch(weightsType.value())
22     {
23         case armnn::DataType::Float16:
24         case armnn::DataType::Float32:
25             return weightsType;
26         case armnn::DataType::QAsymmS8:
27         case armnn::DataType::QAsymmU8:
28         case armnn::DataType::QSymmS8:
29         case armnn::DataType::QSymmS16:
30             return armnn::DataType::Signed32;
31         default:
32             ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
33     }
34     return armnn::EmptyOptional();
35 }
36 
37 template<typename F>
CheckSupportRule(F rule,Optional<std::string &> reasonIfUnsupported,const char * reason)38 bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
39 {
40     bool supported = rule();
41     if (!supported && reason)
42     {
43         reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
44     }
45     return supported;
46 }
47 
48 struct Rule
49 {
operator ()armnn::Rule50     bool operator()() const
51     {
52         return m_Res;
53     }
54 
55     bool m_Res = true;
56 };
57 
58 template<typename T>
AllTypesAreEqualImpl(T)59 bool AllTypesAreEqualImpl(T)
60 {
61     return true;
62 }
63 
64 template<typename T, typename... Rest>
AllTypesAreEqualImpl(T t1,T t2,Rest...rest)65 bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
66 {
67     static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
68 
69     return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
70 }
71 
72 struct TypesAreEqual : public Rule
73 {
74     template<typename ... Ts>
TypesAreEqualarmnn::TypesAreEqual75     TypesAreEqual(const Ts&... ts)
76     {
77         m_Res = AllTypesAreEqualImpl(ts...);
78     }
79 };
80 
81 struct QuantizationParametersAreEqual : public Rule
82 {
QuantizationParametersAreEqualarmnn::QuantizationParametersAreEqual83     QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
84     {
85         m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
86                 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
87     }
88 };
89 
90 struct TypeAnyOf : public Rule
91 {
92     template<typename Container>
TypeAnyOfarmnn::TypeAnyOf93     TypeAnyOf(const TensorInfo& info, const Container& c)
94     {
95         m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
96         {
97             return dt == info.GetDataType();
98         });
99     }
100 };
101 
102 struct TypeIs : public Rule
103 {
TypeIsarmnn::TypeIs104     TypeIs(const TensorInfo& info, DataType dt)
105     {
106         m_Res = dt == info.GetDataType();
107     }
108 };
109 
110 struct TypeNotPerAxisQuantized : public Rule
111 {
TypeNotPerAxisQuantizedarmnn::TypeNotPerAxisQuantized112     TypeNotPerAxisQuantized(const TensorInfo& info)
113     {
114         m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
115     }
116 };
117 
118 struct BiasAndWeightsTypesMatch : public Rule
119 {
BiasAndWeightsTypesMatcharmnn::BiasAndWeightsTypesMatch120     BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121     {
122         m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123     }
124 };
125 
126 struct BiasAndWeightsTypesCompatible : public Rule
127 {
128     template<typename Container>
BiasAndWeightsTypesCompatiblearmnn::BiasAndWeightsTypesCompatible129     BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130     {
131         m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132             {
133                 return dt ==  GetBiasTypeFromWeightsType(info.GetDataType()).value();
134             });
135     }
136 };
137 
138 struct ShapesAreSameRank : public Rule
139 {
ShapesAreSameRankarmnn::ShapesAreSameRank140     ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141     {
142         m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143     }
144 };
145 
146 struct ShapesAreSameTotalSize : public Rule
147 {
ShapesAreSameTotalSizearmnn::ShapesAreSameTotalSize148     ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149     {
150         m_Res = info0.GetNumElements() == info1.GetNumElements();
151     }
152 };
153 
154 struct ShapesAreBroadcastCompatible : public Rule
155 {
CalcInputSizearmnn::ShapesAreBroadcastCompatible156     unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157     {
158         unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159         unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160         return sizeIn;
161     }
162 
ShapesAreBroadcastCompatiblearmnn::ShapesAreBroadcastCompatible163     ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164     {
165         const TensorShape& shape0 = in0.GetShape();
166         const TensorShape& shape1 = in1.GetShape();
167         const TensorShape& outShape = out.GetShape();
168 
169         for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170         {
171             unsigned int sizeOut = outShape[i];
172             unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173             unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174 
175             m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176                      ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177         }
178     }
179 };
180 
181 struct TensorNumDimensionsAreCorrect : public Rule
182 {
TensorNumDimensionsAreCorrectarmnn::TensorNumDimensionsAreCorrect183     TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
184     {
185         m_Res = info.GetNumDimensions() == expectedNumDimensions;
186     }
187 };
188 
189 struct TensorNumDimensionsAreGreaterOrEqualTo : public Rule
190 {
TensorNumDimensionsAreGreaterOrEqualToarmnn::TensorNumDimensionsAreGreaterOrEqualTo191     TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
192     {
193         m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
194     }
195 };
196 
197 } //namespace armnn