1 /* 2 * Copyright 2019 Google LLC 3 * 4 * Licensed under the Apache License, Version 2.0 (the "License"); 5 * you may not use this file except in compliance with the License. 6 * You may obtain a copy of the License at 7 * 8 * http://www.apache.org/licenses/LICENSE-2.0 9 * 10 * Unless required by applicable law or agreed to in writing, software 11 * distributed under the License is distributed on an "AS IS" BASIS, 12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13 * See the License for the specific language governing permissions and 14 * limitations under the License. 15 */ 16 17 #ifndef FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_ 18 #define FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_ 19 20 #include <atomic> 21 #include <functional> 22 #include <memory> 23 #include <utility> 24 #include <vector> 25 26 #include "absl/synchronization/mutex.h" 27 #include "absl/time/time.h" 28 #include "fcp/base/clock.h" 29 #include "fcp/base/monitoring.h" 30 #include "fcp/base/reentrancy_guard.h" 31 #include "fcp/base/scheduler.h" 32 33 namespace fcp { 34 namespace secagg { 35 36 // Simple callback waiter that runs the function on Wakeup. 37 class CallbackWaiter : public Clock::Waiter { 38 public: CallbackWaiter(std::function<void ()> callback)39 explicit CallbackWaiter(std::function<void()> callback) 40 : callback_(std::move(callback)) {} 41 WakeUp()42 void WakeUp() override { callback_(); } 43 44 private: 45 std::function<void()> callback_; 46 }; 47 48 // Provides Cancellation mechanism for SevAggScheduler. 49 class CancellationImpl { 50 public: 51 virtual ~CancellationImpl() = default; 52 53 // Calling Cancel results in skipping the remaining, still pending 54 // ParallelGenerateSequentialReduce. The call blocks waiting for any 55 // currently active ongoing tasks to complete. Calling Cancel for the second 56 // time has no additional effect. 57 virtual void Cancel() = 0; 58 }; 59 60 using CancellationToken = std::shared_ptr<CancellationImpl>; 61 62 template <typename T> 63 class Accumulator : public CancellationImpl, 64 public std::enable_shared_from_this<Accumulator<T>> { 65 public: Accumulator(std::unique_ptr<T> initial_value,std::function<std::unique_ptr<T> (const T &,const T &)> accumulator_func,Scheduler * parallel_scheduler,Scheduler * sequential_scheduler,Clock * clock)66 Accumulator( 67 std::unique_ptr<T> initial_value, 68 std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func, 69 Scheduler* parallel_scheduler, Scheduler* sequential_scheduler, 70 Clock* clock) 71 : parallel_scheduler_(parallel_scheduler), 72 sequential_scheduler_(sequential_scheduler), 73 accumulated_value_(std::move(initial_value)), 74 accumulator_func_(accumulator_func), 75 clock_(clock) {} 76 GetParallelScheduleFunc(std::shared_ptr<Accumulator<T>> accumulator,std::function<std::unique_ptr<T> ()> generator)77 inline static std::function<void()> GetParallelScheduleFunc( 78 std::shared_ptr<Accumulator<T>> accumulator, 79 std::function<std::unique_ptr<T>()> generator) { 80 return [accumulator, generator] { 81 // Increment active count if the accumulator is not canceled, otherwise 82 // return without scheduling the task. By active count we mean the total 83 // number of scheduled tasks, both parallel and sequential. To cancel an 84 // accumulator, we wait until that this count is 0. 85 if (!accumulator->MaybeIncrementActiveCount()) { 86 return; 87 } 88 auto partial = generator(); 89 FCP_CHECK(partial); 90 // Decrement the count for the parallel task that was just run as 91 // generator(). 92 accumulator->DecrementActiveCount(); 93 // Schedule sequential part of the generator, only if accumulator is not 94 // cancelled, otherwise return without scheduling it. 95 if (accumulator->IsCancelled()) { 96 return; 97 } 98 accumulator->RunSequential( 99 [=, partial = std::shared_ptr<T>(partial.release())] { 100 ReentrancyGuard guard; 101 FCP_CHECK_STATUS(guard.Check(accumulator->in_sequential_call())); 102 // mark that a task will be 103 // scheduled, if the accumulator is 104 // not canceled. 105 if (!accumulator->MaybeIncrementActiveCount()) { 106 return; 107 } 108 auto new_value = accumulator->accumulator_func_( 109 *accumulator->accumulated_value_, *partial); 110 FCP_CHECK(new_value); 111 accumulator->accumulated_value_ = std::move(new_value); 112 // At this point the sequantial task has been run, and we (i) 113 // decrement both active and remaining counts and possibly reset the 114 // unobserved work flag, (ii) get the callback, which might be 115 // empty, and (iii) call it if that is not the case. 116 auto callback = accumulator->UpdateCountsAndGetCallback(); 117 if (callback) { 118 callback(); 119 } 120 }); 121 }; 122 } 123 124 // Schedule a parallel generator that includes a delay. The result of the 125 // generator is fed to the accumulator_func Schedule(std::function<std::unique_ptr<T> ()> generator,absl::Duration delay)126 void Schedule(std::function<std::unique_ptr<T>()> generator, 127 absl::Duration delay) { 128 // IncrementRemainingCount() keeps track of the number of async tasks 129 // scheduled, and sets a flag when the count goes from 0 to 1, corresponding 130 // to a starting batch of unobserved work. 131 auto shared_this = this->shared_from_this(); 132 shared_this->IncrementRemainingCount(); 133 clock_->WakeupWithDeadline( 134 clock_->Now() + delay, 135 std::make_shared<CallbackWaiter>([shared_this, generator] { 136 shared_this->RunParallel( 137 Accumulator<T>::GetParallelScheduleFunc(shared_this, generator)); 138 })); 139 } 140 141 // Schedule a parallel generator. The result of the generator is fed to the 142 // accumulator_func Schedule(std::function<std::unique_ptr<T> ()> generator)143 void Schedule(std::function<std::unique_ptr<T>()> generator) { 144 // IncrementRemainingCount() keeps track of the number of async tasks 145 // scheduled, and sets a flag when the count goes from 0 to 1, corresponding 146 // to a starting batch of unobserved work. 147 auto shared_this = this->shared_from_this(); 148 shared_this->IncrementRemainingCount(); 149 RunParallel([shared_this, generator] { 150 shared_this->GetParallelScheduleFunc(shared_this, generator)(); 151 }); 152 } 153 154 // Returns true if the accumulator doesn't have any remaining tasks, 155 // even if their results have not been observed by a callback. IsIdle()156 bool IsIdle() { 157 absl::MutexLock lock(&mutex_); 158 return remaining_sequential_tasks_count_ == 0; 159 } 160 161 // Returns false if no async work has happened since last time this function 162 // was called, or the first time it is called. Otherwise it returns true and 163 // schedules a callback to be called once the scheduler is idle. SetAsyncObserver(std::function<void ()> async_callback)164 bool SetAsyncObserver(std::function<void()> async_callback) { 165 bool idle; 166 { 167 absl::MutexLock lock(&mutex_); 168 if (!has_unobserved_work_) { 169 return false; 170 } 171 idle = (remaining_sequential_tasks_count_ == 0); 172 if (idle) { 173 // The flag is set to false, and the callback is run as soon as we leave 174 // the mutex's scope. 175 has_unobserved_work_ = false; 176 } else { 177 // The callbak is scheduled for later, as there is ongoing work. 178 async_callback_ = async_callback; 179 } 180 } 181 if (idle) { 182 auto shared_this = this->shared_from_this(); 183 RunSequential([async_callback, shared_this] { async_callback(); }); 184 } 185 return true; 186 } 187 188 // Updates the active and remaining task counts, and returns the callback to 189 // be executed, or nullptr if there's pending async work. UpdateCountsAndGetCallback()190 inline std::function<void()> UpdateCountsAndGetCallback() { 191 absl::MutexLock lock(&mutex_); 192 if (--active_count_ == 0 && is_cancelled_) { 193 inactive_cv_.SignalAll(); 194 } 195 --remaining_sequential_tasks_count_; 196 if (remaining_sequential_tasks_count_ == 0 && async_callback_) { 197 has_unobserved_work_ = false; 198 auto callback = async_callback_; 199 async_callback_ = nullptr; 200 return callback; 201 } else { 202 return nullptr; 203 } 204 } 205 206 // Take the accumulated result and abort any further work. This method can 207 // only be called when the accumulator is idle GetResultAndCancel()208 std::unique_ptr<T> GetResultAndCancel() { 209 absl::MutexLock lock(&mutex_); 210 FCP_CHECK(active_count_ == 0); 211 is_cancelled_ = true; 212 return std::move(accumulated_value_); 213 } 214 215 // CancellationImpl implementation Cancel()216 void Cancel() override { 217 mutex_.Lock(); 218 is_cancelled_ = true; 219 while (active_count_ > 0) { 220 inactive_cv_.Wait(&mutex_); 221 } 222 mutex_.Unlock(); 223 } 224 IsCancelled()225 bool IsCancelled() { 226 absl::MutexLock lock(&mutex_); 227 return is_cancelled_; 228 } 229 MaybeIncrementActiveCount()230 bool MaybeIncrementActiveCount() { 231 absl::MutexLock lock(&mutex_); 232 if (is_cancelled_) { 233 return false; 234 } 235 active_count_++; 236 return true; 237 } 238 DecrementActiveCount()239 size_t DecrementActiveCount() { 240 absl::MutexLock lock(&mutex_); 241 FCP_CHECK(active_count_ > 0); 242 if (--active_count_ == 0 && is_cancelled_) { 243 inactive_cv_.SignalAll(); 244 } 245 return active_count_; 246 } 247 IncrementRemainingCount()248 void IncrementRemainingCount() { 249 absl::MutexLock lock(&mutex_); 250 has_unobserved_work_ |= (remaining_sequential_tasks_count_ == 0); 251 remaining_sequential_tasks_count_++; 252 } 253 in_sequential_call()254 std::atomic<bool>* in_sequential_call() { return &in_sequential_call_; } 255 RunParallel(std::function<void ()> function)256 void inline RunParallel(std::function<void()> function) { 257 parallel_scheduler_->Schedule(function); 258 } 259 RunSequential(std::function<void ()> function)260 void inline RunSequential(std::function<void()> function) { 261 sequential_scheduler_->Schedule(function); 262 } 263 264 private: 265 // Scheduler for sequential and parallel tasks, received from the 266 // SecAggScheduler instatiating this class 267 Scheduler* parallel_scheduler_; 268 Scheduler* sequential_scheduler_; 269 270 // Callback to be executed the next time that the sequential scheduler 271 // becomes idle. 272 std::function<void()> async_callback_ ABSL_GUARDED_BY(mutex_) = 273 std::function<void()>(); 274 // Accumulated value - accessed by sequential tasks only. 275 std::unique_ptr<T> accumulated_value_; 276 // Accumulation function - accessed by sequential tasks only. 277 std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func_; 278 // Clock used for scheduling delays in parallel tasks 279 Clock* clock_; 280 // Remaining number of sequential tasks to be executed - accessed by 281 // sequential tasks only. 282 size_t remaining_sequential_tasks_count_ ABSL_GUARDED_BY(mutex_) = 0; 283 bool has_unobserved_work_ ABSL_GUARDED_BY(mutex_) = false; 284 285 // Number of active calls to either callback function. 286 size_t active_count_ ABSL_GUARDED_BY(mutex_) = 0; 287 // This is set to true when the run is aborted. 288 bool is_cancelled_ ABSL_GUARDED_BY(mutex_) = false; 289 // Protects active_count_ and cancelled_. 290 absl::Mutex mutex_; 291 // Used to notify cancellation about reaching inactive state; 292 absl::CondVar inactive_cv_; 293 // This is used by ReentrancyGuard to ensure that Sequential tasks are 294 // indeed sequential. 295 std::atomic<bool> in_sequential_call_ = false; 296 }; 297 298 // Implementation of ParallelGenerateSequentialReduce based on fcp::Scheduler. 299 // Takes two Schedulers, one which is responsible for parallel execution and 300 // another for serial execution. Additionally, takes a clock that can be used to 301 // induce delay in task executions. 302 class SecAggScheduler { 303 public: 304 SecAggScheduler(Scheduler* parallel_scheduler, 305 Scheduler* sequential_scheduler, 306 Clock* clock = Clock::RealClock()) parallel_scheduler_(parallel_scheduler)307 : parallel_scheduler_(parallel_scheduler), 308 sequential_scheduler_(sequential_scheduler), 309 clock_(clock) {} 310 311 // SecAggScheduler is neither copyable nor movable. 312 SecAggScheduler(const SecAggScheduler&) = delete; 313 SecAggScheduler& operator=(const SecAggScheduler&) = delete; 314 315 virtual ~SecAggScheduler() = default; 316 317 // Schedule a callback to be invoked on the sequential scheduler. ScheduleCallback(std::function<void ()> callback)318 inline void ScheduleCallback(std::function<void()> callback) { 319 RunSequential(callback); 320 } 321 322 template <typename T> CreateAccumulator(std::unique_ptr<T> initial_value,std::function<std::unique_ptr<T> (const T &,const T &)> accumulator_func)323 std::shared_ptr<Accumulator<T>> CreateAccumulator( 324 std::unique_ptr<T> initial_value, 325 std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func) { 326 return std::make_shared<Accumulator<T>>( 327 std::move(initial_value), accumulator_func, parallel_scheduler_, 328 sequential_scheduler_, clock_); 329 } 330 331 void WaitUntilIdle(); 332 333 protected: 334 // Virtual for testing 335 virtual void RunSequential(std::function<void()> function); 336 337 private: 338 Scheduler* parallel_scheduler_; 339 Scheduler* sequential_scheduler_; 340 Clock* clock_; 341 }; 342 343 } // namespace secagg 344 } // namespace fcp 345 346 #endif // FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_ 347