xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_scheduler.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2019 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
18*14675a02SAndroid Build Coastguard Worker #define FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
19*14675a02SAndroid Build Coastguard Worker 
20*14675a02SAndroid Build Coastguard Worker #include <atomic>
21*14675a02SAndroid Build Coastguard Worker #include <functional>
22*14675a02SAndroid Build Coastguard Worker #include <memory>
23*14675a02SAndroid Build Coastguard Worker #include <utility>
24*14675a02SAndroid Build Coastguard Worker #include <vector>
25*14675a02SAndroid Build Coastguard Worker 
26*14675a02SAndroid Build Coastguard Worker #include "absl/synchronization/mutex.h"
27*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h"
28*14675a02SAndroid Build Coastguard Worker #include "fcp/base/clock.h"
29*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h"
30*14675a02SAndroid Build Coastguard Worker #include "fcp/base/reentrancy_guard.h"
31*14675a02SAndroid Build Coastguard Worker #include "fcp/base/scheduler.h"
32*14675a02SAndroid Build Coastguard Worker 
33*14675a02SAndroid Build Coastguard Worker namespace fcp {
34*14675a02SAndroid Build Coastguard Worker namespace secagg {
35*14675a02SAndroid Build Coastguard Worker 
36*14675a02SAndroid Build Coastguard Worker // Simple callback waiter that runs the function on Wakeup.
37*14675a02SAndroid Build Coastguard Worker class CallbackWaiter : public Clock::Waiter {
38*14675a02SAndroid Build Coastguard Worker  public:
CallbackWaiter(std::function<void ()> callback)39*14675a02SAndroid Build Coastguard Worker   explicit CallbackWaiter(std::function<void()> callback)
40*14675a02SAndroid Build Coastguard Worker       : callback_(std::move(callback)) {}
41*14675a02SAndroid Build Coastguard Worker 
WakeUp()42*14675a02SAndroid Build Coastguard Worker   void WakeUp() override { callback_(); }
43*14675a02SAndroid Build Coastguard Worker 
44*14675a02SAndroid Build Coastguard Worker  private:
45*14675a02SAndroid Build Coastguard Worker   std::function<void()> callback_;
46*14675a02SAndroid Build Coastguard Worker };
47*14675a02SAndroid Build Coastguard Worker 
48*14675a02SAndroid Build Coastguard Worker // Provides Cancellation mechanism for SevAggScheduler.
49*14675a02SAndroid Build Coastguard Worker class CancellationImpl {
50*14675a02SAndroid Build Coastguard Worker  public:
51*14675a02SAndroid Build Coastguard Worker   virtual ~CancellationImpl() = default;
52*14675a02SAndroid Build Coastguard Worker 
53*14675a02SAndroid Build Coastguard Worker   // Calling Cancel results in skipping the remaining, still pending
54*14675a02SAndroid Build Coastguard Worker   // ParallelGenerateSequentialReduce. The call blocks waiting for any
55*14675a02SAndroid Build Coastguard Worker   // currently active ongoing tasks to complete. Calling Cancel for the second
56*14675a02SAndroid Build Coastguard Worker   // time has no additional effect.
57*14675a02SAndroid Build Coastguard Worker   virtual void Cancel() = 0;
58*14675a02SAndroid Build Coastguard Worker };
59*14675a02SAndroid Build Coastguard Worker 
60*14675a02SAndroid Build Coastguard Worker using CancellationToken = std::shared_ptr<CancellationImpl>;
61*14675a02SAndroid Build Coastguard Worker 
62*14675a02SAndroid Build Coastguard Worker template <typename T>
63*14675a02SAndroid Build Coastguard Worker class Accumulator : public CancellationImpl,
64*14675a02SAndroid Build Coastguard Worker                     public std::enable_shared_from_this<Accumulator<T>> {
65*14675a02SAndroid Build Coastguard Worker  public:
Accumulator(std::unique_ptr<T> initial_value,std::function<std::unique_ptr<T> (const T &,const T &)> accumulator_func,Scheduler * parallel_scheduler,Scheduler * sequential_scheduler,Clock * clock)66*14675a02SAndroid Build Coastguard Worker   Accumulator(
67*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<T> initial_value,
68*14675a02SAndroid Build Coastguard Worker       std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func,
69*14675a02SAndroid Build Coastguard Worker       Scheduler* parallel_scheduler, Scheduler* sequential_scheduler,
70*14675a02SAndroid Build Coastguard Worker       Clock* clock)
71*14675a02SAndroid Build Coastguard Worker       : parallel_scheduler_(parallel_scheduler),
72*14675a02SAndroid Build Coastguard Worker         sequential_scheduler_(sequential_scheduler),
73*14675a02SAndroid Build Coastguard Worker         accumulated_value_(std::move(initial_value)),
74*14675a02SAndroid Build Coastguard Worker         accumulator_func_(accumulator_func),
75*14675a02SAndroid Build Coastguard Worker         clock_(clock) {}
76*14675a02SAndroid Build Coastguard Worker 
GetParallelScheduleFunc(std::shared_ptr<Accumulator<T>> accumulator,std::function<std::unique_ptr<T> ()> generator)77*14675a02SAndroid Build Coastguard Worker   inline static std::function<void()> GetParallelScheduleFunc(
78*14675a02SAndroid Build Coastguard Worker       std::shared_ptr<Accumulator<T>> accumulator,
79*14675a02SAndroid Build Coastguard Worker       std::function<std::unique_ptr<T>()> generator) {
80*14675a02SAndroid Build Coastguard Worker     return [accumulator, generator] {
81*14675a02SAndroid Build Coastguard Worker       // Increment active count if the accumulator is not canceled, otherwise
82*14675a02SAndroid Build Coastguard Worker       // return without scheduling the task. By active count we mean the total
83*14675a02SAndroid Build Coastguard Worker       // number of scheduled tasks, both parallel and sequential. To cancel an
84*14675a02SAndroid Build Coastguard Worker       // accumulator, we wait until that this count is 0.
85*14675a02SAndroid Build Coastguard Worker       if (!accumulator->MaybeIncrementActiveCount()) {
86*14675a02SAndroid Build Coastguard Worker         return;
87*14675a02SAndroid Build Coastguard Worker       }
88*14675a02SAndroid Build Coastguard Worker       auto partial = generator();
89*14675a02SAndroid Build Coastguard Worker       FCP_CHECK(partial);
90*14675a02SAndroid Build Coastguard Worker       // Decrement the count for the parallel task that was just run as
91*14675a02SAndroid Build Coastguard Worker       // generator().
92*14675a02SAndroid Build Coastguard Worker       accumulator->DecrementActiveCount();
93*14675a02SAndroid Build Coastguard Worker       // Schedule sequential part of the generator, only if accumulator is not
94*14675a02SAndroid Build Coastguard Worker       // cancelled, otherwise return without scheduling it.
95*14675a02SAndroid Build Coastguard Worker       if (accumulator->IsCancelled()) {
96*14675a02SAndroid Build Coastguard Worker         return;
97*14675a02SAndroid Build Coastguard Worker       }
98*14675a02SAndroid Build Coastguard Worker       accumulator->RunSequential(
99*14675a02SAndroid Build Coastguard Worker           [=, partial = std::shared_ptr<T>(partial.release())] {
100*14675a02SAndroid Build Coastguard Worker             ReentrancyGuard guard;
101*14675a02SAndroid Build Coastguard Worker             FCP_CHECK_STATUS(guard.Check(accumulator->in_sequential_call()));
102*14675a02SAndroid Build Coastguard Worker             // mark that a task will be
103*14675a02SAndroid Build Coastguard Worker             // scheduled, if the accumulator is
104*14675a02SAndroid Build Coastguard Worker             // not canceled.
105*14675a02SAndroid Build Coastguard Worker             if (!accumulator->MaybeIncrementActiveCount()) {
106*14675a02SAndroid Build Coastguard Worker               return;
107*14675a02SAndroid Build Coastguard Worker             }
108*14675a02SAndroid Build Coastguard Worker             auto new_value = accumulator->accumulator_func_(
109*14675a02SAndroid Build Coastguard Worker                 *accumulator->accumulated_value_, *partial);
110*14675a02SAndroid Build Coastguard Worker             FCP_CHECK(new_value);
111*14675a02SAndroid Build Coastguard Worker             accumulator->accumulated_value_ = std::move(new_value);
112*14675a02SAndroid Build Coastguard Worker             // At this point the sequantial task has been run, and we (i)
113*14675a02SAndroid Build Coastguard Worker             // decrement both active and remaining counts and possibly reset the
114*14675a02SAndroid Build Coastguard Worker             // unobserved work flag, (ii) get the callback, which might be
115*14675a02SAndroid Build Coastguard Worker             // empty, and (iii) call it if that is not the case.
116*14675a02SAndroid Build Coastguard Worker             auto callback = accumulator->UpdateCountsAndGetCallback();
117*14675a02SAndroid Build Coastguard Worker             if (callback) {
118*14675a02SAndroid Build Coastguard Worker               callback();
119*14675a02SAndroid Build Coastguard Worker             }
120*14675a02SAndroid Build Coastguard Worker           });
121*14675a02SAndroid Build Coastguard Worker     };
122*14675a02SAndroid Build Coastguard Worker   }
123*14675a02SAndroid Build Coastguard Worker 
124*14675a02SAndroid Build Coastguard Worker   // Schedule a parallel generator that includes a delay. The result of the
125*14675a02SAndroid Build Coastguard Worker   // generator is fed to the accumulator_func
Schedule(std::function<std::unique_ptr<T> ()> generator,absl::Duration delay)126*14675a02SAndroid Build Coastguard Worker   void Schedule(std::function<std::unique_ptr<T>()> generator,
127*14675a02SAndroid Build Coastguard Worker                 absl::Duration delay) {
128*14675a02SAndroid Build Coastguard Worker     // IncrementRemainingCount() keeps track of the number of async tasks
129*14675a02SAndroid Build Coastguard Worker     // scheduled, and sets a flag when the count goes from 0 to 1, corresponding
130*14675a02SAndroid Build Coastguard Worker     // to a starting batch of unobserved work.
131*14675a02SAndroid Build Coastguard Worker     auto shared_this = this->shared_from_this();
132*14675a02SAndroid Build Coastguard Worker     shared_this->IncrementRemainingCount();
133*14675a02SAndroid Build Coastguard Worker     clock_->WakeupWithDeadline(
134*14675a02SAndroid Build Coastguard Worker         clock_->Now() + delay,
135*14675a02SAndroid Build Coastguard Worker         std::make_shared<CallbackWaiter>([shared_this, generator] {
136*14675a02SAndroid Build Coastguard Worker           shared_this->RunParallel(
137*14675a02SAndroid Build Coastguard Worker               Accumulator<T>::GetParallelScheduleFunc(shared_this, generator));
138*14675a02SAndroid Build Coastguard Worker         }));
139*14675a02SAndroid Build Coastguard Worker   }
140*14675a02SAndroid Build Coastguard Worker 
141*14675a02SAndroid Build Coastguard Worker   // Schedule a parallel generator. The result of the generator is fed to the
142*14675a02SAndroid Build Coastguard Worker   // accumulator_func
Schedule(std::function<std::unique_ptr<T> ()> generator)143*14675a02SAndroid Build Coastguard Worker   void Schedule(std::function<std::unique_ptr<T>()> generator) {
144*14675a02SAndroid Build Coastguard Worker     // IncrementRemainingCount() keeps track of the number of async tasks
145*14675a02SAndroid Build Coastguard Worker     // scheduled, and sets a flag when the count goes from 0 to 1, corresponding
146*14675a02SAndroid Build Coastguard Worker     // to a starting batch of unobserved work.
147*14675a02SAndroid Build Coastguard Worker     auto shared_this = this->shared_from_this();
148*14675a02SAndroid Build Coastguard Worker     shared_this->IncrementRemainingCount();
149*14675a02SAndroid Build Coastguard Worker     RunParallel([shared_this, generator] {
150*14675a02SAndroid Build Coastguard Worker       shared_this->GetParallelScheduleFunc(shared_this, generator)();
151*14675a02SAndroid Build Coastguard Worker     });
152*14675a02SAndroid Build Coastguard Worker   }
153*14675a02SAndroid Build Coastguard Worker 
154*14675a02SAndroid Build Coastguard Worker   // Returns true if the accumulator doesn't have any remaining tasks,
155*14675a02SAndroid Build Coastguard Worker   // even if their results have not been observed by a callback.
IsIdle()156*14675a02SAndroid Build Coastguard Worker   bool IsIdle() {
157*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&mutex_);
158*14675a02SAndroid Build Coastguard Worker     return remaining_sequential_tasks_count_ == 0;
159*14675a02SAndroid Build Coastguard Worker   }
160*14675a02SAndroid Build Coastguard Worker 
161*14675a02SAndroid Build Coastguard Worker   // Returns false if no async work has happened since last time this function
162*14675a02SAndroid Build Coastguard Worker   // was called, or the first time it is called. Otherwise it returns true and
163*14675a02SAndroid Build Coastguard Worker   // schedules a callback to be called once the scheduler is idle.
SetAsyncObserver(std::function<void ()> async_callback)164*14675a02SAndroid Build Coastguard Worker   bool SetAsyncObserver(std::function<void()> async_callback) {
165*14675a02SAndroid Build Coastguard Worker     bool idle;
166*14675a02SAndroid Build Coastguard Worker     {
167*14675a02SAndroid Build Coastguard Worker       absl::MutexLock lock(&mutex_);
168*14675a02SAndroid Build Coastguard Worker       if (!has_unobserved_work_) {
169*14675a02SAndroid Build Coastguard Worker         return false;
170*14675a02SAndroid Build Coastguard Worker       }
171*14675a02SAndroid Build Coastguard Worker       idle = (remaining_sequential_tasks_count_ == 0);
172*14675a02SAndroid Build Coastguard Worker       if (idle) {
173*14675a02SAndroid Build Coastguard Worker         // The flag is set to false, and the callback is run as soon as we leave
174*14675a02SAndroid Build Coastguard Worker         // the mutex's scope.
175*14675a02SAndroid Build Coastguard Worker         has_unobserved_work_ = false;
176*14675a02SAndroid Build Coastguard Worker       } else {
177*14675a02SAndroid Build Coastguard Worker         // The callbak is scheduled for later, as there is ongoing work.
178*14675a02SAndroid Build Coastguard Worker         async_callback_ = async_callback;
179*14675a02SAndroid Build Coastguard Worker       }
180*14675a02SAndroid Build Coastguard Worker     }
181*14675a02SAndroid Build Coastguard Worker     if (idle) {
182*14675a02SAndroid Build Coastguard Worker       auto shared_this = this->shared_from_this();
183*14675a02SAndroid Build Coastguard Worker       RunSequential([async_callback, shared_this] { async_callback(); });
184*14675a02SAndroid Build Coastguard Worker     }
185*14675a02SAndroid Build Coastguard Worker     return true;
186*14675a02SAndroid Build Coastguard Worker   }
187*14675a02SAndroid Build Coastguard Worker 
188*14675a02SAndroid Build Coastguard Worker   // Updates the active and remaining task counts, and returns the callback to
189*14675a02SAndroid Build Coastguard Worker   // be executed, or nullptr if there's pending async work.
UpdateCountsAndGetCallback()190*14675a02SAndroid Build Coastguard Worker   inline std::function<void()> UpdateCountsAndGetCallback() {
191*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&mutex_);
192*14675a02SAndroid Build Coastguard Worker     if (--active_count_ == 0 && is_cancelled_) {
193*14675a02SAndroid Build Coastguard Worker       inactive_cv_.SignalAll();
194*14675a02SAndroid Build Coastguard Worker     }
195*14675a02SAndroid Build Coastguard Worker     --remaining_sequential_tasks_count_;
196*14675a02SAndroid Build Coastguard Worker     if (remaining_sequential_tasks_count_ == 0 && async_callback_) {
197*14675a02SAndroid Build Coastguard Worker       has_unobserved_work_ = false;
198*14675a02SAndroid Build Coastguard Worker       auto callback = async_callback_;
199*14675a02SAndroid Build Coastguard Worker       async_callback_ = nullptr;
200*14675a02SAndroid Build Coastguard Worker       return callback;
201*14675a02SAndroid Build Coastguard Worker     } else {
202*14675a02SAndroid Build Coastguard Worker       return nullptr;
203*14675a02SAndroid Build Coastguard Worker     }
204*14675a02SAndroid Build Coastguard Worker   }
205*14675a02SAndroid Build Coastguard Worker 
206*14675a02SAndroid Build Coastguard Worker   // Take the accumulated result and abort any further work. This method can
207*14675a02SAndroid Build Coastguard Worker   // only be called when the accumulator is idle
GetResultAndCancel()208*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<T> GetResultAndCancel() {
209*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&mutex_);
210*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(active_count_ == 0);
211*14675a02SAndroid Build Coastguard Worker     is_cancelled_ = true;
212*14675a02SAndroid Build Coastguard Worker     return std::move(accumulated_value_);
213*14675a02SAndroid Build Coastguard Worker   }
214*14675a02SAndroid Build Coastguard Worker 
215*14675a02SAndroid Build Coastguard Worker   // CancellationImpl implementation
Cancel()216*14675a02SAndroid Build Coastguard Worker   void Cancel() override {
217*14675a02SAndroid Build Coastguard Worker     mutex_.Lock();
218*14675a02SAndroid Build Coastguard Worker     is_cancelled_ = true;
219*14675a02SAndroid Build Coastguard Worker     while (active_count_ > 0) {
220*14675a02SAndroid Build Coastguard Worker       inactive_cv_.Wait(&mutex_);
221*14675a02SAndroid Build Coastguard Worker     }
222*14675a02SAndroid Build Coastguard Worker     mutex_.Unlock();
223*14675a02SAndroid Build Coastguard Worker   }
224*14675a02SAndroid Build Coastguard Worker 
IsCancelled()225*14675a02SAndroid Build Coastguard Worker   bool IsCancelled() {
226*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&mutex_);
227*14675a02SAndroid Build Coastguard Worker     return is_cancelled_;
228*14675a02SAndroid Build Coastguard Worker   }
229*14675a02SAndroid Build Coastguard Worker 
MaybeIncrementActiveCount()230*14675a02SAndroid Build Coastguard Worker   bool MaybeIncrementActiveCount() {
231*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&mutex_);
232*14675a02SAndroid Build Coastguard Worker     if (is_cancelled_) {
233*14675a02SAndroid Build Coastguard Worker       return false;
234*14675a02SAndroid Build Coastguard Worker     }
235*14675a02SAndroid Build Coastguard Worker     active_count_++;
236*14675a02SAndroid Build Coastguard Worker     return true;
237*14675a02SAndroid Build Coastguard Worker   }
238*14675a02SAndroid Build Coastguard Worker 
DecrementActiveCount()239*14675a02SAndroid Build Coastguard Worker   size_t DecrementActiveCount() {
240*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&mutex_);
241*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(active_count_ > 0);
242*14675a02SAndroid Build Coastguard Worker     if (--active_count_ == 0 && is_cancelled_) {
243*14675a02SAndroid Build Coastguard Worker       inactive_cv_.SignalAll();
244*14675a02SAndroid Build Coastguard Worker     }
245*14675a02SAndroid Build Coastguard Worker     return active_count_;
246*14675a02SAndroid Build Coastguard Worker   }
247*14675a02SAndroid Build Coastguard Worker 
IncrementRemainingCount()248*14675a02SAndroid Build Coastguard Worker   void IncrementRemainingCount() {
249*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&mutex_);
250*14675a02SAndroid Build Coastguard Worker     has_unobserved_work_ |= (remaining_sequential_tasks_count_ == 0);
251*14675a02SAndroid Build Coastguard Worker     remaining_sequential_tasks_count_++;
252*14675a02SAndroid Build Coastguard Worker   }
253*14675a02SAndroid Build Coastguard Worker 
in_sequential_call()254*14675a02SAndroid Build Coastguard Worker   std::atomic<bool>* in_sequential_call() { return &in_sequential_call_; }
255*14675a02SAndroid Build Coastguard Worker 
RunParallel(std::function<void ()> function)256*14675a02SAndroid Build Coastguard Worker   void inline RunParallel(std::function<void()> function) {
257*14675a02SAndroid Build Coastguard Worker     parallel_scheduler_->Schedule(function);
258*14675a02SAndroid Build Coastguard Worker   }
259*14675a02SAndroid Build Coastguard Worker 
RunSequential(std::function<void ()> function)260*14675a02SAndroid Build Coastguard Worker   void inline RunSequential(std::function<void()> function) {
261*14675a02SAndroid Build Coastguard Worker     sequential_scheduler_->Schedule(function);
262*14675a02SAndroid Build Coastguard Worker   }
263*14675a02SAndroid Build Coastguard Worker 
264*14675a02SAndroid Build Coastguard Worker  private:
265*14675a02SAndroid Build Coastguard Worker   // Scheduler for sequential and parallel tasks, received from the
266*14675a02SAndroid Build Coastguard Worker   // SecAggScheduler instatiating this class
267*14675a02SAndroid Build Coastguard Worker   Scheduler* parallel_scheduler_;
268*14675a02SAndroid Build Coastguard Worker   Scheduler* sequential_scheduler_;
269*14675a02SAndroid Build Coastguard Worker 
270*14675a02SAndroid Build Coastguard Worker   // Callback to be executed the next time that the sequential scheduler
271*14675a02SAndroid Build Coastguard Worker   // becomes idle.
272*14675a02SAndroid Build Coastguard Worker   std::function<void()> async_callback_ ABSL_GUARDED_BY(mutex_) =
273*14675a02SAndroid Build Coastguard Worker       std::function<void()>();
274*14675a02SAndroid Build Coastguard Worker   // Accumulated value - accessed by sequential tasks only.
275*14675a02SAndroid Build Coastguard Worker   std::unique_ptr<T> accumulated_value_;
276*14675a02SAndroid Build Coastguard Worker   // Accumulation function - accessed by sequential tasks only.
277*14675a02SAndroid Build Coastguard Worker   std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func_;
278*14675a02SAndroid Build Coastguard Worker   // Clock used for scheduling delays in parallel tasks
279*14675a02SAndroid Build Coastguard Worker   Clock* clock_;
280*14675a02SAndroid Build Coastguard Worker   // Remaining number of sequential tasks to be executed - accessed by
281*14675a02SAndroid Build Coastguard Worker   // sequential tasks only.
282*14675a02SAndroid Build Coastguard Worker   size_t remaining_sequential_tasks_count_ ABSL_GUARDED_BY(mutex_) = 0;
283*14675a02SAndroid Build Coastguard Worker   bool has_unobserved_work_ ABSL_GUARDED_BY(mutex_) = false;
284*14675a02SAndroid Build Coastguard Worker 
285*14675a02SAndroid Build Coastguard Worker   // Number of active calls to either callback function.
286*14675a02SAndroid Build Coastguard Worker   size_t active_count_ ABSL_GUARDED_BY(mutex_) = 0;
287*14675a02SAndroid Build Coastguard Worker   // This is set to true when the run is aborted.
288*14675a02SAndroid Build Coastguard Worker   bool is_cancelled_ ABSL_GUARDED_BY(mutex_) = false;
289*14675a02SAndroid Build Coastguard Worker   // Protects active_count_ and cancelled_.
290*14675a02SAndroid Build Coastguard Worker   absl::Mutex mutex_;
291*14675a02SAndroid Build Coastguard Worker   // Used to notify cancellation about reaching inactive state;
292*14675a02SAndroid Build Coastguard Worker   absl::CondVar inactive_cv_;
293*14675a02SAndroid Build Coastguard Worker   // This is used by ReentrancyGuard to ensure that Sequential tasks are
294*14675a02SAndroid Build Coastguard Worker   // indeed sequential.
295*14675a02SAndroid Build Coastguard Worker   std::atomic<bool> in_sequential_call_ = false;
296*14675a02SAndroid Build Coastguard Worker };
297*14675a02SAndroid Build Coastguard Worker 
298*14675a02SAndroid Build Coastguard Worker // Implementation of ParallelGenerateSequentialReduce based on fcp::Scheduler.
299*14675a02SAndroid Build Coastguard Worker // Takes two Schedulers, one which is responsible for parallel execution and
300*14675a02SAndroid Build Coastguard Worker // another for serial execution. Additionally, takes a clock that can be used to
301*14675a02SAndroid Build Coastguard Worker // induce delay in task executions.
302*14675a02SAndroid Build Coastguard Worker class SecAggScheduler {
303*14675a02SAndroid Build Coastguard Worker  public:
304*14675a02SAndroid Build Coastguard Worker   SecAggScheduler(Scheduler* parallel_scheduler,
305*14675a02SAndroid Build Coastguard Worker                   Scheduler* sequential_scheduler,
306*14675a02SAndroid Build Coastguard Worker                   Clock* clock = Clock::RealClock())
parallel_scheduler_(parallel_scheduler)307*14675a02SAndroid Build Coastguard Worker       : parallel_scheduler_(parallel_scheduler),
308*14675a02SAndroid Build Coastguard Worker         sequential_scheduler_(sequential_scheduler),
309*14675a02SAndroid Build Coastguard Worker         clock_(clock) {}
310*14675a02SAndroid Build Coastguard Worker 
311*14675a02SAndroid Build Coastguard Worker   // SecAggScheduler is neither copyable nor movable.
312*14675a02SAndroid Build Coastguard Worker   SecAggScheduler(const SecAggScheduler&) = delete;
313*14675a02SAndroid Build Coastguard Worker   SecAggScheduler& operator=(const SecAggScheduler&) = delete;
314*14675a02SAndroid Build Coastguard Worker 
315*14675a02SAndroid Build Coastguard Worker   virtual ~SecAggScheduler() = default;
316*14675a02SAndroid Build Coastguard Worker 
317*14675a02SAndroid Build Coastguard Worker   // Schedule a callback to be invoked on the sequential scheduler.
ScheduleCallback(std::function<void ()> callback)318*14675a02SAndroid Build Coastguard Worker   inline void ScheduleCallback(std::function<void()> callback) {
319*14675a02SAndroid Build Coastguard Worker     RunSequential(callback);
320*14675a02SAndroid Build Coastguard Worker   }
321*14675a02SAndroid Build Coastguard Worker 
322*14675a02SAndroid Build Coastguard Worker   template <typename T>
CreateAccumulator(std::unique_ptr<T> initial_value,std::function<std::unique_ptr<T> (const T &,const T &)> accumulator_func)323*14675a02SAndroid Build Coastguard Worker   std::shared_ptr<Accumulator<T>> CreateAccumulator(
324*14675a02SAndroid Build Coastguard Worker       std::unique_ptr<T> initial_value,
325*14675a02SAndroid Build Coastguard Worker       std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func) {
326*14675a02SAndroid Build Coastguard Worker     return std::make_shared<Accumulator<T>>(
327*14675a02SAndroid Build Coastguard Worker         std::move(initial_value), accumulator_func, parallel_scheduler_,
328*14675a02SAndroid Build Coastguard Worker         sequential_scheduler_, clock_);
329*14675a02SAndroid Build Coastguard Worker   }
330*14675a02SAndroid Build Coastguard Worker 
331*14675a02SAndroid Build Coastguard Worker   void WaitUntilIdle();
332*14675a02SAndroid Build Coastguard Worker 
333*14675a02SAndroid Build Coastguard Worker  protected:
334*14675a02SAndroid Build Coastguard Worker   // Virtual for testing
335*14675a02SAndroid Build Coastguard Worker   virtual void RunSequential(std::function<void()> function);
336*14675a02SAndroid Build Coastguard Worker 
337*14675a02SAndroid Build Coastguard Worker  private:
338*14675a02SAndroid Build Coastguard Worker   Scheduler* parallel_scheduler_;
339*14675a02SAndroid Build Coastguard Worker   Scheduler* sequential_scheduler_;
340*14675a02SAndroid Build Coastguard Worker   Clock* clock_;
341*14675a02SAndroid Build Coastguard Worker };
342*14675a02SAndroid Build Coastguard Worker 
343*14675a02SAndroid Build Coastguard Worker }  // namespace secagg
344*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
345*14675a02SAndroid Build Coastguard Worker 
346*14675a02SAndroid Build Coastguard Worker #endif  // FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
347