xref: /aosp_15_r20/external/ComputeLibrary/src/runtime/CL/mlgo/HeuristicTree.cpp (revision c217d954acce2dbc11938adb493fc0abd69584f3)
1*c217d954SCole Faust /*
2*c217d954SCole Faust  * Copyright (c) 2021 Arm Limited.
3*c217d954SCole Faust  *
4*c217d954SCole Faust  * SPDX-License-Identifier: MIT
5*c217d954SCole Faust  *
6*c217d954SCole Faust  * Permission is hereby granted, free of charge, to any person obtaining a copy
7*c217d954SCole Faust  * of this software and associated documentation files (the "Software"), to
8*c217d954SCole Faust  * deal in the Software without restriction, including without limitation the
9*c217d954SCole Faust  * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10*c217d954SCole Faust  * sell copies of the Software, and to permit persons to whom the Software is
11*c217d954SCole Faust  * furnished to do so, subject to the following conditions:
12*c217d954SCole Faust  *
13*c217d954SCole Faust  * The above copyright notice and this permission notice shall be included in all
14*c217d954SCole Faust  * copies or substantial portions of the Software.
15*c217d954SCole Faust  *
16*c217d954SCole Faust  * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17*c217d954SCole Faust  * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18*c217d954SCole Faust  * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19*c217d954SCole Faust  * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20*c217d954SCole Faust  * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21*c217d954SCole Faust  * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22*c217d954SCole Faust  * SOFTWARE.
23*c217d954SCole Faust  */
24*c217d954SCole Faust #include "src/runtime/CL/mlgo/HeuristicTree.h"
25*c217d954SCole Faust #include "arm_compute/core/Log.h"
26*c217d954SCole Faust 
27*c217d954SCole Faust #include "support/Cast.h"
28*c217d954SCole Faust 
29*c217d954SCole Faust #include <algorithm>
30*c217d954SCole Faust #include <deque>
31*c217d954SCole Faust #include <set>
32*c217d954SCole Faust namespace arm_compute
33*c217d954SCole Faust {
34*c217d954SCole Faust namespace mlgo
35*c217d954SCole Faust {
36*c217d954SCole Faust namespace
37*c217d954SCole Faust {
evaluate(GEMMShape shape,Condition cond)38*c217d954SCole Faust bool evaluate(GEMMShape shape, Condition cond)
39*c217d954SCole Faust {
40*c217d954SCole Faust     // PRE: all features and ConditionalOps are valid
41*c217d954SCole Faust     constexpr float eps = 0.0001f;
42*c217d954SCole Faust     // Calculate all secondary features
43*c217d954SCole Faust     std::vector<std::pair<std::string, float>> cond_values
44*c217d954SCole Faust     {
45*c217d954SCole Faust         { "m", static_cast<float>(shape.m) },
46*c217d954SCole Faust         { "n", static_cast<float>(shape.n) },
47*c217d954SCole Faust         { "k", static_cast<float>(shape.k) },
48*c217d954SCole Faust         { "b", static_cast<float>(shape.b) },
49*c217d954SCole Faust         { "r_mn", static_cast<float>(shape.m) / shape.n },
50*c217d954SCole Faust         { "r_mk", static_cast<float>(shape.m) / shape.k },
51*c217d954SCole Faust         { "r_nk", static_cast<float>(shape.n) / shape.k },
52*c217d954SCole Faust         { "r_mnk", static_cast<float>(shape.m) / (static_cast<float>(shape.n) / shape.k) },
53*c217d954SCole Faust         { "workload", (static_cast<float>(shape.m) * shape.n * shape.b) / 20.0 }
54*c217d954SCole Faust     };
55*c217d954SCole Faust     auto cond_value_pair_it = std::find_if(cond_values.begin(), cond_values.end(),
56*c217d954SCole Faust                                            [&cond](decltype(*cond_values.begin()) it)
57*c217d954SCole Faust     {
58*c217d954SCole Faust         return it.first == cond.feature;
59*c217d954SCole Faust     });
60*c217d954SCole Faust 
61*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON(cond_value_pair_it == cond_values.end());
62*c217d954SCole Faust     const float cond_value = cond_value_pair_it->second;
63*c217d954SCole Faust     switch(cond.op)
64*c217d954SCole Faust     {
65*c217d954SCole Faust         case ConditionalOp::LT:
66*c217d954SCole Faust         {
67*c217d954SCole Faust             return cond_value < cond.threshold;
68*c217d954SCole Faust         }
69*c217d954SCole Faust         case ConditionalOp::LE:
70*c217d954SCole Faust         {
71*c217d954SCole Faust             return cond_value <= cond.threshold;
72*c217d954SCole Faust         }
73*c217d954SCole Faust         case ConditionalOp::GT:
74*c217d954SCole Faust         {
75*c217d954SCole Faust             return cond_value > cond.threshold;
76*c217d954SCole Faust         }
77*c217d954SCole Faust         case ConditionalOp::GE:
78*c217d954SCole Faust         {
79*c217d954SCole Faust             return cond_value >= cond.threshold;
80*c217d954SCole Faust         }
81*c217d954SCole Faust         case ConditionalOp::EQ:
82*c217d954SCole Faust         default:
83*c217d954SCole Faust         {
84*c217d954SCole Faust             return std::abs(cond_value - cond.threshold) < eps;
85*c217d954SCole Faust         }
86*c217d954SCole Faust     }
87*c217d954SCole Faust }
88*c217d954SCole Faust 
89*c217d954SCole Faust } // namespace
90*c217d954SCole Faust 
91*c217d954SCole Faust constexpr size_t                HeuristicTree::_max_num_nodes;
92*c217d954SCole Faust constexpr size_t                HeuristicTree::_max_query_depth;
93*c217d954SCole Faust constexpr HeuristicTree::NodeID HeuristicTree::_root;
94*c217d954SCole Faust 
HeuristicTree()95*c217d954SCole Faust HeuristicTree::HeuristicTree()
96*c217d954SCole Faust     : HeuristicTree(0, HeuristicType::GEMM_Type, "", DataType::F32)
97*c217d954SCole Faust {
98*c217d954SCole Faust }
99*c217d954SCole Faust 
HeuristicTree(TreeID id,HeuristicType h_type,const std::string & ip_target,DataType data_type)100*c217d954SCole Faust HeuristicTree::HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type)
101*c217d954SCole Faust     : _id{ id }, _heuristic_type{ h_type }, _ip_target{ ip_target }, _data_type{ data_type }, _tree{}
102*c217d954SCole Faust {
103*c217d954SCole Faust }
104*c217d954SCole Faust 
105*c217d954SCole Faust template <typename T>
query(GEMMShape shape) const106*c217d954SCole Faust std::pair<bool, T> HeuristicTree::query(GEMMShape shape) const
107*c217d954SCole Faust {
108*c217d954SCole Faust     // Root ID = 0;
109*c217d954SCole Faust     auto   cur_node = _tree.at(_root).get();
110*c217d954SCole Faust     size_t depth    = 0;
111*c217d954SCole Faust     while(cur_node->type() != NodeType::Leaf)
112*c217d954SCole Faust     {
113*c217d954SCole Faust         if(depth > _max_query_depth)
114*c217d954SCole Faust         {
115*c217d954SCole Faust             ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding max query depth: %zu. Is the tree too deep?", _max_query_depth);
116*c217d954SCole Faust             return std::make_pair(false, T{});
117*c217d954SCole Faust         }
118*c217d954SCole Faust         ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Branch, "Unexpected NodeType");
119*c217d954SCole Faust         auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
120*c217d954SCole Faust         if(evaluate(shape, br_node->condition))
121*c217d954SCole Faust         {
122*c217d954SCole Faust             cur_node = _tree.at(br_node->true_node).get();
123*c217d954SCole Faust         }
124*c217d954SCole Faust         else
125*c217d954SCole Faust         {
126*c217d954SCole Faust             cur_node = _tree.at(br_node->false_node).get();
127*c217d954SCole Faust         }
128*c217d954SCole Faust         ++depth;
129*c217d954SCole Faust     }
130*c217d954SCole Faust     ARM_COMPUTE_ERROR_ON_MSG(cur_node->type() != NodeType::Leaf, "Unexpected NodeType");
131*c217d954SCole Faust     auto l_node = utils::cast::polymorphic_downcast<LeafNode<T> *>(cur_node);
132*c217d954SCole Faust     return std::make_pair(true, l_node->value);
133*c217d954SCole Faust }
134*c217d954SCole Faust 
135*c217d954SCole Faust template <typename T>
add_leaf(NodeID id,T val)136*c217d954SCole Faust bool HeuristicTree::add_leaf(NodeID id, T val)
137*c217d954SCole Faust {
138*c217d954SCole Faust     if(_tree.size() >= _max_num_nodes)
139*c217d954SCole Faust     {
140*c217d954SCole Faust         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
141*c217d954SCole Faust         return false;
142*c217d954SCole Faust     }
143*c217d954SCole Faust     if(_tree.find(id) != _tree.end())
144*c217d954SCole Faust     {
145*c217d954SCole Faust         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
146*c217d954SCole Faust         return false;
147*c217d954SCole Faust     }
148*c217d954SCole Faust     _tree[id] = std::make_unique<LeafNode<T>>(id, val);
149*c217d954SCole Faust     return true;
150*c217d954SCole Faust }
151*c217d954SCole Faust 
add_branch(NodeID id,Condition cond,NodeID t_node,NodeID f_node)152*c217d954SCole Faust bool HeuristicTree::add_branch(NodeID id, Condition cond, NodeID t_node, NodeID f_node)
153*c217d954SCole Faust {
154*c217d954SCole Faust     if(_tree.size() >= _max_num_nodes)
155*c217d954SCole Faust     {
156*c217d954SCole Faust         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Exceeding the maximum number of nodes allowed %zu", _max_num_nodes);
157*c217d954SCole Faust         return false;
158*c217d954SCole Faust     }
159*c217d954SCole Faust 
160*c217d954SCole Faust     const std::set<std::string> supported_features =
161*c217d954SCole Faust     {
162*c217d954SCole Faust         "m", "n", "k", "b", "r_mn", "r_mk", "r_nk", "r_mnk", "workload"
163*c217d954SCole Faust     };
164*c217d954SCole Faust     const auto orig_feature = cond.feature;
165*c217d954SCole Faust     std::transform(cond.feature.begin(), cond.feature.end(), cond.feature.begin(), [](char c)
166*c217d954SCole Faust     {
167*c217d954SCole Faust         return std::tolower(c);
168*c217d954SCole Faust     });
169*c217d954SCole Faust     if(supported_features.find(cond.feature) == supported_features.end())
170*c217d954SCole Faust     {
171*c217d954SCole Faust         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Unsupported feature %s", orig_feature.c_str());
172*c217d954SCole Faust         return false;
173*c217d954SCole Faust     }
174*c217d954SCole Faust 
175*c217d954SCole Faust     if(_tree.find(id) != _tree.end())
176*c217d954SCole Faust     {
177*c217d954SCole Faust         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Cannot add node; node id %zu already exists", id);
178*c217d954SCole Faust         return false;
179*c217d954SCole Faust     }
180*c217d954SCole Faust     _tree[id] = std::make_unique<BranchNode>(id, cond, t_node, f_node);
181*c217d954SCole Faust     return true;
182*c217d954SCole Faust }
183*c217d954SCole Faust 
check_if_structurally_correct() const184*c217d954SCole Faust bool HeuristicTree::check_if_structurally_correct() const
185*c217d954SCole Faust {
186*c217d954SCole Faust     std::set<NodeID>   visited;
187*c217d954SCole Faust     std::deque<NodeID> to_visit{ _root };
188*c217d954SCole Faust 
189*c217d954SCole Faust     while(!to_visit.empty())
190*c217d954SCole Faust     {
191*c217d954SCole Faust         auto id = to_visit.front();
192*c217d954SCole Faust         to_visit.pop_front();
193*c217d954SCole Faust         if(_tree.find(id) == _tree.end())
194*c217d954SCole Faust         {
195*c217d954SCole Faust             ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing node %zu", id);
196*c217d954SCole Faust             return false;
197*c217d954SCole Faust         }
198*c217d954SCole Faust         auto not_seen_before = visited.insert(id);
199*c217d954SCole Faust         if(!not_seen_before.second)
200*c217d954SCole Faust         {
201*c217d954SCole Faust             ARM_COMPUTE_LOG_INFO_MSG_CORE("Not a tree; contains cycles or loops");
202*c217d954SCole Faust             return false;
203*c217d954SCole Faust         }
204*c217d954SCole Faust         auto cur_node = _tree.at(id).get();
205*c217d954SCole Faust         if(cur_node->type() == NodeType::Branch)
206*c217d954SCole Faust         {
207*c217d954SCole Faust             auto br_node = utils::cast::polymorphic_downcast<BranchNode *>(cur_node);
208*c217d954SCole Faust             to_visit.push_back(br_node->true_node);
209*c217d954SCole Faust             to_visit.push_back(br_node->false_node);
210*c217d954SCole Faust         }
211*c217d954SCole Faust     }
212*c217d954SCole Faust     if(visited.size() != _tree.size())
213*c217d954SCole Faust     {
214*c217d954SCole Faust         ARM_COMPUTE_LOG_INFO_MSG_CORE("Contains disjoint nodes");
215*c217d954SCole Faust         return false;
216*c217d954SCole Faust     }
217*c217d954SCole Faust     return true;
218*c217d954SCole Faust }
219*c217d954SCole Faust 
check()220*c217d954SCole Faust bool HeuristicTree::check()
221*c217d954SCole Faust {
222*c217d954SCole Faust     if(_tree.empty())
223*c217d954SCole Faust     {
224*c217d954SCole Faust         ARM_COMPUTE_LOG_INFO_MSG_CORE("Empty tree encountered");
225*c217d954SCole Faust         return false;
226*c217d954SCole Faust     }
227*c217d954SCole Faust     if(_tree.find(_root) == _tree.end())
228*c217d954SCole Faust     {
229*c217d954SCole Faust         ARM_COMPUTE_LOG_INFO_MSG_WITH_FORMAT_CORE("Missing root. Root must have a Node ID of %zu", _root);
230*c217d954SCole Faust         return false;
231*c217d954SCole Faust     }
232*c217d954SCole Faust     return check_if_structurally_correct();
233*c217d954SCole Faust }
234*c217d954SCole Faust 
235*c217d954SCole Faust /** Explicit template instantiation @relates HeuristicTree */
236*c217d954SCole Faust template std::pair<bool, GEMMType> HeuristicTree::query<GEMMType>(GEMMShape shape) const;
237*c217d954SCole Faust /** Explicit template instantiation @relates HeuristicTree */
238*c217d954SCole Faust template std::pair<bool, GEMMConfigNative> HeuristicTree::query<GEMMConfigNative>(GEMMShape shape) const;
239*c217d954SCole Faust /** Explicit template instantiation @relates HeuristicTree */
240*c217d954SCole Faust template std::pair<bool, GEMMConfigReshapedOnlyRHS> HeuristicTree::query<GEMMConfigReshapedOnlyRHS>(GEMMShape shape) const;
241*c217d954SCole Faust /** Explicit template instantiation @relates HeuristicTree */
242*c217d954SCole Faust template std::pair<bool, GEMMConfigReshaped> HeuristicTree::query<GEMMConfigReshaped>(GEMMShape shape) const;
243*c217d954SCole Faust 
244*c217d954SCole Faust /** Explicit template instantiation @relates HeuristicTree */
245*c217d954SCole Faust template bool HeuristicTree::add_leaf(NodeID id, GEMMType val);
246*c217d954SCole Faust /** Explicit template instantiation @relates HeuristicTree */
247*c217d954SCole Faust template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigNative val);
248*c217d954SCole Faust /** Explicit template instantiation @relates HeuristicTree */
249*c217d954SCole Faust template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshapedOnlyRHS val);
250*c217d954SCole Faust /** Explicit template instantiation @relates HeuristicTree */
251*c217d954SCole Faust template bool HeuristicTree::add_leaf(NodeID id, GEMMConfigReshaped val);
252*c217d954SCole Faust 
253*c217d954SCole Faust } // namespace mlgo
254*c217d954SCole Faust 
255*c217d954SCole Faust } // namespace arm_compute
256