1 /*
2 * Copyright (c) 2021 Arm Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24 #include "src/runtime/CL/mlgo/HeuristicTree.h"
25 #include "arm_compute/core/Log.h"
26
27 #include "support/Cast.h"
28
29 #include <algorithm>
30 #include <deque>
31 #include <set>
32 namespace arm_compute
33 {
34 namespace mlgo
35 {
36 namespace
37 {
evaluate(GEMMShape shape,Condition cond)38 bool evaluate(GEMMShape shape, Condition cond)
39 {
40 // PRE: all features and ConditionalOps are valid
41 constexpr float eps = 0.0001f;
42 // Calculate all secondary features
43 std::vector<std::pair<std::string, float>> cond_values
44 {
45 { "m", static_cast<float>(shape.m) },
46 { "n", static_cast<float>(shape.n) },
47 { "k", static_cast<float>(shape.k) },
48 { "b", static_cast<float>(shape.b) },
49 { "r_mn", static_cast<float>(shape.m) / shape.n },
50 { "r_mk", static_cast<float>(shape.m) / shape.k },
51 { "r_nk", static_cast<float>(shape.n) / shape.k },
52 { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
53 { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
54 };
55 auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
56 [&cond](decltype(*cond_values.begin()) it)
57 {
58 return it.first == cond.feature;
59 });
60
61 ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
62 const float cond_value = cond_value_pair_it->second;
63 switch(cond.op)
64 {
65 case ConditionalOp::LT:
66 {
67 return cond_value < cond.threshold;
68 }
69 case ConditionalOp::LE:
70 {
71 return cond_value <= cond.threshold;
72 }
73 case ConditionalOp::GT:
74 {
75 return cond_value > cond.threshold;
76 }
77 case ConditionalOp::GE:
78 {
79 return cond_value >= cond.threshold;
80 }
81 case ConditionalOp::EQ:
82 default:
83 {
84 return std::abs(cond_value - cond.threshold) < eps;
85 }
86 }
87 }
88
89 } // namespace
90
91 constexpr size_t HeuristicTree::_max_num_nodes;
92 constexpr size_t HeuristicTree::_max_query_depth;
93 constexpr HeuristicTree::NodeID HeuristicTree::_root;
94
HeuristicTree()95 HeuristicTree::HeuristicTree()
96 : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
97 {
98 }
99
HeuristicTree(TreeID id,HeuristicType h_type,const std::string & ip_target,DataType data_type)100 HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
101 : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
102 {
103 }
104
105 template <typename T>
query(GEMMShape shape) const106 std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
107 {
108 // Root ID = 0;
109 auto cur_node = _tree.at(_root).get();
110 size_t depth = 0;
111 while(cur_node->type() != NodeType::Leaf)
112 {
113 if(depth > _max_query_depth)
114 {
115 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
116 return std::make_pair(false, T{});
117 }
118 ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
119 auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
120 if(evaluate(shape, br_node->condition))
121 {
122 cur_node = _tree.at(br_node->true_node).get();
123 }
124 else
125 {
126 cur_node = _tree.at(br_node->false_node).get();
127 }
128 ++depth;
129 }
130 ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
131 auto l_node = utils::cast::polymorphic_downcast<LeafNode<T> *>(cur_node);
132 return std::make_pair(true, l_node->value);
133 }
134
135 template <typename T>
add_leaf(NodeID id,T val)136 bool HeuristicTree::add_leaf(NodeID id, T val)
137 {
138 if(_tree.size() >= _max_num_nodes)
139 {
140 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
141 return false;
142 }
143 if(_tree.find(id) != _tree.end())
144 {
145 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
146 return false;
147 }
148 _tree[id] = std::make_unique<LeafNode<T>>(id, val);
149 return true;
150 }
151
add_branch(NodeID id,Condition cond,NodeID t_node,NodeID f_node)152 bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
153 {
154 if(_tree.size() >= _max_num_nodes)
155 {
156 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
157 return false;
158 }
159
160 const std::set<std::string> supported_features =
161 {
162 "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
163 };
164 const auto orig_feature = cond.feature;
165 std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
166 {
167 return std::tolower(c);
168 });
169 if(supported_features.find(cond.feature) == supported_features.end())
170 {
171 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
172 return false;
173 }
174
175 if(_tree.find(id) != _tree.end())
176 {
177 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
178 return false;
179 }
180 _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
181 return true;
182 }
183
check_if_structurally_correct() const184 bool HeuristicTree::check_if_structurally_correct() const
185 {
186 std::set<NodeID> visited;
187 std::deque<NodeID> to_visit{ _root };
188
189 while(!to_visit.empty())
190 {
191 auto id = to_visit.front();
192 to_visit.pop_front();
193 if(_tree.find(id) == _tree.end())
194 {
195 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
196 return false;
197 }
198 auto not_seen_before = visited.insert(id);
199 if(!not_seen_before.second)
200 {
201 ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
202 return false;
203 }
204 auto cur_node = _tree.at(id).get();
205 if(cur_node->type() == NodeType::Branch)
206 {
207 auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
208 to_visit.push_back(br_node->true_node);
209 to_visit.push_back(br_node->false_node);
210 }
211 }
212 if(visited.size() != _tree.size())
213 {
214 ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
215 return false;
216 }
217 return true;
218 }
219
check()220 bool HeuristicTree::check()
221 {
222 if(_tree.empty())
223 {
224 ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
225 return false;
226 }
227 if(_tree.find(_root) == _tree.end())
228 {
229 ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
230 return false;
231 }
232 return check_if_structurally_correct();
233 }
234
235 /** Explicit template instantiation @relates HeuristicTree */
236 template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shape) const;
237 /** Explicit template instantiation @relates HeuristicTree */
238 template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
239 /** Explicit template instantiation @relates HeuristicTree */
240 template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
241 /** Explicit template instantiation @relates HeuristicTree */
242 template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
243
244 /** Explicit template instantiation @relates HeuristicTree */
245 template bool HeuristicTree::add_leaf(NodeID id, GEMMType val);
246 /** Explicit template instantiation @relates HeuristicTree */
247 template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val);
248 /** Explicit template instantiation @relates HeuristicTree */
249 template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val);
250 /** Explicit template instantiation @relates HeuristicTree */
251 template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val);
252
253 } // namespace mlgo
254
255 } // namespace arm_compute
256