1 /* 2 * Copyright (c) 2021 Arm Limited. 3 * 4 * SPDX-License-Identifier: MIT 5 * 6 * Permission is hereby granted, free of charge, to any person obtaining a copy 7 * of this software and associated documentation files (the "Software"), to 8 * deal in the Software without restriction, including without limitation the 9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or 10 * sell copies of the Software, and to permit persons to whom the Software is 11 * furnished to do so, subject to the following conditions: 12 * 13 * The above copyright notice and this permission notice shall be included in all 14 * copies or substantial portions of the Software. 15 * 16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR 17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, 18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE 19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER 20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, 21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE 22 * SOFTWARE. 23 */ 24 #ifndef SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H 25 #define SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H 26 27 #include "arm_compute/core/Types.h" 28 #include "src/runtime/CL/mlgo/Common.h" 29 30 #include <map> 31 #include <memory> 32 #include <string> 33 #include <utility> 34 35 namespace arm_compute 36 { 37 namespace mlgo 38 { 39 /** Conditional ops */ 40 enum class ConditionalOp 41 { 42 EQ, /**< Equal */ 43 LT, /**< Less than */ 44 LE, /**< Less than or equal to */ 45 GT, /**< Greater than */ 46 GE, /**< Greater than or equal to */ 47 }; 48 49 /** A branch condition expression evaluating: feature op threshold */ 50 struct Condition 51 { 52 std::string feature; /**< Feature name */ 53 ConditionalOp op; /**< Condtional op */ 54 float threshold; /**< Threshold value */ 55 }; 56 57 /** GEMM Shape used for query */ 58 struct GEMMShape 59 { 60 unsigned int m; /**< Number of rows for the lhs matrix. Lhs matrix NOT transposed */ 61 unsigned int n; /**< Number of columns for the rhs matrix. Rhs matrix NOT transposed */ 62 unsigned int k; /**< Number of rows for the rhs matrix. Rhs matrix NOT transposed */ 63 unsigned int b; /**< Batch size */ 64 }; 65 66 /** A binary decision tree based heuristic */ 67 class HeuristicTree 68 { 69 public: 70 using NodeID = size_t; 71 using TreeID = size_t; 72 using Index = std::tuple<HeuristicType, std::string, DataType>; 73 enum class NodeType 74 { 75 Branch, 76 Leaf 77 }; 78 struct Node 79 { 80 virtual NodeType type() const = 0; 81 virtual ~Node() = default; 82 }; 83 84 struct BranchNode : public Node 85 { BranchNodeBranchNode86 BranchNode(NodeID id, Condition cond, NodeID t_node, NodeID f_node) 87 : id{ id }, condition{ cond }, true_node{ t_node }, false_node{ f_node } 88 { 89 } typeBranchNode90 NodeType type() const override 91 { 92 return NodeType::Branch; 93 } 94 NodeID id; 95 Condition condition; 96 NodeID true_node; 97 NodeID false_node; 98 }; 99 100 template <typename T> 101 struct LeafNode : public Node 102 { LeafNodeLeafNode103 LeafNode(NodeID id, T val) 104 : id{ id }, value{ val } 105 { 106 } typeLeafNode107 NodeType type() const override 108 { 109 return NodeType::Leaf; 110 } 111 NodeID id; 112 T value; 113 }; 114 115 public: 116 /** Constructor */ 117 HeuristicTree(); 118 /** Constructor */ 119 HeuristicTree(TreeID id, HeuristicType h_type, const std::string &ip_target, DataType data_type); 120 // Since the HeuristicTree is a handle that owns the the nodes, it is move-only 121 /** Prevent copy construction */ 122 HeuristicTree(const HeuristicTree &) = delete; 123 /** Prevent copy assignment */ 124 HeuristicTree &operator=(const HeuristicTree &) = delete; 125 /** Move constructor */ 126 HeuristicTree(HeuristicTree &&other) noexcept = default; 127 /** Move assignment */ 128 HeuristicTree &operator=(HeuristicTree &&other) = default; 129 130 /** Query a leaf value given a gemm shape 131 * 132 * @tparam T Leaf value type 133 * @param shape A @ref GEMMShape for the query 134 * @return std::pair<bool, T> Outcome contains bool, signalling if the query succeeded or not 135 */ 136 template <typename T> 137 std::pair<bool, T> query(GEMMShape shape) const; 138 139 /** Add a leaf node 140 * 141 * @tparam T Leaf value type 142 * @param id Leaf node ID 143 * @param leaf_value Leaf node value 144 * @return bool If the addition succeeded or not 145 */ 146 template <typename T> 147 bool add_leaf(NodeID id, T leaf_value); 148 /** Add a branch node 149 * 150 * @param id Branch node ID 151 * @param cond Branch node @ref Condition 152 * @param true_node True node's ID 153 * @param false_node False node's ID 154 * @return bool If the addition succeeded or not 155 */ 156 bool add_branch(NodeID id, Condition cond, NodeID true_node, NodeID false_node); 157 158 /** Get tree ID 159 * @return TreeID 160 */ id()161 TreeID id() const 162 { 163 return _id; 164 } 165 166 /** Get tree index 167 * @return Index 168 */ index()169 Index index() const 170 { 171 return std::make_tuple(_heuristic_type, _ip_target, _data_type); 172 } 173 174 /** Check if tree is valid 175 * @return bool 176 */ 177 bool check(); 178 179 private: 180 static constexpr size_t _max_query_depth{ 1000 }; // Maximum depth of query 181 static constexpr size_t _max_num_nodes{ 100000 }; // Maximum number of nodes contained by the tree 182 static constexpr NodeID _root{ 0 }; // Root tree ID 183 184 private: 185 bool check_if_structurally_correct() const; 186 187 private: 188 TreeID _id; /**< Heuristic tree ID */ 189 HeuristicType _heuristic_type; /**< Heuristic type */ 190 std::string _ip_target; /**< IP target associated with the tree */ 191 DataType _data_type; /**< Data type associated with the tree */ 192 std::map<NodeID, std::unique_ptr<Node>> _tree; /**< Tree representation */ 193 }; 194 } // namespace mlgo 195 196 } // namespace arm_compute 197 198 #endif //SRC_RUNTIME_CL_MLGO_HEURISTIC_TREE_H