1 // Copyright 2020 The Marl Authors. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // https://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // marl::DAG<> provides an ahead of time, declarative, directed acyclic 16 // task graph. 17 18 #ifndef marl_dag_h 19 #define marl_dag_h 20 21 #include "containers.h" 22 #include "export.h" 23 #include "memory.h" 24 #include "scheduler.h" 25 #include "waitgroup.h" 26 27 namespace marl { 28 namespace detail { 29 using DAGCounter = std::atomic<uint32_t>; 30 template <typename T> 31 struct DAGRunContext { 32 T data; 33 Allocator::unique_ptr<DAGCounter> counters; 34 35 template <typename F> invokeDAGRunContext36 MARL_NO_EXPORT inline void invoke(F&& f) { 37 f(data); 38 } 39 }; 40 template <> 41 struct DAGRunContext<void> { 42 Allocator::unique_ptr<DAGCounter> counters; 43 44 template <typename F> 45 MARL_NO_EXPORT inline void invoke(F&& f) { 46 f(); 47 } 48 }; 49 template <typename T> 50 struct DAGWork { 51 using type = std::function<void(T)>; 52 }; 53 template <> 54 struct DAGWork<void> { 55 using type = std::function<void()>; 56 }; 57 } // namespace detail 58 59 /////////////////////////////////////////////////////////////////////////////// 60 // Forward declarations 61 /////////////////////////////////////////////////////////////////////////////// 62 template <typename T> 63 class DAG; 64 65 template <typename T> 66 class DAGBuilder; 67 68 template <typename T> 69 class DAGNodeBuilder; 70 71 /////////////////////////////////////////////////////////////////////////////// 72 // DAGBase<T> 73 /////////////////////////////////////////////////////////////////////////////// 74 75 // DAGBase is derived by DAG<T> and DAG<void>. It has no public API. 76 template <typename T> 77 class DAGBase { 78 protected: 79 friend DAGBuilder<T>; 80 friend DAGNodeBuilder<T>; 81 82 using RunContext = detail::DAGRunContext<T>; 83 using Counter = detail::DAGCounter; 84 using NodeIndex = size_t; 85 using Work = typename detail::DAGWork<T>::type; 86 static const constexpr size_t NumReservedNodes = 32; 87 static const constexpr size_t NumReservedNumOuts = 4; 88 static const constexpr size_t InvalidCounterIndex = ~static_cast<size_t>(0); 89 static const constexpr NodeIndex RootIndex = 0; 90 static const constexpr NodeIndex InvalidNodeIndex = 91 ~static_cast<NodeIndex>(0); 92 93 // DAG work node. 94 struct Node { 95 MARL_NO_EXPORT inline Node() = default; 96 MARL_NO_EXPORT inline Node(Work&& work); 97 MARL_NO_EXPORT inline Node(const Work& work); 98 99 // The work to perform for this node in the graph. 100 Work work; 101 102 // counterIndex if valid, is the index of the counter in the RunContext for 103 // this node. The counter is decremented for each completed dependency task 104 // (ins), and once it reaches 0, this node will be invoked. 105 size_t counterIndex = InvalidCounterIndex; 106 107 // Indices for all downstream nodes. 108 containers::vector<NodeIndex, NumReservedNumOuts> outs; 109 }; 110 111 // initCounters() allocates and initializes the ctx->coutners from 112 // initialCounters. 113 MARL_NO_EXPORT inline void initCounters(RunContext* ctx, 114 Allocator* allocator); 115 116 // notify() is called each time a dependency task (ins) has completed for the 117 // node with the given index. 118 // If all dependency tasks have completed (or this is the root node) then 119 // notify() returns true and the caller should then call invoke(). 120 MARL_NO_EXPORT inline bool notify(RunContext*, NodeIndex); 121 122 // invoke() calls the work function for the node with the given index, then 123 // calls notify() and possibly invoke() for all the dependee nodes. 124 MARL_NO_EXPORT inline void invoke(RunContext*, NodeIndex, WaitGroup*); 125 126 // nodes is the full list of the nodes in the graph. 127 // nodes[0] is always the root node, which has no dependencies (ins). 128 containers::vector<Node, NumReservedNodes> nodes; 129 130 // initialCounters is a list of initial counter values to be copied to 131 // RunContext::counters on DAG<>::run(). 132 // initialCounters is indexed by Node::counterIndex, and only contains counts 133 // for nodes that have at least 2 dependencies (ins) - because of this the 134 // number of entries in initialCounters may be fewer than nodes. 135 containers::vector<uint32_t, NumReservedNodes> initialCounters; 136 }; 137 138 template <typename T> 139 DAGBase<T>::Node::Node(Work&& work) : work(std::move(work)) {} 140 141 template <typename T> 142 DAGBase<T>::Node::Node(const Work& work) : work(work) {} 143 144 template <typename T> 145 void DAGBase<T>::initCounters(RunContext* ctx, Allocator* allocator) { 146 auto numCounters = initialCounters.size(); 147 ctx->counters = allocator->make_unique_n<Counter>(numCounters); 148 for (size_t i = 0; i < numCounters; i++) { 149 ctx->counters.get()[i] = {initialCounters[i]}; 150 } 151 } 152 153 template <typename T> 154 bool DAGBase<T>::notify(RunContext* ctx, NodeIndex nodeIdx) { 155 Node* node = &nodes[nodeIdx]; 156 157 // If we have multiple dependencies, decrement the counter and check whether 158 // we've reached 0. 159 if (node->counterIndex == InvalidCounterIndex) { 160 return true; 161 } 162 auto counters = ctx->counters.get(); 163 auto counter = --counters[node->counterIndex]; 164 return counter == 0; 165 } 166 167 template <typename T> 168 void DAGBase<T>::invoke(RunContext* ctx, NodeIndex nodeIdx, WaitGroup* wg) { 169 Node* node = &nodes[nodeIdx]; 170 171 // Run this node's work. 172 if (node->work) { 173 ctx->invoke(node->work); 174 } 175 176 // Then call notify() on all dependees (outs), and invoke() those that 177 // returned true. 178 // We buffer the node to invoke (toInvoke) so we can schedule() all but the 179 // last node to invoke(), and directly call the last invoke() on this thread. 180 // This is done to avoid the overheads of scheduling when a direct call would 181 // suffice. 182 NodeIndex toInvoke = InvalidNodeIndex; 183 for (NodeIndex idx : node->outs) { 184 if (notify(ctx, idx)) { 185 if (toInvoke != InvalidNodeIndex) { 186 wg->add(1); 187 // Schedule while promoting the WaitGroup capture from a pointer 188 // reference to a value. This ensures that the WaitGroup isn't dropped 189 // while in use. 190 schedule( 191 [=](WaitGroup wg) { 192 invoke(ctx, toInvoke, &wg); 193 wg.done(); 194 }, 195 *wg); 196 } 197 toInvoke = idx; 198 } 199 } 200 if (toInvoke != InvalidNodeIndex) { 201 invoke(ctx, toInvoke, wg); 202 } 203 } 204 205 /////////////////////////////////////////////////////////////////////////////// 206 // DAGNodeBuilder<T> 207 /////////////////////////////////////////////////////////////////////////////// 208 209 // DAGNodeBuilder is the builder interface for a DAG node. 210 template <typename T> 211 class DAGNodeBuilder { 212 using NodeIndex = typename DAGBase<T>::NodeIndex; 213 214 public: 215 // then() builds and returns a new DAG node that will be invoked after this 216 // node has completed. 217 // 218 // F is a function that will be called when the new DAG node is invoked, with 219 // the signature: 220 // void(T) when T is not void 221 // or 222 // void() when T is void 223 template <typename F> 224 MARL_NO_EXPORT inline DAGNodeBuilder then(F&&); 225 226 private: 227 friend DAGBuilder<T>; 228 MARL_NO_EXPORT inline DAGNodeBuilder(DAGBuilder<T>*, NodeIndex); 229 DAGBuilder<T>* builder; 230 NodeIndex index; 231 }; 232 233 template <typename T> 234 DAGNodeBuilder<T>::DAGNodeBuilder(DAGBuilder<T>* builder, NodeIndex index) 235 : builder(builder), index(index) {} 236 237 template <typename T> 238 template <typename F> 239 DAGNodeBuilder<T> DAGNodeBuilder<T>::then(F&& work) { 240 auto node = builder->node(std::forward<F>(work)); 241 builder->addDependency(*this, node); 242 return node; 243 } 244 245 /////////////////////////////////////////////////////////////////////////////// 246 // DAGBuilder<T> 247 /////////////////////////////////////////////////////////////////////////////// 248 template <typename T> 249 class DAGBuilder { 250 public: 251 // DAGBuilder constructor 252 MARL_NO_EXPORT inline DAGBuilder(Allocator* allocator = Allocator::Default); 253 254 // root() returns the root DAG node. 255 MARL_NO_EXPORT inline DAGNodeBuilder<T> root(); 256 257 // node() builds and returns a new DAG node with no initial dependencies. 258 // The returned node must be attached to the graph in order to invoke F or any 259 // of the dependees of this returned node. 260 // 261 // F is a function that will be called when the new DAG node is invoked, with 262 // the signature: 263 // void(T) when T is not void 264 // or 265 // void() when T is void 266 template <typename F> 267 MARL_NO_EXPORT inline DAGNodeBuilder<T> node(F&& work); 268 269 // node() builds and returns a new DAG node that depends on all the tasks in 270 // after to be completed before invoking F. 271 // 272 // F is a function that will be called when the new DAG node is invoked, with 273 // the signature: 274 // void(T) when T is not void 275 // or 276 // void() when T is void 277 template <typename F> 278 MARL_NO_EXPORT inline DAGNodeBuilder<T> node( 279 F&& work, 280 std::initializer_list<DAGNodeBuilder<T>> after); 281 282 // addDependency() adds parent as dependency on child. All dependencies of 283 // child must have completed before child is invoked. 284 MARL_NO_EXPORT inline void addDependency(DAGNodeBuilder<T> parent, 285 DAGNodeBuilder<T> child); 286 287 // build() constructs and returns the DAG. No other methods of this class may 288 // be called after calling build(). 289 MARL_NO_EXPORT inline Allocator::unique_ptr<DAG<T>> build(); 290 291 private: 292 static const constexpr size_t NumReservedNumIns = 4; 293 using Node = typename DAG<T>::Node; 294 295 // The DAG being built. 296 Allocator::unique_ptr<DAG<T>> dag; 297 298 // Number of dependencies (ins) for each node in dag->nodes. 299 containers::vector<uint32_t, NumReservedNumIns> numIns; 300 }; 301 302 template <typename T> 303 DAGBuilder<T>::DAGBuilder(Allocator* allocator /* = Allocator::Default */) 304 : dag(allocator->make_unique<DAG<T>>()), numIns(allocator) { 305 // Add root 306 dag->nodes.emplace_back(Node{}); 307 numIns.emplace_back(0); 308 } 309 310 template <typename T> 311 DAGNodeBuilder<T> DAGBuilder<T>::root() { 312 return DAGNodeBuilder<T>{this, DAGBase<T>::RootIndex}; 313 } 314 315 template <typename T> 316 template <typename F> 317 DAGNodeBuilder<T> DAGBuilder<T>::node(F&& work) { 318 return node(std::forward<F>(work), {}); 319 } 320 321 template <typename T> 322 template <typename F> 323 DAGNodeBuilder<T> DAGBuilder<T>::node( 324 F&& work, 325 std::initializer_list<DAGNodeBuilder<T>> after) { 326 MARL_ASSERT(numIns.size() == dag->nodes.size(), 327 "NodeBuilder vectors out of sync"); 328 auto index = dag->nodes.size(); 329 numIns.emplace_back(0); 330 dag->nodes.emplace_back(Node{std::forward<F>(work)}); 331 auto node = DAGNodeBuilder<T>{this, index}; 332 for (auto in : after) { 333 addDependency(in, node); 334 } 335 return node; 336 } 337 338 template <typename T> 339 void DAGBuilder<T>::addDependency(DAGNodeBuilder<T> parent, 340 DAGNodeBuilder<T> child) { 341 numIns[child.index]++; 342 dag->nodes[parent.index].outs.push_back(child.index); 343 } 344 345 template <typename T> 346 Allocator::unique_ptr<DAG<T>> DAGBuilder<T>::build() { 347 auto numNodes = dag->nodes.size(); 348 MARL_ASSERT(numIns.size() == dag->nodes.size(), 349 "NodeBuilder vectors out of sync"); 350 for (size_t i = 0; i < numNodes; i++) { 351 if (numIns[i] > 1) { 352 auto& node = dag->nodes[i]; 353 node.counterIndex = dag->initialCounters.size(); 354 dag->initialCounters.push_back(numIns[i]); 355 } 356 } 357 return std::move(dag); 358 } 359 360 /////////////////////////////////////////////////////////////////////////////// 361 // DAG<T> 362 /////////////////////////////////////////////////////////////////////////////// 363 template <typename T = void> 364 class DAG : public DAGBase<T> { 365 public: 366 using Builder = DAGBuilder<T>; 367 using NodeBuilder = DAGNodeBuilder<T>; 368 369 // run() invokes the function of each node in the graph of the DAG, passing 370 // data to each, starting with the root node. All dependencies need to have 371 // completed their function before dependees will be invoked. 372 MARL_NO_EXPORT inline void run(T& data, 373 Allocator* allocator = Allocator::Default); 374 }; 375 376 template <typename T> 377 void DAG<T>::run(T& arg, Allocator* allocator /* = Allocator::Default */) { 378 typename DAGBase<T>::RunContext ctx{arg}; 379 this->initCounters(&ctx, allocator); 380 WaitGroup wg; 381 this->invoke(&ctx, this->RootIndex, &wg); 382 wg.wait(); 383 } 384 385 /////////////////////////////////////////////////////////////////////////////// 386 // DAG<void> 387 /////////////////////////////////////////////////////////////////////////////// 388 template <> 389 class DAG<void> : public DAGBase<void> { 390 public: 391 using Builder = DAGBuilder<void>; 392 using NodeBuilder = DAGNodeBuilder<void>; 393 394 // run() invokes the function of each node in the graph of the DAG, starting 395 // with the root node. All dependencies need to have completed their function 396 // before dependees will be invoked. 397 MARL_NO_EXPORT inline void run(Allocator* allocator = Allocator::Default); 398 }; 399 400 void DAG<void>::run(Allocator* allocator /* = Allocator::Default */) { 401 typename DAGBase<void>::RunContext ctx{}; 402 this->initCounters(&ctx, allocator); 403 WaitGroup wg; 404 this->invoke(&ctx, this->RootIndex, &wg); 405 wg.wait(); 406 } 407 408 } // namespace marl 409 410 #endif // marl_dag_h 411