xref: /aosp_15_r20/external/armnn/include/armnn/BackendOptions.hpp (revision 89c4ff92f2867872bb9e2354d150bf0c8c502810)
1 //
2 // Copyright © 2019 Arm Ltd. All rights reserved.
3 // SPDX-License-Identifier: MIT
4 //
5 
6 #pragma once
7 
8 #include "BackendId.hpp"
9 #include <armnn/Exceptions.hpp>
10 #include <cassert>
11 
12 namespace armnn
13 {
14 
15 struct BackendOptions;
16 using NetworkOptions = std::vector<BackendOptions>;
17 
18 using ModelOptions = std::vector<BackendOptions>;
19 using BackendCapabilities = BackendOptions;
20 
21 /// Struct for the users to pass backend specific options
22 struct BackendOptions
23 {
24 private:
25     template<typename T>
26     struct CheckAllowed
27     {
28         static const bool value = std::is_same<T, int>::value ||
29                                   std::is_same<T, unsigned int>::value ||
30                                   std::is_same<T, float>::value ||
31                                   std::is_same<T, bool>::value ||
32                                   std::is_same<T, std::string>::value ||
33                                   std::is_same<T, const char*>::value;
34     };
35 public:
36 
37     /// Very basic type safe variant
38     class Var
39     {
40 
41     public:
42         /// Constructors
Var(int i)43         explicit Var(int i) : m_Vals(i), m_Type(VarTypes::Integer) {};
Var(unsigned int u)44         explicit Var(unsigned int u) : m_Vals(u), m_Type(VarTypes::UnsignedInteger) {};
Var(float f)45         explicit Var(float f) : m_Vals(f), m_Type(VarTypes::Float) {};
Var(bool b)46         explicit Var(bool b) : m_Vals(b), m_Type(VarTypes::Boolean) {};
Var(const char * s)47         explicit Var(const char* s) : m_Vals(s), m_Type(VarTypes::String) {};
Var(std::string s)48         explicit Var(std::string s) : m_Vals(s), m_Type(VarTypes::String) {};
49 
50         /// Disallow implicit conversions from types not explicitly allowed below.
51         template<typename DisallowedType>
Var(DisallowedType)52         Var(DisallowedType)
53         {
54             static_assert(CheckAllowed<DisallowedType>::value, "Type is not allowed for Var<DisallowedType>.");
55             assert(false && "Unreachable code");
56         }
57 
58         /// Copy Construct
Var(const Var & other)59         Var(const Var& other)
60             : m_Type(other.m_Type)
61         {
62             switch(m_Type)
63             {
64                 case VarTypes::String:
65                 {
66                     new (&m_Vals.s) std::string(other.m_Vals.s);
67                     break;
68                 }
69                 default:
70                 {
71                     DoOp(other, [](auto& a, auto& b)
72                         {
73                             a = b;
74                         });
75                     break;
76                 }
77             }
78         }
79 
80         /// Copy operator
operator =(const Var & other)81         Var& operator=(const Var& other)
82         {
83             // Destroy existing string
84             if (m_Type == VarTypes::String)
85             {
86                 Destruct(m_Vals.s);
87             }
88 
89             m_Type = other.m_Type;
90             switch(m_Type)
91             {
92                 case VarTypes::String:
93                 {
94 
95                     new (&m_Vals.s) std::string(other.m_Vals.s);
96                     break;
97                 }
98                 default:
99                 {
100                     DoOp(other, [](auto& a, auto& b)
101                         {
102                             a = b;
103                         });
104                     break;
105                 }
106             }
107 
108             return *this;
109         };
110 
111         /// Type getters
IsBool() const112         bool IsBool() const { return m_Type == VarTypes::Boolean; }
IsInt() const113         bool IsInt() const { return m_Type == VarTypes::Integer; }
IsUnsignedInt() const114         bool IsUnsignedInt() const { return m_Type == VarTypes::UnsignedInteger; }
IsFloat() const115         bool IsFloat() const { return m_Type == VarTypes::Float; }
IsString() const116         bool IsString() const { return m_Type == VarTypes::String; }
117 
118         /// Value getters
AsBool() const119         bool AsBool() const { assert(IsBool()); return m_Vals.b; }
AsInt() const120         int AsInt() const { assert(IsInt()); return m_Vals.i; }
AsUnsignedInt() const121         unsigned int AsUnsignedInt() const { assert(IsUnsignedInt()); return m_Vals.u; }
AsFloat() const122         float AsFloat() const { assert(IsFloat()); return m_Vals.f; }
AsString() const123         std::string AsString() const { assert(IsString()); return m_Vals.s; }
ToString()124         std::string ToString()
125         {
126             if (IsBool()) { return AsBool() ? "true" : "false"; }
127             else if (IsInt()) { return std::to_string(AsInt()); }
128             else if (IsUnsignedInt()) { return std::to_string(AsUnsignedInt()); }
129             else if (IsFloat()) { return std::to_string(AsFloat()); }
130             else if (IsString()) { return AsString(); }
131             else
132             {
133                 throw armnn::InvalidArgumentException("Unknown data type for string conversion");
134             }
135         }
136 
137         /// Destructor
~Var()138         ~Var()
139         {
140             DoOp(*this, [this](auto& a, auto&)
141                 {
142                     Destruct(a);
143                 });
144         }
145     private:
146         template<typename Func>
DoOp(const Var & other,Func func)147         void DoOp(const Var& other, Func func)
148         {
149             if (other.IsBool())
150             {
151                 func(m_Vals.b, other.m_Vals.b);
152             }
153             else if (other.IsInt())
154             {
155                 func(m_Vals.i, other.m_Vals.i);
156             }
157             else if (other.IsUnsignedInt())
158             {
159                 func(m_Vals.u, other.m_Vals.u);
160             }
161             else if (other.IsFloat())
162             {
163                 func(m_Vals.f, other.m_Vals.f);
164             }
165             else if (other.IsString())
166             {
167                 func(m_Vals.s, other.m_Vals.s);
168             }
169         }
170 
171         template<typename Destructable>
Destruct(Destructable & d)172         void Destruct(Destructable& d)
173         {
174             if (std::is_destructible<Destructable>::value)
175             {
176                 d.~Destructable();
177             }
178         }
179 
180     private:
181         /// Types which can be stored
182         enum class VarTypes
183         {
184             Boolean,
185             Integer,
186             Float,
187             String,
188             UnsignedInteger
189         };
190 
191         /// Union of potential type values.
192         union Vals
193         {
194             int i;
195             unsigned int u;
196             float f;
197             bool b;
198             std::string s;
199 
Vals()200             Vals(){}
~Vals()201             ~Vals(){}
202 
Vals(int i)203             explicit Vals(int i) : i(i) {};
Vals(unsigned int u)204             explicit Vals(unsigned int u) : u(u) {};
Vals(float f)205             explicit Vals(float f) : f(f) {};
Vals(bool b)206             explicit Vals(bool b) : b(b) {};
Vals(const char * s)207             explicit Vals(const char* s) : s(std::string(s)) {}
Vals(std::string s)208             explicit Vals(std::string s) : s(s) {}
209        };
210 
211         Vals m_Vals;
212         VarTypes m_Type;
213     };
214 
215     struct BackendOption
216     {
217     public:
BackendOptionarmnn::BackendOptions::BackendOption218         BackendOption(std::string name, bool value)
219             : m_Name(name), m_Value(value)
220         {}
BackendOptionarmnn::BackendOptions::BackendOption221         BackendOption(std::string name, int value)
222             : m_Name(name), m_Value(value)
223         {}
BackendOptionarmnn::BackendOptions::BackendOption224         BackendOption(std::string name, unsigned int value)
225                 : m_Name(name), m_Value(value)
226         {}
BackendOptionarmnn::BackendOptions::BackendOption227         BackendOption(std::string name, float value)
228             : m_Name(name), m_Value(value)
229         {}
BackendOptionarmnn::BackendOptions::BackendOption230         BackendOption(std::string name, std::string value)
231             : m_Name(name), m_Value(value)
232         {}
BackendOptionarmnn::BackendOptions::BackendOption233         BackendOption(std::string name, const char* value)
234             : m_Name(name), m_Value(value)
235         {}
236 
237         template<typename DisallowedType>
BackendOptionarmnn::BackendOptions::BackendOption238         BackendOption(std::string, DisallowedType)
239             : m_Value(0)
240         {
241             static_assert(CheckAllowed<DisallowedType>::value, "Type is not allowed for BackendOption.");
242             assert(false && "Unreachable code");
243         }
244 
245         BackendOption(const BackendOption& other) = default;
246         BackendOption(BackendOption&& other) = default;
247         BackendOption& operator=(const BackendOption& other) = default;
248         BackendOption& operator=(BackendOption&& other) = default;
249         ~BackendOption() = default;
250 
GetNamearmnn::BackendOptions::BackendOption251         std::string GetName() const   { return m_Name; }
GetValuearmnn::BackendOptions::BackendOption252         Var GetValue() const          { return m_Value; }
253 
254     private:
255         std::string m_Name;         ///< Name of the option
256         Var         m_Value;        ///< Value of the option. (Bool, int, Float, String)
257     };
258 
BackendOptionsarmnn::BackendOptions259     explicit BackendOptions(BackendId backend)
260         : m_TargetBackend(backend)
261     {}
262 
BackendOptionsarmnn::BackendOptions263     BackendOptions(BackendId backend, std::initializer_list<BackendOption> options)
264         : m_TargetBackend(backend)
265         , m_Options(options)
266     {}
267 
268     BackendOptions(const BackendOptions& other) = default;
269     BackendOptions(BackendOptions&& other) = default;
270     BackendOptions& operator=(const BackendOptions& other) = default;
271     BackendOptions& operator=(BackendOptions&& other) = default;
272 
AddOptionarmnn::BackendOptions273     void AddOption(BackendOption&& option)
274     {
275         m_Options.push_back(option);
276     }
277 
AddOptionarmnn::BackendOptions278     void AddOption(const BackendOption& option)
279     {
280         m_Options.push_back(option);
281     }
282 
GetBackendIdarmnn::BackendOptions283     const BackendId& GetBackendId() const noexcept { return m_TargetBackend; }
GetOptionCountarmnn::BackendOptions284     size_t GetOptionCount() const noexcept { return m_Options.size(); }
GetOptionarmnn::BackendOptions285     const BackendOption& GetOption(size_t idx) const { return m_Options[idx]; }
286 
287 private:
288     /// The id for the backend to which the options should be passed.
289     BackendId m_TargetBackend;
290 
291     /// The array of options to pass to the backend context
292     std::vector<BackendOption> m_Options;
293 };
294 
295 
296 template <typename F>
ParseOptions(const std::vector<BackendOptions> & options,BackendId backend,F f)297 void ParseOptions(const std::vector<BackendOptions>& options, BackendId backend, F f)
298 {
299     for (auto optionsGroup : options)
300     {
301         if (optionsGroup.GetBackendId() == backend)
302         {
303             for (size_t i=0; i < optionsGroup.GetOptionCount(); i++)
304             {
305                 const BackendOptions::BackendOption option = optionsGroup.GetOption(i);
306                 f(option.GetName(), option.GetValue());
307             }
308         }
309     }
310 }
311 
ParseBooleanBackendOption(const armnn::BackendOptions::Var & value,bool defaultValue)312 inline bool ParseBooleanBackendOption(const armnn::BackendOptions::Var& value, bool defaultValue)
313 {
314     if (value.IsBool())
315     {
316         return value.AsBool();
317     }
318     return defaultValue;
319 }
320 
ParseStringBackendOption(const armnn::BackendOptions::Var & value,std::string defaultValue)321 inline std::string ParseStringBackendOption(const armnn::BackendOptions::Var& value, std::string defaultValue)
322 {
323     if (value.IsString())
324     {
325         return value.AsString();
326     }
327     return defaultValue;
328 }
329 
ParseIntBackendOption(const armnn::BackendOptions::Var & value,int defaultValue)330 inline int ParseIntBackendOption(const armnn::BackendOptions::Var& value, int defaultValue)
331 {
332     if (value.IsInt())
333     {
334         return value.AsInt();
335     }
336     return defaultValue;
337 }
338 
339 } //namespace armnn
340