xref: /aosp_15_r20/external/federated-compute/fcp/base/scheduler.cc (revision 14675a029014e728ec732f129a32e299b2da0601)
1*14675a02SAndroid Build Coastguard Worker /*
2*14675a02SAndroid Build Coastguard Worker  * Copyright 2018 Google LLC
3*14675a02SAndroid Build Coastguard Worker  *
4*14675a02SAndroid Build Coastguard Worker  * Licensed under the Apache License, Version 2.0 (the "License");
5*14675a02SAndroid Build Coastguard Worker  * you may not use this file except in compliance with the License.
6*14675a02SAndroid Build Coastguard Worker  * You may obtain a copy of the License at
7*14675a02SAndroid Build Coastguard Worker  *
8*14675a02SAndroid Build Coastguard Worker  *      http://www.apache.org/licenses/LICENSE-2.0
9*14675a02SAndroid Build Coastguard Worker  *
10*14675a02SAndroid Build Coastguard Worker  * Unless required by applicable law or agreed to in writing, software
11*14675a02SAndroid Build Coastguard Worker  * distributed under the License is distributed on an "AS IS" BASIS,
12*14675a02SAndroid Build Coastguard Worker  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13*14675a02SAndroid Build Coastguard Worker  * See the License for the specific language governing permissions and
14*14675a02SAndroid Build Coastguard Worker  * limitations under the License.
15*14675a02SAndroid Build Coastguard Worker  */
16*14675a02SAndroid Build Coastguard Worker 
17*14675a02SAndroid Build Coastguard Worker #include "fcp/base/scheduler.h"
18*14675a02SAndroid Build Coastguard Worker 
19*14675a02SAndroid Build Coastguard Worker #include <array>
20*14675a02SAndroid Build Coastguard Worker #include <functional>
21*14675a02SAndroid Build Coastguard Worker #include <memory>
22*14675a02SAndroid Build Coastguard Worker #include <queue>
23*14675a02SAndroid Build Coastguard Worker #include <thread>  // NOLINT(build/c++11)
24*14675a02SAndroid Build Coastguard Worker #include <utility>
25*14675a02SAndroid Build Coastguard Worker #include <vector>
26*14675a02SAndroid Build Coastguard Worker 
27*14675a02SAndroid Build Coastguard Worker #include "absl/synchronization/blocking_counter.h"
28*14675a02SAndroid Build Coastguard Worker #include "absl/synchronization/mutex.h"
29*14675a02SAndroid Build Coastguard Worker 
30*14675a02SAndroid Build Coastguard Worker namespace fcp {
31*14675a02SAndroid Build Coastguard Worker 
32*14675a02SAndroid Build Coastguard Worker namespace {
33*14675a02SAndroid Build Coastguard Worker 
34*14675a02SAndroid Build Coastguard Worker // A helper class to track information about lifetime of an object.
35*14675a02SAndroid Build Coastguard Worker // Uses a shared pointer (SharedMarker) to a boolean memory fragment
36*14675a02SAndroid Build Coastguard Worker // which remembers if the object has been destroyed. Capturing the
37*14675a02SAndroid Build Coastguard Worker // marker in a lambda gives us a clean way to CHECK fail if the
38*14675a02SAndroid Build Coastguard Worker // object is accessed post destruction.
39*14675a02SAndroid Build Coastguard Worker class LifetimeTracker {
40*14675a02SAndroid Build Coastguard Worker  public:
41*14675a02SAndroid Build Coastguard Worker   using SharedMarker = std::shared_ptr<bool>;
LifetimeTracker()42*14675a02SAndroid Build Coastguard Worker   LifetimeTracker() : marker_(std::make_shared<bool>(true)) {}
~LifetimeTracker()43*14675a02SAndroid Build Coastguard Worker   virtual ~LifetimeTracker() { *marker_ = false; }
marker()44*14675a02SAndroid Build Coastguard Worker   SharedMarker& marker() { return marker_; }
45*14675a02SAndroid Build Coastguard Worker 
46*14675a02SAndroid Build Coastguard Worker  private:
47*14675a02SAndroid Build Coastguard Worker   SharedMarker marker_;
48*14675a02SAndroid Build Coastguard Worker };
49*14675a02SAndroid Build Coastguard Worker 
50*14675a02SAndroid Build Coastguard Worker // Implementation of workers.
51*14675a02SAndroid Build Coastguard Worker class WorkerImpl : public Worker, public LifetimeTracker {
52*14675a02SAndroid Build Coastguard Worker  public:
WorkerImpl(Scheduler * scheduler)53*14675a02SAndroid Build Coastguard Worker   explicit WorkerImpl(Scheduler* scheduler) : scheduler_(scheduler) {}
54*14675a02SAndroid Build Coastguard Worker 
55*14675a02SAndroid Build Coastguard Worker   ~WorkerImpl() override = default;
56*14675a02SAndroid Build Coastguard Worker 
Schedule(std::function<void ()> task)57*14675a02SAndroid Build Coastguard Worker   void Schedule(std::function<void()> task) override {
58*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&busy_);
59*14675a02SAndroid Build Coastguard Worker     steps_.emplace_back(std::move(task));
60*14675a02SAndroid Build Coastguard Worker     MaybeRunNext();
61*14675a02SAndroid Build Coastguard Worker   }
62*14675a02SAndroid Build Coastguard Worker 
63*14675a02SAndroid Build Coastguard Worker  private:
MaybeRunNext()64*14675a02SAndroid Build Coastguard Worker   void MaybeRunNext() ABSL_EXCLUSIVE_LOCKS_REQUIRED(busy_) {
65*14675a02SAndroid Build Coastguard Worker     if (running_ || steps_.empty()) {
66*14675a02SAndroid Build Coastguard Worker       // Already running, and next task will be executed when finished, or
67*14675a02SAndroid Build Coastguard Worker       // nothing to run.
68*14675a02SAndroid Build Coastguard Worker       return;
69*14675a02SAndroid Build Coastguard Worker     }
70*14675a02SAndroid Build Coastguard Worker     auto task = std::move(steps_.front());
71*14675a02SAndroid Build Coastguard Worker     steps_.pop_front();
72*14675a02SAndroid Build Coastguard Worker     running_ = true;
73*14675a02SAndroid Build Coastguard Worker     auto wrapped_task = MoveToLambda(std::move(task));
74*14675a02SAndroid Build Coastguard Worker     auto marker = this->marker();
75*14675a02SAndroid Build Coastguard Worker     scheduler_->Schedule([this, marker, wrapped_task] {
76*14675a02SAndroid Build Coastguard Worker       // Call the Task which is stored in wrapped_task.value.
77*14675a02SAndroid Build Coastguard Worker       (*wrapped_task)();
78*14675a02SAndroid Build Coastguard Worker 
79*14675a02SAndroid Build Coastguard Worker       // Run the next task.
80*14675a02SAndroid Build Coastguard Worker       FCP_CHECK(*marker) << "Worker destroyed before all tasks finished";
81*14675a02SAndroid Build Coastguard Worker       {
82*14675a02SAndroid Build Coastguard Worker         // Try run next task if any.
83*14675a02SAndroid Build Coastguard Worker         absl::MutexLock lock(&this->busy_);
84*14675a02SAndroid Build Coastguard Worker         this->running_ = false;
85*14675a02SAndroid Build Coastguard Worker         this->MaybeRunNext();
86*14675a02SAndroid Build Coastguard Worker       }
87*14675a02SAndroid Build Coastguard Worker     });
88*14675a02SAndroid Build Coastguard Worker   }
89*14675a02SAndroid Build Coastguard Worker 
90*14675a02SAndroid Build Coastguard Worker   Scheduler* scheduler_;
91*14675a02SAndroid Build Coastguard Worker   absl::Mutex busy_;
92*14675a02SAndroid Build Coastguard Worker   bool running_ ABSL_GUARDED_BY(busy_) = false;
93*14675a02SAndroid Build Coastguard Worker   std::deque<std::function<void()>> steps_ ABSL_GUARDED_BY(busy_);
94*14675a02SAndroid Build Coastguard Worker };
95*14675a02SAndroid Build Coastguard Worker 
96*14675a02SAndroid Build Coastguard Worker // Implementation of thread pools.
97*14675a02SAndroid Build Coastguard Worker class ThreadPoolScheduler : public Scheduler {
98*14675a02SAndroid Build Coastguard Worker  public:
ThreadPoolScheduler(std::size_t thread_count)99*14675a02SAndroid Build Coastguard Worker   explicit ThreadPoolScheduler(std::size_t thread_count)
100*14675a02SAndroid Build Coastguard Worker       : idle_condition_(absl::Condition(IdleCondition, this)),
101*14675a02SAndroid Build Coastguard Worker         active_count_(thread_count) {
102*14675a02SAndroid Build Coastguard Worker     FCP_CHECK(thread_count > 0) << "invalid thread_count";
103*14675a02SAndroid Build Coastguard Worker 
104*14675a02SAndroid Build Coastguard Worker     // Create threads.
105*14675a02SAndroid Build Coastguard Worker     for (int i = 0; i < thread_count; ++i) {
106*14675a02SAndroid Build Coastguard Worker       threads_.emplace_back(std::thread([this] { this->PerThreadActivity(); }));
107*14675a02SAndroid Build Coastguard Worker     }
108*14675a02SAndroid Build Coastguard Worker   }
109*14675a02SAndroid Build Coastguard Worker 
~ThreadPoolScheduler()110*14675a02SAndroid Build Coastguard Worker   ~ThreadPoolScheduler() override {
111*14675a02SAndroid Build Coastguard Worker     {
112*14675a02SAndroid Build Coastguard Worker       absl::MutexLock lock(&busy_);
113*14675a02SAndroid Build Coastguard Worker       FCP_CHECK(IdleCondition(this))
114*14675a02SAndroid Build Coastguard Worker           << "Thread pool must be idle at destruction time";
115*14675a02SAndroid Build Coastguard Worker 
116*14675a02SAndroid Build Coastguard Worker       threads_should_join_ = true;
117*14675a02SAndroid Build Coastguard Worker       work_available_cond_var_.SignalAll();
118*14675a02SAndroid Build Coastguard Worker     }
119*14675a02SAndroid Build Coastguard Worker 
120*14675a02SAndroid Build Coastguard Worker     for (auto& thread : threads_) {
121*14675a02SAndroid Build Coastguard Worker       FCP_CHECK(thread.joinable()) << "Attempted to destroy a threadpool from "
122*14675a02SAndroid Build Coastguard Worker                                       "one of its running threads";
123*14675a02SAndroid Build Coastguard Worker       thread.join();
124*14675a02SAndroid Build Coastguard Worker     }
125*14675a02SAndroid Build Coastguard Worker   }
126*14675a02SAndroid Build Coastguard Worker 
Schedule(std::function<void ()> task)127*14675a02SAndroid Build Coastguard Worker   void Schedule(std::function<void()> task) override {
128*14675a02SAndroid Build Coastguard Worker     absl::MutexLock lock(&busy_);
129*14675a02SAndroid Build Coastguard Worker     todo_.push(std::move(task));
130*14675a02SAndroid Build Coastguard Worker     // Wake up a *single* thread to handle this task.
131*14675a02SAndroid Build Coastguard Worker     work_available_cond_var_.Signal();
132*14675a02SAndroid Build Coastguard Worker   }
133*14675a02SAndroid Build Coastguard Worker 
WaitUntilIdle()134*14675a02SAndroid Build Coastguard Worker   void WaitUntilIdle() override {
135*14675a02SAndroid Build Coastguard Worker     busy_.LockWhen(idle_condition_);
136*14675a02SAndroid Build Coastguard Worker     busy_.Unlock();
137*14675a02SAndroid Build Coastguard Worker   }
138*14675a02SAndroid Build Coastguard Worker 
IdleCondition(ThreadPoolScheduler * pool)139*14675a02SAndroid Build Coastguard Worker   static bool IdleCondition(ThreadPoolScheduler* pool)
140*14675a02SAndroid Build Coastguard Worker       ABSL_EXCLUSIVE_LOCKS_REQUIRED(pool->busy_) {
141*14675a02SAndroid Build Coastguard Worker     return pool->todo_.empty() && pool->active_count_ == 0;
142*14675a02SAndroid Build Coastguard Worker   }
143*14675a02SAndroid Build Coastguard Worker 
PerThreadActivity()144*14675a02SAndroid Build Coastguard Worker   void PerThreadActivity() {
145*14675a02SAndroid Build Coastguard Worker     for (;;) {
146*14675a02SAndroid Build Coastguard Worker       std::function<void()> task;
147*14675a02SAndroid Build Coastguard Worker       {
148*14675a02SAndroid Build Coastguard Worker         absl::MutexLock lock(&busy_);
149*14675a02SAndroid Build Coastguard Worker         --active_count_;
150*14675a02SAndroid Build Coastguard Worker         while (todo_.empty()) {
151*14675a02SAndroid Build Coastguard Worker           if (threads_should_join_) {
152*14675a02SAndroid Build Coastguard Worker             return;
153*14675a02SAndroid Build Coastguard Worker           }
154*14675a02SAndroid Build Coastguard Worker 
155*14675a02SAndroid Build Coastguard Worker           work_available_cond_var_.Wait(&busy_);
156*14675a02SAndroid Build Coastguard Worker         }
157*14675a02SAndroid Build Coastguard Worker 
158*14675a02SAndroid Build Coastguard Worker         // Destructor invariant
159*14675a02SAndroid Build Coastguard Worker         FCP_CHECK(!threads_should_join_);
160*14675a02SAndroid Build Coastguard Worker         task = std::move(todo_.front());
161*14675a02SAndroid Build Coastguard Worker         todo_.pop();
162*14675a02SAndroid Build Coastguard Worker         ++active_count_;
163*14675a02SAndroid Build Coastguard Worker       }
164*14675a02SAndroid Build Coastguard Worker 
165*14675a02SAndroid Build Coastguard Worker       task();
166*14675a02SAndroid Build Coastguard Worker     }
167*14675a02SAndroid Build Coastguard Worker   }
168*14675a02SAndroid Build Coastguard Worker 
169*14675a02SAndroid Build Coastguard Worker   // A vector of threads allocated for execution.
170*14675a02SAndroid Build Coastguard Worker   std::vector<std::thread> threads_;
171*14675a02SAndroid Build Coastguard Worker 
172*14675a02SAndroid Build Coastguard Worker   // A CondVar used to signal availability of tasks.
173*14675a02SAndroid Build Coastguard Worker   //
174*14675a02SAndroid Build Coastguard Worker   // We would prefer to use the more declarative absl::Condition instead,
175*14675a02SAndroid Build Coastguard Worker   // however, this one only allows to wake up all threads if a new task is
176*14675a02SAndroid Build Coastguard Worker   // available -- but we want to wake up only one in this case.
177*14675a02SAndroid Build Coastguard Worker   absl::CondVar work_available_cond_var_;
178*14675a02SAndroid Build Coastguard Worker 
179*14675a02SAndroid Build Coastguard Worker   // See IdleCondition
180*14675a02SAndroid Build Coastguard Worker   absl::Condition idle_condition_;
181*14675a02SAndroid Build Coastguard Worker 
182*14675a02SAndroid Build Coastguard Worker   // A mutex protecting mutable state in this class.
183*14675a02SAndroid Build Coastguard Worker   absl::Mutex busy_;
184*14675a02SAndroid Build Coastguard Worker 
185*14675a02SAndroid Build Coastguard Worker   // Set when worker threads should join instead of waiting for work.
186*14675a02SAndroid Build Coastguard Worker   bool threads_should_join_ ABSL_GUARDED_BY(busy_) = false;
187*14675a02SAndroid Build Coastguard Worker 
188*14675a02SAndroid Build Coastguard Worker   // Queue of tasks with work to do.
189*14675a02SAndroid Build Coastguard Worker   std::queue<std::function<void()>> todo_ ABSL_GUARDED_BY(busy_);
190*14675a02SAndroid Build Coastguard Worker 
191*14675a02SAndroid Build Coastguard Worker   // The number of threads currently doing work in this pool.
192*14675a02SAndroid Build Coastguard Worker   std::size_t active_count_ ABSL_GUARDED_BY(busy_);
193*14675a02SAndroid Build Coastguard Worker };
194*14675a02SAndroid Build Coastguard Worker 
195*14675a02SAndroid Build Coastguard Worker }  // namespace
196*14675a02SAndroid Build Coastguard Worker 
CreateWorker()197*14675a02SAndroid Build Coastguard Worker std::unique_ptr<Worker> Scheduler::CreateWorker() {
198*14675a02SAndroid Build Coastguard Worker   return std::make_unique<WorkerImpl>(this);
199*14675a02SAndroid Build Coastguard Worker }
200*14675a02SAndroid Build Coastguard Worker 
CreateThreadPoolScheduler(std::size_t thread_count)201*14675a02SAndroid Build Coastguard Worker std::unique_ptr<Scheduler> CreateThreadPoolScheduler(std::size_t thread_count) {
202*14675a02SAndroid Build Coastguard Worker   return std::make_unique<ThreadPoolScheduler>(thread_count);
203*14675a02SAndroid Build Coastguard Worker }
204*14675a02SAndroid Build Coastguard Worker 
205*14675a02SAndroid Build Coastguard Worker }  // namespace fcp
206