1 //
2 // Copyright © 2017 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include <armnn/utility/Assert.hpp>
9 #include <algorithm>
10
11 namespace armnn
12 {
13
GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)14 inline armnn::Optional<armnn::DataType> GetBiasTypeFromWeightsType(armnn::Optional<armnn::DataType> weightsType)
15 {
16 if (!weightsType)
17 {
18 return weightsType;
19 }
20
21 switch(weightsType.value())
22 {
23 case armnn::DataType::Float16:
24 case armnn::DataType::Float32:
25 return weightsType;
26 case armnn::DataType::QAsymmS8:
27 case armnn::DataType::QAsymmU8:
28 case armnn::DataType::QSymmS8:
29 case armnn::DataType::QSymmS16:
30 return armnn::DataType::Signed32;
31 default:
32 ARMNN_ASSERT_MSG(false, "GetBiasTypeFromWeightsType(): Unsupported data type.");
33 }
34 return armnn::EmptyOptional();
35 }
36
37 template<typename F>
CheckSupportRule(F rule,Optional<std::string &> reasonIfUnsupported,const char * reason)38 bool CheckSupportRule(F rule, Optional<std::string&> reasonIfUnsupported, const char* reason)
39 {
40 bool supported = rule();
41 if (!supported && reason)
42 {
43 reasonIfUnsupported.value() += std::string(reason) + "\n"; // Append the reason on a new line
44 }
45 return supported;
46 }
47
48 struct Rule
49 {
operator ()armnn::Rule50 bool operator()() const
51 {
52 return m_Res;
53 }
54
55 bool m_Res = true;
56 };
57
58 template<typename T>
AllTypesAreEqualImpl(T)59 bool AllTypesAreEqualImpl(T)
60 {
61 return true;
62 }
63
64 template<typename T, typename... Rest>
AllTypesAreEqualImpl(T t1,T t2,Rest...rest)65 bool AllTypesAreEqualImpl(T t1, T t2, Rest... rest)
66 {
67 static_assert(std::is_same<T, TensorInfo>::value, "Type T must be a TensorInfo");
68
69 return (t1.GetDataType() == t2.GetDataType()) && AllTypesAreEqualImpl(t2, rest...);
70 }
71
72 struct TypesAreEqual : public Rule
73 {
74 template<typename ... Ts>
TypesAreEqualarmnn::TypesAreEqual75 TypesAreEqual(const Ts&... ts)
76 {
77 m_Res = AllTypesAreEqualImpl(ts...);
78 }
79 };
80
81 struct QuantizationParametersAreEqual : public Rule
82 {
QuantizationParametersAreEqualarmnn::QuantizationParametersAreEqual83 QuantizationParametersAreEqual(const TensorInfo& info0, const TensorInfo& info1)
84 {
85 m_Res = info0.GetQuantizationScale() == info1.GetQuantizationScale() &&
86 info0.GetQuantizationOffset() == info1.GetQuantizationOffset();
87 }
88 };
89
90 struct TypeAnyOf : public Rule
91 {
92 template<typename Container>
TypeAnyOfarmnn::TypeAnyOf93 TypeAnyOf(const TensorInfo& info, const Container& c)
94 {
95 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
96 {
97 return dt == info.GetDataType();
98 });
99 }
100 };
101
102 struct TypeIs : public Rule
103 {
TypeIsarmnn::TypeIs104 TypeIs(const TensorInfo& info, DataType dt)
105 {
106 m_Res = dt == info.GetDataType();
107 }
108 };
109
110 struct TypeNotPerAxisQuantized : public Rule
111 {
TypeNotPerAxisQuantizedarmnn::TypeNotPerAxisQuantized112 TypeNotPerAxisQuantized(const TensorInfo& info)
113 {
114 m_Res = !info.IsQuantized() || !info.HasPerAxisQuantization();
115 }
116 };
117
118 struct BiasAndWeightsTypesMatch : public Rule
119 {
BiasAndWeightsTypesMatcharmnn::BiasAndWeightsTypesMatch120 BiasAndWeightsTypesMatch(const TensorInfo& biases, const TensorInfo& weights)
121 {
122 m_Res = biases.GetDataType() == GetBiasTypeFromWeightsType(weights.GetDataType()).value();
123 }
124 };
125
126 struct BiasAndWeightsTypesCompatible : public Rule
127 {
128 template<typename Container>
BiasAndWeightsTypesCompatiblearmnn::BiasAndWeightsTypesCompatible129 BiasAndWeightsTypesCompatible(const TensorInfo& info, const Container& c)
130 {
131 m_Res = std::any_of(c.begin(), c.end(), [&info](DataType dt)
132 {
133 return dt == GetBiasTypeFromWeightsType(info.GetDataType()).value();
134 });
135 }
136 };
137
138 struct ShapesAreSameRank : public Rule
139 {
ShapesAreSameRankarmnn::ShapesAreSameRank140 ShapesAreSameRank(const TensorInfo& info0, const TensorInfo& info1)
141 {
142 m_Res = info0.GetShape().GetNumDimensions() == info1.GetShape().GetNumDimensions();
143 }
144 };
145
146 struct ShapesAreSameTotalSize : public Rule
147 {
ShapesAreSameTotalSizearmnn::ShapesAreSameTotalSize148 ShapesAreSameTotalSize(const TensorInfo& info0, const TensorInfo& info1)
149 {
150 m_Res = info0.GetNumElements() == info1.GetNumElements();
151 }
152 };
153
154 struct ShapesAreBroadcastCompatible : public Rule
155 {
CalcInputSizearmnn::ShapesAreBroadcastCompatible156 unsigned int CalcInputSize(const TensorShape& in, const TensorShape& out, unsigned int idx)
157 {
158 unsigned int offset = out.GetNumDimensions() - in.GetNumDimensions();
159 unsigned int sizeIn = (idx < offset) ? 1 : in[idx-offset];
160 return sizeIn;
161 }
162
ShapesAreBroadcastCompatiblearmnn::ShapesAreBroadcastCompatible163 ShapesAreBroadcastCompatible(const TensorInfo& in0, const TensorInfo& in1, const TensorInfo& out)
164 {
165 const TensorShape& shape0 = in0.GetShape();
166 const TensorShape& shape1 = in1.GetShape();
167 const TensorShape& outShape = out.GetShape();
168
169 for (unsigned int i=0; i < outShape.GetNumDimensions() && m_Res; i++)
170 {
171 unsigned int sizeOut = outShape[i];
172 unsigned int sizeIn0 = CalcInputSize(shape0, outShape, i);
173 unsigned int sizeIn1 = CalcInputSize(shape1, outShape, i);
174
175 m_Res &= ((sizeIn0 == sizeOut) || (sizeIn0 == 1)) &&
176 ((sizeIn1 == sizeOut) || (sizeIn1 == 1));
177 }
178 }
179 };
180
181 struct TensorNumDimensionsAreCorrect : public Rule
182 {
TensorNumDimensionsAreCorrectarmnn::TensorNumDimensionsAreCorrect183 TensorNumDimensionsAreCorrect(const TensorInfo& info, unsigned int expectedNumDimensions)
184 {
185 m_Res = info.GetNumDimensions() == expectedNumDimensions;
186 }
187 };
188
189 struct TensorNumDimensionsAreGreaterOrEqualTo : public Rule
190 {
TensorNumDimensionsAreGreaterOrEqualToarmnn::TensorNumDimensionsAreGreaterOrEqualTo191 TensorNumDimensionsAreGreaterOrEqualTo(const TensorInfo& info, unsigned int numDimensionsToCompare)
192 {
193 m_Res = info.GetNumDimensions() >= numDimensionsToCompare;
194 }
195 };
196
197 } //namespace armnn