xref: /aosp_15_r20/external/federated-compute/fcp/secagg/server/secagg_scheduler.h (revision 14675a029014e728ec732f129a32e299b2da0601)
1 /*
2  * Copyright 2019 Google LLC
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #ifndef FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
18 #define FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
19 
20 #include <atomic>
21 #include <functional>
22 #include <memory>
23 #include <utility>
24 #include <vector>
25 
26 #include "absl/synchronization/mutex.h"
27 #include "absl/time/time.h"
28 #include "fcp/base/clock.h"
29 #include "fcp/base/monitoring.h"
30 #include "fcp/base/reentrancy_guard.h"
31 #include "fcp/base/scheduler.h"
32 
33 namespace fcp {
34 namespace secagg {
35 
36 // Simple callback waiter that runs the function on Wakeup.
37 class CallbackWaiter : public Clock::Waiter {
38  public:
CallbackWaiter(std::function<void ()> callback)39   explicit CallbackWaiter(std::function<void()> callback)
40       : callback_(std::move(callback)) {}
41 
WakeUp()42   void WakeUp() override { callback_(); }
43 
44  private:
45   std::function<void()> callback_;
46 };
47 
48 // Provides Cancellation mechanism for SevAggScheduler.
49 class CancellationImpl {
50  public:
51   virtual ~CancellationImpl() = default;
52 
53   // Calling Cancel results in skipping the remaining, still pending
54   // ParallelGenerateSequentialReduce. The call blocks waiting for any
55   // currently active ongoing tasks to complete. Calling Cancel for the second
56   // time has no additional effect.
57   virtual void Cancel() = 0;
58 };
59 
60 using CancellationToken = std::shared_ptr<CancellationImpl>;
61 
62 template <typename T>
63 class Accumulator : public CancellationImpl,
64                     public std::enable_shared_from_this<Accumulator<T>> {
65  public:
Accumulator(std::unique_ptr<T> initial_value,std::function<std::unique_ptr<T> (const T &,const T &)> accumulator_func,Scheduler * parallel_scheduler,Scheduler * sequential_scheduler,Clock * clock)66   Accumulator(
67       std::unique_ptr<T> initial_value,
68       std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func,
69       Scheduler* parallel_scheduler, Scheduler* sequential_scheduler,
70       Clock* clock)
71       : parallel_scheduler_(parallel_scheduler),
72         sequential_scheduler_(sequential_scheduler),
73         accumulated_value_(std::move(initial_value)),
74         accumulator_func_(accumulator_func),
75         clock_(clock) {}
76 
GetParallelScheduleFunc(std::shared_ptr<Accumulator<T>> accumulator,std::function<std::unique_ptr<T> ()> generator)77   inline static std::function<void()> GetParallelScheduleFunc(
78       std::shared_ptr<Accumulator<T>> accumulator,
79       std::function<std::unique_ptr<T>()> generator) {
80     return [accumulator, generator] {
81       // Increment active count if the accumulator is not canceled, otherwise
82       // return without scheduling the task. By active count we mean the total
83       // number of scheduled tasks, both parallel and sequential. To cancel an
84       // accumulator, we wait until that this count is 0.
85       if (!accumulator->MaybeIncrementActiveCount()) {
86         return;
87       }
88       auto partial = generator();
89       FCP_CHECK(partial);
90       // Decrement the count for the parallel task that was just run as
91       // generator().
92       accumulator->DecrementActiveCount();
93       // Schedule sequential part of the generator, only if accumulator is not
94       // cancelled, otherwise return without scheduling it.
95       if (accumulator->IsCancelled()) {
96         return;
97       }
98       accumulator->RunSequential(
99           [=, partial = std::shared_ptr<T>(partial.release())] {
100             ReentrancyGuard guard;
101             FCP_CHECK_STATUS(guard.Check(accumulator->in_sequential_call()));
102             // mark that a task will be
103             // scheduled, if the accumulator is
104             // not canceled.
105             if (!accumulator->MaybeIncrementActiveCount()) {
106               return;
107             }
108             auto new_value = accumulator->accumulator_func_(
109                 *accumulator->accumulated_value_, *partial);
110             FCP_CHECK(new_value);
111             accumulator->accumulated_value_ = std::move(new_value);
112             // At this point the sequantial task has been run, and we (i)
113             // decrement both active and remaining counts and possibly reset the
114             // unobserved work flag, (ii) get the callback, which might be
115             // empty, and (iii) call it if that is not the case.
116             auto callback = accumulator->UpdateCountsAndGetCallback();
117             if (callback) {
118               callback();
119             }
120           });
121     };
122   }
123 
124   // Schedule a parallel generator that includes a delay. The result of the
125   // generator is fed to the accumulator_func
Schedule(std::function<std::unique_ptr<T> ()> generator,absl::Duration delay)126   void Schedule(std::function<std::unique_ptr<T>()> generator,
127                 absl::Duration delay) {
128     // IncrementRemainingCount() keeps track of the number of async tasks
129     // scheduled, and sets a flag when the count goes from 0 to 1, corresponding
130     // to a starting batch of unobserved work.
131     auto shared_this = this->shared_from_this();
132     shared_this->IncrementRemainingCount();
133     clock_->WakeupWithDeadline(
134         clock_->Now() + delay,
135         std::make_shared<CallbackWaiter>([shared_this, generator] {
136           shared_this->RunParallel(
137               Accumulator<T>::GetParallelScheduleFunc(shared_this, generator));
138         }));
139   }
140 
141   // Schedule a parallel generator. The result of the generator is fed to the
142   // accumulator_func
Schedule(std::function<std::unique_ptr<T> ()> generator)143   void Schedule(std::function<std::unique_ptr<T>()> generator) {
144     // IncrementRemainingCount() keeps track of the number of async tasks
145     // scheduled, and sets a flag when the count goes from 0 to 1, corresponding
146     // to a starting batch of unobserved work.
147     auto shared_this = this->shared_from_this();
148     shared_this->IncrementRemainingCount();
149     RunParallel([shared_this, generator] {
150       shared_this->GetParallelScheduleFunc(shared_this, generator)();
151     });
152   }
153 
154   // Returns true if the accumulator doesn't have any remaining tasks,
155   // even if their results have not been observed by a callback.
IsIdle()156   bool IsIdle() {
157     absl::MutexLock lock(&mutex_);
158     return remaining_sequential_tasks_count_ == 0;
159   }
160 
161   // Returns false if no async work has happened since last time this function
162   // was called, or the first time it is called. Otherwise it returns true and
163   // schedules a callback to be called once the scheduler is idle.
SetAsyncObserver(std::function<void ()> async_callback)164   bool SetAsyncObserver(std::function<void()> async_callback) {
165     bool idle;
166     {
167       absl::MutexLock lock(&mutex_);
168       if (!has_unobserved_work_) {
169         return false;
170       }
171       idle = (remaining_sequential_tasks_count_ == 0);
172       if (idle) {
173         // The flag is set to false, and the callback is run as soon as we leave
174         // the mutex's scope.
175         has_unobserved_work_ = false;
176       } else {
177         // The callbak is scheduled for later, as there is ongoing work.
178         async_callback_ = async_callback;
179       }
180     }
181     if (idle) {
182       auto shared_this = this->shared_from_this();
183       RunSequential([async_callback, shared_this] { async_callback(); });
184     }
185     return true;
186   }
187 
188   // Updates the active and remaining task counts, and returns the callback to
189   // be executed, or nullptr if there's pending async work.
UpdateCountsAndGetCallback()190   inline std::function<void()> UpdateCountsAndGetCallback() {
191     absl::MutexLock lock(&mutex_);
192     if (--active_count_ == 0 && is_cancelled_) {
193       inactive_cv_.SignalAll();
194     }
195     --remaining_sequential_tasks_count_;
196     if (remaining_sequential_tasks_count_ == 0 && async_callback_) {
197       has_unobserved_work_ = false;
198       auto callback = async_callback_;
199       async_callback_ = nullptr;
200       return callback;
201     } else {
202       return nullptr;
203     }
204   }
205 
206   // Take the accumulated result and abort any further work. This method can
207   // only be called when the accumulator is idle
GetResultAndCancel()208   std::unique_ptr<T> GetResultAndCancel() {
209     absl::MutexLock lock(&mutex_);
210     FCP_CHECK(active_count_ == 0);
211     is_cancelled_ = true;
212     return std::move(accumulated_value_);
213   }
214 
215   // CancellationImpl implementation
Cancel()216   void Cancel() override {
217     mutex_.Lock();
218     is_cancelled_ = true;
219     while (active_count_ > 0) {
220       inactive_cv_.Wait(&mutex_);
221     }
222     mutex_.Unlock();
223   }
224 
IsCancelled()225   bool IsCancelled() {
226     absl::MutexLock lock(&mutex_);
227     return is_cancelled_;
228   }
229 
MaybeIncrementActiveCount()230   bool MaybeIncrementActiveCount() {
231     absl::MutexLock lock(&mutex_);
232     if (is_cancelled_) {
233       return false;
234     }
235     active_count_++;
236     return true;
237   }
238 
DecrementActiveCount()239   size_t DecrementActiveCount() {
240     absl::MutexLock lock(&mutex_);
241     FCP_CHECK(active_count_ > 0);
242     if (--active_count_ == 0 && is_cancelled_) {
243       inactive_cv_.SignalAll();
244     }
245     return active_count_;
246   }
247 
IncrementRemainingCount()248   void IncrementRemainingCount() {
249     absl::MutexLock lock(&mutex_);
250     has_unobserved_work_ |= (remaining_sequential_tasks_count_ == 0);
251     remaining_sequential_tasks_count_++;
252   }
253 
in_sequential_call()254   std::atomic<bool>* in_sequential_call() { return &in_sequential_call_; }
255 
RunParallel(std::function<void ()> function)256   void inline RunParallel(std::function<void()> function) {
257     parallel_scheduler_->Schedule(function);
258   }
259 
RunSequential(std::function<void ()> function)260   void inline RunSequential(std::function<void()> function) {
261     sequential_scheduler_->Schedule(function);
262   }
263 
264  private:
265   // Scheduler for sequential and parallel tasks, received from the
266   // SecAggScheduler instatiating this class
267   Scheduler* parallel_scheduler_;
268   Scheduler* sequential_scheduler_;
269 
270   // Callback to be executed the next time that the sequential scheduler
271   // becomes idle.
272   std::function<void()> async_callback_ ABSL_GUARDED_BY(mutex_) =
273       std::function<void()>();
274   // Accumulated value - accessed by sequential tasks only.
275   std::unique_ptr<T> accumulated_value_;
276   // Accumulation function - accessed by sequential tasks only.
277   std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func_;
278   // Clock used for scheduling delays in parallel tasks
279   Clock* clock_;
280   // Remaining number of sequential tasks to be executed - accessed by
281   // sequential tasks only.
282   size_t remaining_sequential_tasks_count_ ABSL_GUARDED_BY(mutex_) = 0;
283   bool has_unobserved_work_ ABSL_GUARDED_BY(mutex_) = false;
284 
285   // Number of active calls to either callback function.
286   size_t active_count_ ABSL_GUARDED_BY(mutex_) = 0;
287   // This is set to true when the run is aborted.
288   bool is_cancelled_ ABSL_GUARDED_BY(mutex_) = false;
289   // Protects active_count_ and cancelled_.
290   absl::Mutex mutex_;
291   // Used to notify cancellation about reaching inactive state;
292   absl::CondVar inactive_cv_;
293   // This is used by ReentrancyGuard to ensure that Sequential tasks are
294   // indeed sequential.
295   std::atomic<bool> in_sequential_call_ = false;
296 };
297 
298 // Implementation of ParallelGenerateSequentialReduce based on fcp::Scheduler.
299 // Takes two Schedulers, one which is responsible for parallel execution and
300 // another for serial execution. Additionally, takes a clock that can be used to
301 // induce delay in task executions.
302 class SecAggScheduler {
303  public:
304   SecAggScheduler(Scheduler* parallel_scheduler,
305                   Scheduler* sequential_scheduler,
306                   Clock* clock = Clock::RealClock())
parallel_scheduler_(parallel_scheduler)307       : parallel_scheduler_(parallel_scheduler),
308         sequential_scheduler_(sequential_scheduler),
309         clock_(clock) {}
310 
311   // SecAggScheduler is neither copyable nor movable.
312   SecAggScheduler(const SecAggScheduler&) = delete;
313   SecAggScheduler& operator=(const SecAggScheduler&) = delete;
314 
315   virtual ~SecAggScheduler() = default;
316 
317   // Schedule a callback to be invoked on the sequential scheduler.
ScheduleCallback(std::function<void ()> callback)318   inline void ScheduleCallback(std::function<void()> callback) {
319     RunSequential(callback);
320   }
321 
322   template <typename T>
CreateAccumulator(std::unique_ptr<T> initial_value,std::function<std::unique_ptr<T> (const T &,const T &)> accumulator_func)323   std::shared_ptr<Accumulator<T>> CreateAccumulator(
324       std::unique_ptr<T> initial_value,
325       std::function<std::unique_ptr<T>(const T&, const T&)> accumulator_func) {
326     return std::make_shared<Accumulator<T>>(
327         std::move(initial_value), accumulator_func, parallel_scheduler_,
328         sequential_scheduler_, clock_);
329   }
330 
331   void WaitUntilIdle();
332 
333  protected:
334   // Virtual for testing
335   virtual void RunSequential(std::function<void()> function);
336 
337  private:
338   Scheduler* parallel_scheduler_;
339   Scheduler* sequential_scheduler_;
340   Clock* clock_;
341 };
342 
343 }  // namespace secagg
344 }  // namespace fcp
345 
346 #endif  // FCP_SECAGG_SERVER_SECAGG_SCHEDULER_H_
347