xref: /aosp_15_r20/external/cronet/base/functional/concurrent_callbacks.h (revision 6777b5387eb2ff775bb5750e3f5d96f37fb7352b)
1 // Copyright 2024 The Chromium Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 #ifndef BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_
6 #define BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_
7 
8 #include <memory>
9 #include <type_traits>
10 #include <vector>
11 
12 #include "base/functional/bind.h"
13 #include "base/functional/callback.h"
14 #include "base/location.h"
15 #include "base/memory/raw_ptr.h"
16 #include "base/task/bind_post_task.h"
17 #include "base/task/sequenced_task_runner.h"
18 
19 // OVERVIEW:
20 //
21 // ConcurrentCallbacks<T> is an alternative to BarrierCallback<T>, it dispenses
22 // OnceCallbacks via CreateCallback() and invokes the callback passed to Done()
23 // after all prior callbacks have been run.
24 //
25 // ConcurrentCallbacks<T> is intended to be used over BarrierCallback<T> in
26 // cases where the count is unknown prior to requiring a callback to start a
27 // task, and for cases where the count is manually derived from the code and
28 // subject to human error.
29 //
30 // IMPORTANT NOTES:
31 //
32 // - ConcurrentCallbacks<T> is NOT thread safe.
33 // - The done callback will NOT be run synchronously, it will be PostTask() to
34 //   the sequence that Done() was invoked on.
35 // - ConcurrentCallbacks<T> cannot be used after Done() is called, a CHECK
36 //   verifies this.
37 //
38 // TYPICAL USAGE:
39 //
40 // class Example {
41 //   void OnRequestsReceived(std::vector<Request> requests) {
42 //     base::ConcurrentCallbacks<Result> concurrent;
43 //
44 //     for (Request& request : requests) {
45 //       if (IsValidRequest(request)) {
46 //         StartRequest(std::move(request), concurrent.CreateCallback());
47 //       }
48 //     }
49 //
50 //     std::move(concurrent).Done(
51 //         base::BindOnce(&Example::OnRequestsComplete, GetWeakPtr()));
52 //   }
53 //
54 //   void StartRequest(Request request,
55 //                     base::OnceCallback<void(Result)> callback) {
56 //     // Process the request asynchronously and call callback with a Result.
57 //   }
58 //
59 //   void OnRequestsComplete(std::vector<Result> results) {
60 //     // Invoked after all requests are completed and receives the results of
61 //     // all of them.
62 //   }
63 // };
64 
65 namespace base {
66 
67 template <typename T>
68 class ConcurrentCallbacks {
69  public:
70   using Results = std::vector<std::remove_cvref_t<T>>;
71 
ConcurrentCallbacks()72   ConcurrentCallbacks() {
73     auto info_owner = std::make_unique<Info>();
74     info_ = info_owner.get();
75     info_run_callback_ = BindRepeating(&Info::Run, std::move(info_owner));
76   }
77 
78   // Create a callback for the done callback to wait for.
CreateCallback()79   [[nodiscard]] OnceCallback<void(T)> CreateCallback() {
80     CHECK(info_);
81     ++info_->pending_;
82     return info_run_callback_;
83   }
84 
85   // Finish creating concurrent callbacks and provide done callback to run once
86   // all prior callbacks have executed.
87   // `this` is no longer usable after calling Done(), must be called with
88   // std::move().
89   void Done(OnceCallback<void(Results)> done_callback,
90             const Location& location = FROM_HERE) && {
91     CHECK(info_);
92     info_->done_callback_ =
93         BindPostTask(SequencedTaskRunner::GetCurrentDefault(),
94                      std::move(done_callback), location);
95     if (info_->pending_ == 0u) {
96       std::move(info_->done_callback_).Run(std::move(info_->results_));
97     }
98     info_ = nullptr;
99   }
100 
101  private:
102   class Info {
103    public:
104     Info() = default;
105 
Run(T value)106     void Run(T value) {
107       CHECK_GT(pending_, 0u);
108       --pending_;
109       results_.push_back(std::move(value));
110       if (done_callback_ && pending_ == 0u) {
111         std::move(done_callback_).Run(std::move(results_));
112       }
113     }
114 
115     size_t pending_ = 0u;
116     Results results_;
117     OnceCallback<void(Results)> done_callback_;
118   };
119 
120   RepeatingCallback<void(T)> info_run_callback_;
121   // info_ is owned by info_run_callback_.
122   raw_ptr<Info> info_;
123 };
124 
125 }  // namespace base
126 
127 #endif  // BASE_FUNCTIONAL_CONCURRENT_CALLBACKS_H_
128