1*14675a02SAndroid Build Coastguard Worker /* 2*14675a02SAndroid Build Coastguard Worker * Copyright 2019 Google LLC 3*14675a02SAndroid Build Coastguard Worker * 4*14675a02SAndroid Build Coastguard Worker * Licensed under the Apache License, Version 2.0 (the "License"); 5*14675a02SAndroid Build Coastguard Worker * you may not use this file except in compliance with the License. 6*14675a02SAndroid Build Coastguard Worker * You may obtain a copy of the License at 7*14675a02SAndroid Build Coastguard Worker * 8*14675a02SAndroid Build Coastguard Worker * http://www.apache.org/licenses/LICENSE-2.0 9*14675a02SAndroid Build Coastguard Worker * 10*14675a02SAndroid Build Coastguard Worker * Unless required by applicable law or agreed to in writing, software 11*14675a02SAndroid Build Coastguard Worker * distributed under the License is distributed on an "AS IS" BASIS, 12*14675a02SAndroid Build Coastguard Worker * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13*14675a02SAndroid Build Coastguard Worker * See the License for the specific language governing permissions and 14*14675a02SAndroid Build Coastguard Worker * limitations under the License. 15*14675a02SAndroid Build Coastguard Worker */ 16*14675a02SAndroid Build Coastguard Worker 17*14675a02SAndroid Build Coastguard Worker #ifndef FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_ 18*14675a02SAndroid Build Coastguard Worker #define FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_ 19*14675a02SAndroid Build Coastguard Worker 20*14675a02SAndroid Build Coastguard Worker #include <atomic> 21*14675a02SAndroid Build Coastguard Worker #include <functional> 22*14675a02SAndroid Build Coastguard Worker #include <memory> 23*14675a02SAndroid Build Coastguard Worker #include <utility> 24*14675a02SAndroid Build Coastguard Worker #include <vector> 25*14675a02SAndroid Build Coastguard Worker 26*14675a02SAndroid Build Coastguard Worker #include "absl/synchronization/mutex.h" 27*14675a02SAndroid Build Coastguard Worker #include "absl/time/time.h" 28*14675a02SAndroid Build Coastguard Worker #include "fcp/base/clock.h" 29*14675a02SAndroid Build Coastguard Worker #include "fcp/base/monitoring.h" 30*14675a02SAndroid Build Coastguard Worker #include "fcp/base/reentrancy_guard.h" 31*14675a02SAndroid Build Coastguard Worker #include "fcp/base/scheduler.h" 32*14675a02SAndroid Build Coastguard Worker 33*14675a02SAndroid Build Coastguard Worker namespace fcp { 34*14675a02SAndroid Build Coastguard Worker namespace secagg { 35*14675a02SAndroid Build Coastguard Worker 36*14675a02SAndroid Build Coastguard Worker // Simple callback waiter that runs the function on Wakeup. 37*14675a02SAndroid Build Coastguard Worker class CallbackWaiter : public Clock::Waiter { 38*14675a02SAndroid Build Coastguard Worker public: CallbackWaiter(std::function<void ()> callback)39*14675a02SAndroid Build Coastguard Worker explicit CallbackWaiter(std::function<void()> callback) 40*14675a02SAndroid Build Coastguard Worker : callback_(std::move(callback)) {} 41*14675a02SAndroid Build Coastguard Worker WakeUp()42*14675a02SAndroid Build Coastguard Worker void WakeUp() override { callback_(); } 43*14675a02SAndroid Build Coastguard Worker 44*14675a02SAndroid Build Coastguard Worker private: 45*14675a02SAndroid Build Coastguard Worker std::function<void()> callback_; 46*14675a02SAndroid Build Coastguard Worker }; 47*14675a02SAndroid Build Coastguard Worker 48*14675a02SAndroid Build Coastguard Worker // Provides Cancellation mechanism for SevAggScheduler. 49*14675a02SAndroid Build Coastguard Worker class CancellationImpl { 50*14675a02SAndroid Build Coastguard Worker public: 51*14675a02SAndroid Build Coastguard Worker virtual ~CancellationImpl() = default; 52*14675a02SAndroid Build Coastguard Worker 53*14675a02SAndroid Build Coastguard Worker // Calling Cancel results in skipping the remaining, still pending 54*14675a02SAndroid Build Coastguard Worker // ParallelGenerateSequentialReduce. The call blocks waiting for any 55*14675a02SAndroid Build Coastguard Worker // currently active ongoing tasks to complete. Calling Cancel for the second 56*14675a02SAndroid Build Coastguard Worker // time has no additional effect. 57*14675a02SAndroid Build Coastguard Worker virtual void Cancel() = 0; 58*14675a02SAndroid Build Coastguard Worker }; 59*14675a02SAndroid Build Coastguard Worker 60*14675a02SAndroid Build Coastguard Worker using CancellationToken = std::shared_ptr<CancellationImpl>; 61*14675a02SAndroid Build Coastguard Worker 62*14675a02SAndroid Build Coastguard Worker template <typename T> 63*14675a02SAndroid Build Coastguard Worker class Accumulator : public CancellationImpl, 64*14675a02SAndroid Build Coastguard Worker public std::enable_shared_from_this<Accumulator<T>> { 65*14675a02SAndroid Build Coastguard Worker public: Accumulator(std::unique_ptr<T> initial_value,std::function<std::unique_ptr<T> (const T &,const T &)> accumulator_func,Scheduler * parallel_scheduler,Scheduler * sequential_scheduler,Clock * clock)66*14675a02SAndroid Build Coastguard Worker Accumulator( 67*14675a02SAndroid Build Coastguard Worker std::unique_ptr<T> initial_value, 68*14675a02SAndroid Build Coastguard Worker std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func, 69*14675a02SAndroid Build Coastguard Worker Scheduler* parallel_scheduler, Scheduler* sequential_scheduler, 70*14675a02SAndroid Build Coastguard Worker Clock* clock) 71*14675a02SAndroid Build Coastguard Worker : parallel_scheduler_(parallel_scheduler), 72*14675a02SAndroid Build Coastguard Worker sequential_scheduler_(sequential_scheduler), 73*14675a02SAndroid Build Coastguard Worker accumulated_value_(std::move(initial_value)), 74*14675a02SAndroid Build Coastguard Worker accumulator_func_(accumulator_func), 75*14675a02SAndroid Build Coastguard Worker clock_(clock) {} 76*14675a02SAndroid Build Coastguard Worker GetParallelScheduleFunc(std::shared_ptr<Accumulator<T>> accumulator,std::function<std::unique_ptr<T> ()> generator)77*14675a02SAndroid Build Coastguard Worker inline static std::function<void()> GetParallelScheduleFunc( 78*14675a02SAndroid Build Coastguard Worker std::shared_ptr<Accumulator<T>> accumulator, 79*14675a02SAndroid Build Coastguard Worker std::function<std::unique_ptr<T>()> generator) { 80*14675a02SAndroid Build Coastguard Worker return [accumulator, generator] { 81*14675a02SAndroid Build Coastguard Worker // Increment active count if the accumulator is not canceled, otherwise 82*14675a02SAndroid Build Coastguard Worker // return without scheduling the task. By active count we mean the total 83*14675a02SAndroid Build Coastguard Worker // number of scheduled tasks, both parallel and sequential. To cancel an 84*14675a02SAndroid Build Coastguard Worker // accumulator, we wait until that this count is 0. 85*14675a02SAndroid Build Coastguard Worker if (!accumulator->MaybeIncrementActiveCount()) { 86*14675a02SAndroid Build Coastguard Worker return; 87*14675a02SAndroid Build Coastguard Worker } 88*14675a02SAndroid Build Coastguard Worker auto partial = generator(); 89*14675a02SAndroid Build Coastguard Worker FCP_CHECK(partial); 90*14675a02SAndroid Build Coastguard Worker // Decrement the count for the parallel task that was just run as 91*14675a02SAndroid Build Coastguard Worker // generator(). 92*14675a02SAndroid Build Coastguard Worker accumulator->DecrementActiveCount(); 93*14675a02SAndroid Build Coastguard Worker // Schedule sequential part of the generator, only if accumulator is not 94*14675a02SAndroid Build Coastguard Worker // cancelled, otherwise return without scheduling it. 95*14675a02SAndroid Build Coastguard Worker if (accumulator->IsCancelled()) { 96*14675a02SAndroid Build Coastguard Worker return; 97*14675a02SAndroid Build Coastguard Worker } 98*14675a02SAndroid Build Coastguard Worker accumulator->RunSequential( 99*14675a02SAndroid Build Coastguard Worker [=, partial = std::shared_ptr<T>(partial.release())] { 100*14675a02SAndroid Build Coastguard Worker ReentrancyGuard guard; 101*14675a02SAndroid Build Coastguard Worker FCP_CHECK_STATUS(guard.Check(accumulator->in_sequential_call())); 102*14675a02SAndroid Build Coastguard Worker // mark that a task will be 103*14675a02SAndroid Build Coastguard Worker // scheduled, if the accumulator is 104*14675a02SAndroid Build Coastguard Worker // not canceled. 105*14675a02SAndroid Build Coastguard Worker if (!accumulator->MaybeIncrementActiveCount()) { 106*14675a02SAndroid Build Coastguard Worker return; 107*14675a02SAndroid Build Coastguard Worker } 108*14675a02SAndroid Build Coastguard Worker auto new_value = accumulator->accumulator_func_( 109*14675a02SAndroid Build Coastguard Worker *accumulator->accumulated_value_, *partial); 110*14675a02SAndroid Build Coastguard Worker FCP_CHECK(new_value); 111*14675a02SAndroid Build Coastguard Worker accumulator->accumulated_value_ = std::move(new_value); 112*14675a02SAndroid Build Coastguard Worker // At this point the sequantial task has been run, and we (i) 113*14675a02SAndroid Build Coastguard Worker // decrement both active and remaining counts and possibly reset the 114*14675a02SAndroid Build Coastguard Worker // unobserved work flag, (ii) get the callback, which might be 115*14675a02SAndroid Build Coastguard Worker // empty, and (iii) call it if that is not the case. 116*14675a02SAndroid Build Coastguard Worker auto callback = accumulator->UpdateCountsAndGetCallback(); 117*14675a02SAndroid Build Coastguard Worker if (callback) { 118*14675a02SAndroid Build Coastguard Worker callback(); 119*14675a02SAndroid Build Coastguard Worker } 120*14675a02SAndroid Build Coastguard Worker }); 121*14675a02SAndroid Build Coastguard Worker }; 122*14675a02SAndroid Build Coastguard Worker } 123*14675a02SAndroid Build Coastguard Worker 124*14675a02SAndroid Build Coastguard Worker // Schedule a parallel generator that includes a delay. The result of the 125*14675a02SAndroid Build Coastguard Worker // generator is fed to the accumulator_func Schedule(std::function<std::unique_ptr<T> ()> generator,absl::Duration delay)126*14675a02SAndroid Build Coastguard Worker void Schedule(std::function<std::unique_ptr<T>()> generator, 127*14675a02SAndroid Build Coastguard Worker absl::Duration delay) { 128*14675a02SAndroid Build Coastguard Worker // IncrementRemainingCount() keeps track of the number of async tasks 129*14675a02SAndroid Build Coastguard Worker // scheduled, and sets a flag when the count goes from 0 to 1, corresponding 130*14675a02SAndroid Build Coastguard Worker // to a starting batch of unobserved work. 131*14675a02SAndroid Build Coastguard Worker auto shared_this = this->shared_from_this(); 132*14675a02SAndroid Build Coastguard Worker shared_this->IncrementRemainingCount(); 133*14675a02SAndroid Build Coastguard Worker clock_->WakeupWithDeadline( 134*14675a02SAndroid Build Coastguard Worker clock_->Now() + delay, 135*14675a02SAndroid Build Coastguard Worker std::make_shared<CallbackWaiter>([shared_this, generator] { 136*14675a02SAndroid Build Coastguard Worker shared_this->RunParallel( 137*14675a02SAndroid Build Coastguard Worker Accumulator<T>::GetParallelScheduleFunc(shared_this, generator)); 138*14675a02SAndroid Build Coastguard Worker })); 139*14675a02SAndroid Build Coastguard Worker } 140*14675a02SAndroid Build Coastguard Worker 141*14675a02SAndroid Build Coastguard Worker // Schedule a parallel generator. The result of the generator is fed to the 142*14675a02SAndroid Build Coastguard Worker // accumulator_func Schedule(std::function<std::unique_ptr<T> ()> generator)143*14675a02SAndroid Build Coastguard Worker void Schedule(std::function<std::unique_ptr<T>()> generator) { 144*14675a02SAndroid Build Coastguard Worker // IncrementRemainingCount() keeps track of the number of async tasks 145*14675a02SAndroid Build Coastguard Worker // scheduled, and sets a flag when the count goes from 0 to 1, corresponding 146*14675a02SAndroid Build Coastguard Worker // to a starting batch of unobserved work. 147*14675a02SAndroid Build Coastguard Worker auto shared_this = this->shared_from_this(); 148*14675a02SAndroid Build Coastguard Worker shared_this->IncrementRemainingCount(); 149*14675a02SAndroid Build Coastguard Worker RunParallel([shared_this, generator] { 150*14675a02SAndroid Build Coastguard Worker shared_this->GetParallelScheduleFunc(shared_this, generator)(); 151*14675a02SAndroid Build Coastguard Worker }); 152*14675a02SAndroid Build Coastguard Worker } 153*14675a02SAndroid Build Coastguard Worker 154*14675a02SAndroid Build Coastguard Worker // Returns true if the accumulator doesn't have any remaining tasks, 155*14675a02SAndroid Build Coastguard Worker // even if their results have not been observed by a callback. IsIdle()156*14675a02SAndroid Build Coastguard Worker bool IsIdle() { 157*14675a02SAndroid Build Coastguard Worker absl::MutexLock lock(&mutex_); 158*14675a02SAndroid Build Coastguard Worker return remaining_sequential_tasks_count_ == 0; 159*14675a02SAndroid Build Coastguard Worker } 160*14675a02SAndroid Build Coastguard Worker 161*14675a02SAndroid Build Coastguard Worker // Returns false if no async work has happened since last time this function 162*14675a02SAndroid Build Coastguard Worker // was called, or the first time it is called. Otherwise it returns true and 163*14675a02SAndroid Build Coastguard Worker // schedules a callback to be called once the scheduler is idle. SetAsyncObserver(std::function<void ()> async_callback)164*14675a02SAndroid Build Coastguard Worker bool SetAsyncObserver(std::function<void()> async_callback) { 165*14675a02SAndroid Build Coastguard Worker bool idle; 166*14675a02SAndroid Build Coastguard Worker { 167*14675a02SAndroid Build Coastguard Worker absl::MutexLock lock(&mutex_); 168*14675a02SAndroid Build Coastguard Worker if (!has_unobserved_work_) { 169*14675a02SAndroid Build Coastguard Worker return false; 170*14675a02SAndroid Build Coastguard Worker } 171*14675a02SAndroid Build Coastguard Worker idle = (remaining_sequential_tasks_count_ == 0); 172*14675a02SAndroid Build Coastguard Worker if (idle) { 173*14675a02SAndroid Build Coastguard Worker // The flag is set to false, and the callback is run as soon as we leave 174*14675a02SAndroid Build Coastguard Worker // the mutex's scope. 175*14675a02SAndroid Build Coastguard Worker has_unobserved_work_ = false; 176*14675a02SAndroid Build Coastguard Worker } else { 177*14675a02SAndroid Build Coastguard Worker // The callbak is scheduled for later, as there is ongoing work. 178*14675a02SAndroid Build Coastguard Worker async_callback_ = async_callback; 179*14675a02SAndroid Build Coastguard Worker } 180*14675a02SAndroid Build Coastguard Worker } 181*14675a02SAndroid Build Coastguard Worker if (idle) { 182*14675a02SAndroid Build Coastguard Worker auto shared_this = this->shared_from_this(); 183*14675a02SAndroid Build Coastguard Worker RunSequential([async_callback, shared_this] { async_callback(); }); 184*14675a02SAndroid Build Coastguard Worker } 185*14675a02SAndroid Build Coastguard Worker return true; 186*14675a02SAndroid Build Coastguard Worker } 187*14675a02SAndroid Build Coastguard Worker 188*14675a02SAndroid Build Coastguard Worker // Updates the active and remaining task counts, and returns the callback to 189*14675a02SAndroid Build Coastguard Worker // be executed, or nullptr if there's pending async work. UpdateCountsAndGetCallback()190*14675a02SAndroid Build Coastguard Worker inline std::function<void()> UpdateCountsAndGetCallback() { 191*14675a02SAndroid Build Coastguard Worker absl::MutexLock lock(&mutex_); 192*14675a02SAndroid Build Coastguard Worker if (--active_count_ == 0 && is_cancelled_) { 193*14675a02SAndroid Build Coastguard Worker inactive_cv_.SignalAll(); 194*14675a02SAndroid Build Coastguard Worker } 195*14675a02SAndroid Build Coastguard Worker --remaining_sequential_tasks_count_; 196*14675a02SAndroid Build Coastguard Worker if (remaining_sequential_tasks_count_ == 0 && async_callback_) { 197*14675a02SAndroid Build Coastguard Worker has_unobserved_work_ = false; 198*14675a02SAndroid Build Coastguard Worker auto callback = async_callback_; 199*14675a02SAndroid Build Coastguard Worker async_callback_ = nullptr; 200*14675a02SAndroid Build Coastguard Worker return callback; 201*14675a02SAndroid Build Coastguard Worker } else { 202*14675a02SAndroid Build Coastguard Worker return nullptr; 203*14675a02SAndroid Build Coastguard Worker } 204*14675a02SAndroid Build Coastguard Worker } 205*14675a02SAndroid Build Coastguard Worker 206*14675a02SAndroid Build Coastguard Worker // Take the accumulated result and abort any further work. This method can 207*14675a02SAndroid Build Coastguard Worker // only be called when the accumulator is idle GetResultAndCancel()208*14675a02SAndroid Build Coastguard Worker std::unique_ptr<T> GetResultAndCancel() { 209*14675a02SAndroid Build Coastguard Worker absl::MutexLock lock(&mutex_); 210*14675a02SAndroid Build Coastguard Worker FCP_CHECK(active_count_ == 0); 211*14675a02SAndroid Build Coastguard Worker is_cancelled_ = true; 212*14675a02SAndroid Build Coastguard Worker return std::move(accumulated_value_); 213*14675a02SAndroid Build Coastguard Worker } 214*14675a02SAndroid Build Coastguard Worker 215*14675a02SAndroid Build Coastguard Worker // CancellationImpl implementation Cancel()216*14675a02SAndroid Build Coastguard Worker void Cancel() override { 217*14675a02SAndroid Build Coastguard Worker mutex_.Lock(); 218*14675a02SAndroid Build Coastguard Worker is_cancelled_ = true; 219*14675a02SAndroid Build Coastguard Worker while (active_count_ > 0) { 220*14675a02SAndroid Build Coastguard Worker inactive_cv_.Wait(&mutex_); 221*14675a02SAndroid Build Coastguard Worker } 222*14675a02SAndroid Build Coastguard Worker mutex_.Unlock(); 223*14675a02SAndroid Build Coastguard Worker } 224*14675a02SAndroid Build Coastguard Worker IsCancelled()225*14675a02SAndroid Build Coastguard Worker bool IsCancelled() { 226*14675a02SAndroid Build Coastguard Worker absl::MutexLock lock(&mutex_); 227*14675a02SAndroid Build Coastguard Worker return is_cancelled_; 228*14675a02SAndroid Build Coastguard Worker } 229*14675a02SAndroid Build Coastguard Worker MaybeIncrementActiveCount()230*14675a02SAndroid Build Coastguard Worker bool MaybeIncrementActiveCount() { 231*14675a02SAndroid Build Coastguard Worker absl::MutexLock lock(&mutex_); 232*14675a02SAndroid Build Coastguard Worker if (is_cancelled_) { 233*14675a02SAndroid Build Coastguard Worker return false; 234*14675a02SAndroid Build Coastguard Worker } 235*14675a02SAndroid Build Coastguard Worker active_count_++; 236*14675a02SAndroid Build Coastguard Worker return true; 237*14675a02SAndroid Build Coastguard Worker } 238*14675a02SAndroid Build Coastguard Worker DecrementActiveCount()239*14675a02SAndroid Build Coastguard Worker size_t DecrementActiveCount() { 240*14675a02SAndroid Build Coastguard Worker absl::MutexLock lock(&mutex_); 241*14675a02SAndroid Build Coastguard Worker FCP_CHECK(active_count_ > 0); 242*14675a02SAndroid Build Coastguard Worker if (--active_count_ == 0 && is_cancelled_) { 243*14675a02SAndroid Build Coastguard Worker inactive_cv_.SignalAll(); 244*14675a02SAndroid Build Coastguard Worker } 245*14675a02SAndroid Build Coastguard Worker return active_count_; 246*14675a02SAndroid Build Coastguard Worker } 247*14675a02SAndroid Build Coastguard Worker IncrementRemainingCount()248*14675a02SAndroid Build Coastguard Worker void IncrementRemainingCount() { 249*14675a02SAndroid Build Coastguard Worker absl::MutexLock lock(&mutex_); 250*14675a02SAndroid Build Coastguard Worker has_unobserved_work_ |= (remaining_sequential_tasks_count_ == 0); 251*14675a02SAndroid Build Coastguard Worker remaining_sequential_tasks_count_++; 252*14675a02SAndroid Build Coastguard Worker } 253*14675a02SAndroid Build Coastguard Worker in_sequential_call()254*14675a02SAndroid Build Coastguard Worker std::atomic<bool>* in_sequential_call() { return &in_sequential_call_; } 255*14675a02SAndroid Build Coastguard Worker RunParallel(std::function<void ()> function)256*14675a02SAndroid Build Coastguard Worker void inline RunParallel(std::function<void()> function) { 257*14675a02SAndroid Build Coastguard Worker parallel_scheduler_->Schedule(function); 258*14675a02SAndroid Build Coastguard Worker } 259*14675a02SAndroid Build Coastguard Worker RunSequential(std::function<void ()> function)260*14675a02SAndroid Build Coastguard Worker void inline RunSequential(std::function<void()> function) { 261*14675a02SAndroid Build Coastguard Worker sequential_scheduler_->Schedule(function); 262*14675a02SAndroid Build Coastguard Worker } 263*14675a02SAndroid Build Coastguard Worker 264*14675a02SAndroid Build Coastguard Worker private: 265*14675a02SAndroid Build Coastguard Worker // Scheduler for sequential and parallel tasks, received from the 266*14675a02SAndroid Build Coastguard Worker // SecAggScheduler instatiating this class 267*14675a02SAndroid Build Coastguard Worker Scheduler* parallel_scheduler_; 268*14675a02SAndroid Build Coastguard Worker Scheduler* sequential_scheduler_; 269*14675a02SAndroid Build Coastguard Worker 270*14675a02SAndroid Build Coastguard Worker // Callback to be executed the next time that the sequential scheduler 271*14675a02SAndroid Build Coastguard Worker // becomes idle. 272*14675a02SAndroid Build Coastguard Worker std::function<void()> async_callback_ ABSL_GUARDED_BY(mutex_) = 273*14675a02SAndroid Build Coastguard Worker std::function<void()>(); 274*14675a02SAndroid Build Coastguard Worker // Accumulated value - accessed by sequential tasks only. 275*14675a02SAndroid Build Coastguard Worker std::unique_ptr<T> accumulated_value_; 276*14675a02SAndroid Build Coastguard Worker // Accumulation function - accessed by sequential tasks only. 277*14675a02SAndroid Build Coastguard Worker std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func_; 278*14675a02SAndroid Build Coastguard Worker // Clock used for scheduling delays in parallel tasks 279*14675a02SAndroid Build Coastguard Worker Clock* clock_; 280*14675a02SAndroid Build Coastguard Worker // Remaining number of sequential tasks to be executed - accessed by 281*14675a02SAndroid Build Coastguard Worker // sequential tasks only. 282*14675a02SAndroid Build Coastguard Worker size_t remaining_sequential_tasks_count_ ABSL_GUARDED_BY(mutex_) = 0; 283*14675a02SAndroid Build Coastguard Worker bool has_unobserved_work_ ABSL_GUARDED_BY(mutex_) = false; 284*14675a02SAndroid Build Coastguard Worker 285*14675a02SAndroid Build Coastguard Worker // Number of active calls to either callback function. 286*14675a02SAndroid Build Coastguard Worker size_t active_count_ ABSL_GUARDED_BY(mutex_) = 0; 287*14675a02SAndroid Build Coastguard Worker // This is set to true when the run is aborted. 288*14675a02SAndroid Build Coastguard Worker bool is_cancelled_ ABSL_GUARDED_BY(mutex_) = false; 289*14675a02SAndroid Build Coastguard Worker // Protects active_count_ and cancelled_. 290*14675a02SAndroid Build Coastguard Worker absl::Mutex mutex_; 291*14675a02SAndroid Build Coastguard Worker // Used to notify cancellation about reaching inactive state; 292*14675a02SAndroid Build Coastguard Worker absl::CondVar inactive_cv_; 293*14675a02SAndroid Build Coastguard Worker // This is used by ReentrancyGuard to ensure that Sequential tasks are 294*14675a02SAndroid Build Coastguard Worker // indeed sequential. 295*14675a02SAndroid Build Coastguard Worker std::atomic<bool> in_sequential_call_ = false; 296*14675a02SAndroid Build Coastguard Worker }; 297*14675a02SAndroid Build Coastguard Worker 298*14675a02SAndroid Build Coastguard Worker // Implementation of ParallelGenerateSequentialReduce based on fcp::Scheduler. 299*14675a02SAndroid Build Coastguard Worker // Takes two Schedulers, one which is responsible for parallel execution and 300*14675a02SAndroid Build Coastguard Worker // another for serial execution. Additionally, takes a clock that can be used to 301*14675a02SAndroid Build Coastguard Worker // induce delay in task executions. 302*14675a02SAndroid Build Coastguard Worker class SecAggScheduler { 303*14675a02SAndroid Build Coastguard Worker public: 304*14675a02SAndroid Build Coastguard Worker SecAggScheduler(Scheduler* parallel_scheduler, 305*14675a02SAndroid Build Coastguard Worker Scheduler* sequential_scheduler, 306*14675a02SAndroid Build Coastguard Worker Clock* clock = Clock::RealClock()) parallel_scheduler_(parallel_scheduler)307*14675a02SAndroid Build Coastguard Worker : parallel_scheduler_(parallel_scheduler), 308*14675a02SAndroid Build Coastguard Worker sequential_scheduler_(sequential_scheduler), 309*14675a02SAndroid Build Coastguard Worker clock_(clock) {} 310*14675a02SAndroid Build Coastguard Worker 311*14675a02SAndroid Build Coastguard Worker // SecAggScheduler is neither copyable nor movable. 312*14675a02SAndroid Build Coastguard Worker SecAggScheduler(const SecAggScheduler&) = delete; 313*14675a02SAndroid Build Coastguard Worker SecAggScheduler& operator=(const SecAggScheduler&) = delete; 314*14675a02SAndroid Build Coastguard Worker 315*14675a02SAndroid Build Coastguard Worker virtual ~SecAggScheduler() = default; 316*14675a02SAndroid Build Coastguard Worker 317*14675a02SAndroid Build Coastguard Worker // Schedule a callback to be invoked on the sequential scheduler. ScheduleCallback(std::function<void ()> callback)318*14675a02SAndroid Build Coastguard Worker inline void ScheduleCallback(std::function<void()> callback) { 319*14675a02SAndroid Build Coastguard Worker RunSequential(callback); 320*14675a02SAndroid Build Coastguard Worker } 321*14675a02SAndroid Build Coastguard Worker 322*14675a02SAndroid Build Coastguard Worker template <typename T> CreateAccumulator(std::unique_ptr<T> initial_value,std::function<std::unique_ptr<T> (const T &,const T &)> accumulator_func)323*14675a02SAndroid Build Coastguard Worker std::shared_ptr<Accumulator<T>> CreateAccumulator( 324*14675a02SAndroid Build Coastguard Worker std::unique_ptr<T> initial_value, 325*14675a02SAndroid Build Coastguard Worker std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func) { 326*14675a02SAndroid Build Coastguard Worker return std::make_shared<Accumulator<T>>( 327*14675a02SAndroid Build Coastguard Worker std::move(initial_value), accumulator_func, parallel_scheduler_, 328*14675a02SAndroid Build Coastguard Worker sequential_scheduler_, clock_); 329*14675a02SAndroid Build Coastguard Worker } 330*14675a02SAndroid Build Coastguard Worker 331*14675a02SAndroid Build Coastguard Worker void WaitUntilIdle(); 332*14675a02SAndroid Build Coastguard Worker 333*14675a02SAndroid Build Coastguard Worker protected: 334*14675a02SAndroid Build Coastguard Worker // Virtual for testing 335*14675a02SAndroid Build Coastguard Worker virtual void RunSequential(std::function<void()> function); 336*14675a02SAndroid Build Coastguard Worker 337*14675a02SAndroid Build Coastguard Worker private: 338*14675a02SAndroid Build Coastguard Worker Scheduler* parallel_scheduler_; 339*14675a02SAndroid Build Coastguard Worker Scheduler* sequential_scheduler_; 340*14675a02SAndroid Build Coastguard Worker Clock* clock_; 341*14675a02SAndroid Build Coastguard Worker }; 342*14675a02SAndroid Build Coastguard Worker 343*14675a02SAndroid Build Coastguard Worker } // namespace secagg 344*14675a02SAndroid Build Coastguard Worker } // namespace fcp 345*14675a02SAndroid Build Coastguard Worker 346*14675a02SAndroid Build Coastguard Worker #endif // FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_ 347