xref: /aosp_15_r20/external/swiftshader/third_party/marl/include/marl/dag.h (revision 03ce13f70fcc45d86ee91b7ee4cab1936a95046e)
1 // Copyright 2020 The Marl Authors.
2 //
3 // Licensed under the Apache License, Version 2.0 (the "License");
4 // you may not use this file except in compliance with the License.
5 // You may obtain a copy of the License at
6 //
7 //     https://www.apache.org/licenses/LICENSE-2.0
8 //
9 // Unless required by applicable law or agreed to in writing, software
10 // distributed under the License is distributed on an "AS IS" BASIS,
11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12 // See the License for the specific language governing permissions and
13 // limitations under the License.
14 
15 // marl::DAG<> provides an ahead of time, declarative, directed acyclic
16 // task graph.
17 
18 #ifndef marl_dag_h
19 #define marl_dag_h
20 
21 #include "containers.h"
22 #include "export.h"
23 #include "memory.h"
24 #include "scheduler.h"
25 #include "waitgroup.h"
26 
27 namespace marl {
28 namespace detail {
29 using DAGCounter = std::atomic<uint32_t>;
30 template <typename T>
31 struct DAGRunContext {
32   T data;
33   Allocator::unique_ptr<DAGCounter> counters;
34 
35   template <typename F>
invokeDAGRunContext36   MARL_NO_EXPORT inline void invoke(F&& f) {
37     f(data);
38   }
39 };
40 template <>
41 struct DAGRunContext<void> {
42   Allocator::unique_ptr<DAGCounter> counters;
43 
44   template <typename F>
45   MARL_NO_EXPORT inline void invoke(F&& f) {
46     f();
47   }
48 };
49 template <typename T>
50 struct DAGWork {
51   using type = std::function<void(T)>;
52 };
53 template <>
54 struct DAGWork<void> {
55   using type = std::function<void()>;
56 };
57 }  // namespace detail
58 
59 ///////////////////////////////////////////////////////////////////////////////
60 // Forward declarations
61 ///////////////////////////////////////////////////////////////////////////////
62 template <typename T>
63 class DAG;
64 
65 template <typename T>
66 class DAGBuilder;
67 
68 template <typename T>
69 class DAGNodeBuilder;
70 
71 ///////////////////////////////////////////////////////////////////////////////
72 // DAGBase<T>
73 ///////////////////////////////////////////////////////////////////////////////
74 
75 // DAGBase is derived by DAG<T> and DAG<void>. It has no public API.
76 template <typename T>
77 class DAGBase {
78  protected:
79   friend DAGBuilder<T>;
80   friend DAGNodeBuilder<T>;
81 
82   using RunContext = detail::DAGRunContext<T>;
83   using Counter = detail::DAGCounter;
84   using NodeIndex = size_t;
85   using Work = typename detail::DAGWork<T>::type;
86   static const constexpr size_t NumReservedNodes = 32;
87   static const constexpr size_t NumReservedNumOuts = 4;
88   static const constexpr size_t InvalidCounterIndex = ~static_cast<size_t>(0);
89   static const constexpr NodeIndex RootIndex = 0;
90   static const constexpr NodeIndex InvalidNodeIndex =
91       ~static_cast<NodeIndex>(0);
92 
93   // DAG work node.
94   struct Node {
95     MARL_NO_EXPORT inline Node() = default;
96     MARL_NO_EXPORT inline Node(Work&& work);
97     MARL_NO_EXPORT inline Node(const Work& work);
98 
99     // The work to perform for this node in the graph.
100     Work work;
101 
102     // counterIndex if valid, is the index of the counter in the RunContext for
103     // this node. The counter is decremented for each completed dependency task
104     // (ins), and once it reaches 0, this node will be invoked.
105     size_t counterIndex = InvalidCounterIndex;
106 
107     // Indices for all downstream nodes.
108     containers::vector<NodeIndex, NumReservedNumOuts> outs;
109   };
110 
111   // initCounters() allocates and initializes the ctx->coutners from
112   // initialCounters.
113   MARL_NO_EXPORT inline void initCounters(RunContext* ctx,
114                                           Allocator* allocator);
115 
116   // notify() is called each time a dependency task (ins) has completed for the
117   // node with the given index.
118   // If all dependency tasks have completed (or this is the root node) then
119   // notify() returns true and the caller should then call invoke().
120   MARL_NO_EXPORT inline bool notify(RunContext*, NodeIndex);
121 
122   // invoke() calls the work function for the node with the given index, then
123   // calls notify() and possibly invoke() for all the dependee nodes.
124   MARL_NO_EXPORT inline void invoke(RunContext*, NodeIndex, WaitGroup*);
125 
126   // nodes is the full list of the nodes in the graph.
127   // nodes[0] is always the root node, which has no dependencies (ins).
128   containers::vector<Node, NumReservedNodes> nodes;
129 
130   // initialCounters is a list of initial counter values to be copied to
131   // RunContext::counters on DAG<>::run().
132   // initialCounters is indexed by Node::counterIndex, and only contains counts
133   // for nodes that have at least 2 dependencies (ins) - because of this the
134   // number of entries in initialCounters may be fewer than nodes.
135   containers::vector<uint32_t, NumReservedNodes> initialCounters;
136 };
137 
138 template <typename T>
139 DAGBase<T>::Node::Node(Work&& work) : work(std::move(work)) {}
140 
141 template <typename T>
142 DAGBase<T>::Node::Node(const Work& work) : work(work) {}
143 
144 template <typename T>
145 void DAGBase<T>::initCounters(RunContext* ctx, Allocator* allocator) {
146   auto numCounters = initialCounters.size();
147   ctx->counters = allocator->make_unique_n<Counter>(numCounters);
148   for (size_t i = 0; i < numCounters; i++) {
149     ctx->counters.get()[i] = {initialCounters[i]};
150   }
151 }
152 
153 template <typename T>
154 bool DAGBase<T>::notify(RunContext* ctx, NodeIndex nodeIdx) {
155   Node* node = &nodes[nodeIdx];
156 
157   // If we have multiple dependencies, decrement the counter and check whether
158   // we've reached 0.
159   if (node->counterIndex == InvalidCounterIndex) {
160     return true;
161   }
162   auto counters = ctx->counters.get();
163   auto counter = --counters[node->counterIndex];
164   return counter == 0;
165 }
166 
167 template <typename T>
168 void DAGBase<T>::invoke(RunContext* ctx, NodeIndex nodeIdx, WaitGroup* wg) {
169   Node* node = &nodes[nodeIdx];
170 
171   // Run this node's work.
172   if (node->work) {
173     ctx->invoke(node->work);
174   }
175 
176   // Then call notify() on all dependees (outs), and invoke() those that
177   // returned true.
178   // We buffer the node to invoke (toInvoke) so we can schedule() all but the
179   // last node to invoke(), and directly call the last invoke() on this thread.
180   // This is done to avoid the overheads of scheduling when a direct call would
181   // suffice.
182   NodeIndex toInvoke = InvalidNodeIndex;
183   for (NodeIndex idx : node->outs) {
184     if (notify(ctx, idx)) {
185       if (toInvoke != InvalidNodeIndex) {
186         wg->add(1);
187         // Schedule while promoting the WaitGroup capture from a pointer
188         // reference to a value. This ensures that the WaitGroup isn't dropped
189         // while in use.
190         schedule(
191             [=](WaitGroup wg) {
192               invoke(ctx, toInvoke, &wg);
193               wg.done();
194             },
195             *wg);
196       }
197       toInvoke = idx;
198     }
199   }
200   if (toInvoke != InvalidNodeIndex) {
201     invoke(ctx, toInvoke, wg);
202   }
203 }
204 
205 ///////////////////////////////////////////////////////////////////////////////
206 // DAGNodeBuilder<T>
207 ///////////////////////////////////////////////////////////////////////////////
208 
209 // DAGNodeBuilder is the builder interface for a DAG node.
210 template <typename T>
211 class DAGNodeBuilder {
212   using NodeIndex = typename DAGBase<T>::NodeIndex;
213 
214  public:
215   // then() builds and returns a new DAG node that will be invoked after this
216   // node has completed.
217   //
218   // F is a function that will be called when the new DAG node is invoked, with
219   // the signature:
220   //   void(T)   when T is not void
221   // or
222   //   void()    when T is void
223   template <typename F>
224   MARL_NO_EXPORT inline DAGNodeBuilder then(F&&);
225 
226  private:
227   friend DAGBuilder<T>;
228   MARL_NO_EXPORT inline DAGNodeBuilder(DAGBuilder<T>*, NodeIndex);
229   DAGBuilder<T>* builder;
230   NodeIndex index;
231 };
232 
233 template <typename T>
234 DAGNodeBuilder<T>::DAGNodeBuilder(DAGBuilder<T>* builder, NodeIndex index)
235     : builder(builder), index(index) {}
236 
237 template <typename T>
238 template <typename F>
239 DAGNodeBuilder<T> DAGNodeBuilder<T>::then(F&& work) {
240   auto node = builder->node(std::forward<F>(work));
241   builder->addDependency(*this, node);
242   return node;
243 }
244 
245 ///////////////////////////////////////////////////////////////////////////////
246 // DAGBuilder<T>
247 ///////////////////////////////////////////////////////////////////////////////
248 template <typename T>
249 class DAGBuilder {
250  public:
251   // DAGBuilder constructor
252   MARL_NO_EXPORT inline DAGBuilder(Allocator* allocator = Allocator::Default);
253 
254   // root() returns the root DAG node.
255   MARL_NO_EXPORT inline DAGNodeBuilder<T> root();
256 
257   // node() builds and returns a new DAG node with no initial dependencies.
258   // The returned node must be attached to the graph in order to invoke F or any
259   // of the dependees of this returned node.
260   //
261   // F is a function that will be called when the new DAG node is invoked, with
262   // the signature:
263   //   void(T)   when T is not void
264   // or
265   //   void()    when T is void
266   template <typename F>
267   MARL_NO_EXPORT inline DAGNodeBuilder<T> node(F&& work);
268 
269   // node() builds and returns a new DAG node that depends on all the tasks in
270   // after to be completed before invoking F.
271   //
272   // F is a function that will be called when the new DAG node is invoked, with
273   // the signature:
274   //   void(T)   when T is not void
275   // or
276   //   void()    when T is void
277   template <typename F>
278   MARL_NO_EXPORT inline DAGNodeBuilder<T> node(
279       F&& work,
280       std::initializer_list<DAGNodeBuilder<T>> after);
281 
282   // addDependency() adds parent as dependency on child. All dependencies of
283   // child must have completed before child is invoked.
284   MARL_NO_EXPORT inline void addDependency(DAGNodeBuilder<T> parent,
285                                            DAGNodeBuilder<T> child);
286 
287   // build() constructs and returns the DAG. No other methods of this class may
288   // be called after calling build().
289   MARL_NO_EXPORT inline Allocator::unique_ptr<DAG<T>> build();
290 
291  private:
292   static const constexpr size_t NumReservedNumIns = 4;
293   using Node = typename DAG<T>::Node;
294 
295   // The DAG being built.
296   Allocator::unique_ptr<DAG<T>> dag;
297 
298   // Number of dependencies (ins) for each node in dag->nodes.
299   containers::vector<uint32_t, NumReservedNumIns> numIns;
300 };
301 
302 template <typename T>
303 DAGBuilder<T>::DAGBuilder(Allocator* allocator /* = Allocator::Default */)
304     : dag(allocator->make_unique<DAG<T>>()), numIns(allocator) {
305   // Add root
306   dag->nodes.emplace_back(Node{});
307   numIns.emplace_back(0);
308 }
309 
310 template <typename T>
311 DAGNodeBuilder<T> DAGBuilder<T>::root() {
312   return DAGNodeBuilder<T>{this, DAGBase<T>::RootIndex};
313 }
314 
315 template <typename T>
316 template <typename F>
317 DAGNodeBuilder<T> DAGBuilder<T>::node(F&& work) {
318   return node(std::forward<F>(work), {});
319 }
320 
321 template <typename T>
322 template <typename F>
323 DAGNodeBuilder<T> DAGBuilder<T>::node(
324     F&& work,
325     std::initializer_list<DAGNodeBuilder<T>> after) {
326   MARL_ASSERT(numIns.size() == dag->nodes.size(),
327               "NodeBuilder vectors out of sync");
328   auto index = dag->nodes.size();
329   numIns.emplace_back(0);
330   dag->nodes.emplace_back(Node{std::forward<F>(work)});
331   auto node = DAGNodeBuilder<T>{this, index};
332   for (auto in : after) {
333     addDependency(in, node);
334   }
335   return node;
336 }
337 
338 template <typename T>
339 void DAGBuilder<T>::addDependency(DAGNodeBuilder<T> parent,
340                                   DAGNodeBuilder<T> child) {
341   numIns[child.index]++;
342   dag->nodes[parent.index].outs.push_back(child.index);
343 }
344 
345 template <typename T>
346 Allocator::unique_ptr<DAG<T>> DAGBuilder<T>::build() {
347   auto numNodes = dag->nodes.size();
348   MARL_ASSERT(numIns.size() == dag->nodes.size(),
349               "NodeBuilder vectors out of sync");
350   for (size_t i = 0; i < numNodes; i++) {
351     if (numIns[i] > 1) {
352       auto& node = dag->nodes[i];
353       node.counterIndex = dag->initialCounters.size();
354       dag->initialCounters.push_back(numIns[i]);
355     }
356   }
357   return std::move(dag);
358 }
359 
360 ///////////////////////////////////////////////////////////////////////////////
361 // DAG<T>
362 ///////////////////////////////////////////////////////////////////////////////
363 template <typename T = void>
364 class DAG : public DAGBase<T> {
365  public:
366   using Builder = DAGBuilder<T>;
367   using NodeBuilder = DAGNodeBuilder<T>;
368 
369   // run() invokes the function of each node in the graph of the DAG, passing
370   // data to each, starting with the root node. All dependencies need to have
371   // completed their function before dependees will be invoked.
372   MARL_NO_EXPORT inline void run(T& data,
373                                  Allocator* allocator = Allocator::Default);
374 };
375 
376 template <typename T>
377 void DAG<T>::run(T& arg, Allocator* allocator /* = Allocator::Default */) {
378   typename DAGBase<T>::RunContext ctx{arg};
379   this->initCounters(&ctx, allocator);
380   WaitGroup wg;
381   this->invoke(&ctx, this->RootIndex, &wg);
382   wg.wait();
383 }
384 
385 ///////////////////////////////////////////////////////////////////////////////
386 // DAG<void>
387 ///////////////////////////////////////////////////////////////////////////////
388 template <>
389 class DAG<void> : public DAGBase<void> {
390  public:
391   using Builder = DAGBuilder<void>;
392   using NodeBuilder = DAGNodeBuilder<void>;
393 
394   // run() invokes the function of each node in the graph of the DAG, starting
395   // with the root node. All dependencies need to have completed their function
396   // before dependees will be invoked.
397   MARL_NO_EXPORT inline void run(Allocator* allocator = Allocator::Default);
398 };
399 
400 void DAG<void>::run(Allocator* allocator /* = Allocator::Default */) {
401   typename DAGBase<void>::RunContext ctx{};
402   this->initCounters(&ctx, allocator);
403   WaitGroup wg;
404   this->invoke(&ctx, this->RootIndex, &wg);
405   wg.wait();
406 }
407 
408 }  // namespace marl
409 
410 #endif  // marl_dag_h
411