1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5
6 #pragma once
7
8 #include "BackendId.hpp"
9 #include <armnn/Exceptions.hpp>
10 #include <cassert>
11
12 namespace armnn
13 {
14
15 struct BackendOptions;
16 using NetworkOptions = std::vector<BackendOptions>;
17
18 using ModelOptions = std::vector<BackendOptions>;
19 using BackendCapabilities = BackendOptions;
20
21 /// Struct for the users to pass backend specific options
22 struct BackendOptions
23 {
24 private:
25 template<typename T>
26 struct CheckAllowed
27 {
28 static const bool value = std::is_same<T, int>::value ||
29 std::is_same<T, unsigned int>::value ||
30 std::is_same<T, float>::value ||
31 std::is_same<T, bool>::value ||
32 std::is_same<T, std::string>::value ||
33 std::is_same<T, const char*>::value;
34 };
35 public:
36
37 /// Very basic type safe variant
38 class Var
39 {
40
41 public:
42 /// Constructors
Var(int i)43 explicit Var(int i) : m_Vals(i), m_Type(VarTypes::Integer) {};
Var(unsigned int u)44 explicit Var(unsigned int u) : m_Vals(u), m_Type(VarTypes::UnsignedInteger) {};
Var(float f)45 explicit Var(float f) : m_Vals(f), m_Type(VarTypes::Float) {};
Var(bool b)46 explicit Var(bool b) : m_Vals(b), m_Type(VarTypes::Boolean) {};
Var(const char * s)47 explicit Var(const char* s) : m_Vals(s), m_Type(VarTypes::String) {};
Var(std::string s)48 explicit Var(std::string s) : m_Vals(s), m_Type(VarTypes::String) {};
49
50 /// Disallow implicit conversions from types not explicitly allowed below.
51 template<typename DisallowedType>
Var(DisallowedType)52 Var(DisallowedType)
53 {
54 static_assert(CheckAllowed<DisallowedType>::value, "Type is not allowed for Var<DisallowedType>.");
55 assert(false && "Unreachable code");
56 }
57
58 /// Copy Construct
Var(const Var & other)59 Var(const Var& other)
60 : m_Type(other.m_Type)
61 {
62 switch(m_Type)
63 {
64 case VarTypes::String:
65 {
66 new (&m_Vals.s) std::string(other.m_Vals.s);
67 break;
68 }
69 default:
70 {
71 DoOp(other, [](auto& a, auto& b)
72 {
73 a = b;
74 });
75 break;
76 }
77 }
78 }
79
80 /// Copy operator
operator =(const Var & other)81 Var& operator=(const Var& other)
82 {
83 // Destroy existing string
84 if (m_Type == VarTypes::String)
85 {
86 Destruct(m_Vals.s);
87 }
88
89 m_Type = other.m_Type;
90 switch(m_Type)
91 {
92 case VarTypes::String:
93 {
94
95 new (&m_Vals.s) std::string(other.m_Vals.s);
96 break;
97 }
98 default:
99 {
100 DoOp(other, [](auto& a, auto& b)
101 {
102 a = b;
103 });
104 break;
105 }
106 }
107
108 return *this;
109 };
110
111 /// Type getters
IsBool() const112 bool IsBool() const { return m_Type == VarTypes::Boolean; }
IsInt() const113 bool IsInt() const { return m_Type == VarTypes::Integer; }
IsUnsignedInt() const114 bool IsUnsignedInt() const { return m_Type == VarTypes::UnsignedInteger; }
IsFloat() const115 bool IsFloat() const { return m_Type == VarTypes::Float; }
IsString() const116 bool IsString() const { return m_Type == VarTypes::String; }
117
118 /// Value getters
AsBool() const119 bool AsBool() const { assert(IsBool()); return m_Vals.b; }
AsInt() const120 int AsInt() const { assert(IsInt()); return m_Vals.i; }
AsUnsignedInt() const121 unsigned int AsUnsignedInt() const { assert(IsUnsignedInt()); return m_Vals.u; }
AsFloat() const122 float AsFloat() const { assert(IsFloat()); return m_Vals.f; }
AsString() const123 std::string AsString() const { assert(IsString()); return m_Vals.s; }
ToString()124 std::string ToString()
125 {
126 if (IsBool()) { return AsBool() ? "true" : "false"; }
127 else if (IsInt()) { return std::to_string(AsInt()); }
128 else if (IsUnsignedInt()) { return std::to_string(AsUnsignedInt()); }
129 else if (IsFloat()) { return std::to_string(AsFloat()); }
130 else if (IsString()) { return AsString(); }
131 else
132 {
133 throw armnn::InvalidArgumentException("Unknown data type for string conversion");
134 }
135 }
136
137 /// Destructor
~Var()138 ~Var()
139 {
140 DoOp(*this, [this](auto& a, auto&)
141 {
142 Destruct(a);
143 });
144 }
145 private:
146 template<typename Func>
DoOp(const Var & other,Func func)147 void DoOp(const Var& other, Func func)
148 {
149 if (other.IsBool())
150 {
151 func(m_Vals.b, other.m_Vals.b);
152 }
153 else if (other.IsInt())
154 {
155 func(m_Vals.i, other.m_Vals.i);
156 }
157 else if (other.IsUnsignedInt())
158 {
159 func(m_Vals.u, other.m_Vals.u);
160 }
161 else if (other.IsFloat())
162 {
163 func(m_Vals.f, other.m_Vals.f);
164 }
165 else if (other.IsString())
166 {
167 func(m_Vals.s, other.m_Vals.s);
168 }
169 }
170
171 template<typename Destructable>
Destruct(Destructable & d)172 void Destruct(Destructable& d)
173 {
174 if (std::is_destructible<Destructable>::value)
175 {
176 d.~Destructable();
177 }
178 }
179
180 private:
181 /// Types which can be stored
182 enum class VarTypes
183 {
184 Boolean,
185 Integer,
186 Float,
187 String,
188 UnsignedInteger
189 };
190
191 /// Union of potential type values.
192 union Vals
193 {
194 int i;
195 unsigned int u;
196 float f;
197 bool b;
198 std::string s;
199
Vals()200 Vals(){}
~Vals()201 ~Vals(){}
202
Vals(int i)203 explicit Vals(int i) : i(i) {};
Vals(unsigned int u)204 explicit Vals(unsigned int u) : u(u) {};
Vals(float f)205 explicit Vals(float f) : f(f) {};
Vals(bool b)206 explicit Vals(bool b) : b(b) {};
Vals(const char * s)207 explicit Vals(const char* s) : s(std::string(s)) {}
Vals(std::string s)208 explicit Vals(std::string s) : s(s) {}
209 };
210
211 Vals m_Vals;
212 VarTypes m_Type;
213 };
214
215 struct BackendOption
216 {
217 public:
BackendOptionarmnn::BackendOptions::BackendOption218 BackendOption(std::string name, bool value)
219 : m_Name(name), m_Value(value)
220 {}
BackendOptionarmnn::BackendOptions::BackendOption221 BackendOption(std::string name, int value)
222 : m_Name(name), m_Value(value)
223 {}
BackendOptionarmnn::BackendOptions::BackendOption224 BackendOption(std::string name, unsigned int value)
225 : m_Name(name), m_Value(value)
226 {}
BackendOptionarmnn::BackendOptions::BackendOption227 BackendOption(std::string name, float value)
228 : m_Name(name), m_Value(value)
229 {}
BackendOptionarmnn::BackendOptions::BackendOption230 BackendOption(std::string name, std::string value)
231 : m_Name(name), m_Value(value)
232 {}
BackendOptionarmnn::BackendOptions::BackendOption233 BackendOption(std::string name, const char* value)
234 : m_Name(name), m_Value(value)
235 {}
236
237 template<typename DisallowedType>
BackendOptionarmnn::BackendOptions::BackendOption238 BackendOption(std::string, DisallowedType)
239 : m_Value(0)
240 {
241 static_assert(CheckAllowed<DisallowedType>::value, "Type is not allowed for BackendOption.");
242 assert(false && "Unreachable code");
243 }
244
245 BackendOption(const BackendOption& other) = default;
246 BackendOption(BackendOption&& other) = default;
247 BackendOption& operator=(const BackendOption& other) = default;
248 BackendOption& operator=(BackendOption&& other) = default;
249 ~BackendOption() = default;
250
GetNamearmnn::BackendOptions::BackendOption251 std::string GetName() const { return m_Name; }
GetValuearmnn::BackendOptions::BackendOption252 Var GetValue() const { return m_Value; }
253
254 private:
255 std::string m_Name; ///< Name of the option
256 Var m_Value; ///< Value of the option. (Bool, int, Float, String)
257 };
258
BackendOptionsarmnn::BackendOptions259 explicit BackendOptions(BackendId backend)
260 : m_TargetBackend(backend)
261 {}
262
BackendOptionsarmnn::BackendOptions263 BackendOptions(BackendId backend, std::initializer_list<BackendOption> options)
264 : m_TargetBackend(backend)
265 , m_Options(options)
266 {}
267
268 BackendOptions(const BackendOptions& other) = default;
269 BackendOptions(BackendOptions&& other) = default;
270 BackendOptions& operator=(const BackendOptions& other) = default;
271 BackendOptions& operator=(BackendOptions&& other) = default;
272
AddOptionarmnn::BackendOptions273 void AddOption(BackendOption&& option)
274 {
275 m_Options.push_back(option);
276 }
277
AddOptionarmnn::BackendOptions278 void AddOption(const BackendOption& option)
279 {
280 m_Options.push_back(option);
281 }
282
GetBackendIdarmnn::BackendOptions283 const BackendId& GetBackendId() const noexcept { return m_TargetBackend; }
GetOptionCountarmnn::BackendOptions284 size_t GetOptionCount() const noexcept { return m_Options.size(); }
GetOptionarmnn::BackendOptions285 const BackendOption& GetOption(size_t idx) const { return m_Options[idx]; }
286
287 private:
288 /// The id for the backend to which the options should be passed.
289 BackendId m_TargetBackend;
290
291 /// The array of options to pass to the backend context
292 std::vector<BackendOption> m_Options;
293 };
294
295
296 template <typename F>
ParseOptions(const std::vector<BackendOptions> & options,BackendId backend,F f)297 void ParseOptions(const std::vector<BackendOptions>& options, BackendId backend, F f)
298 {
299 for (auto optionsGroup : options)
300 {
301 if (optionsGroup.GetBackendId() == backend)
302 {
303 for (size_t i=0; i < optionsGroup.GetOptionCount(); i++)
304 {
305 const BackendOptions::BackendOption option = optionsGroup.GetOption(i);
306 f(option.GetName(), option.GetValue());
307 }
308 }
309 }
310 }
311
ParseBooleanBackendOption(const armnn::BackendOptions::Var & value,bool defaultValue)312 inline bool ParseBooleanBackendOption(const armnn::BackendOptions::Var& value, bool defaultValue)
313 {
314 if (value.IsBool())
315 {
316 return value.AsBool();
317 }
318 return defaultValue;
319 }
320
ParseStringBackendOption(const armnn::BackendOptions::Var & value,std::string defaultValue)321 inline std::string ParseStringBackendOption(const armnn::BackendOptions::Var& value, std::string defaultValue)
322 {
323 if (value.IsString())
324 {
325 return value.AsString();
326 }
327 return defaultValue;
328 }
329
ParseIntBackendOption(const armnn::BackendOptions::Var & value,int defaultValue)330 inline int ParseIntBackendOption(const armnn::BackendOptions::Var& value, int defaultValue)
331 {
332 if (value.IsInt())
333 {
334 return value.AsInt();
335 }
336 return defaultValue;
337 }
338
339 } //namespace armnn
340