xref: /aosp_15_r20/external/crosvm/cros_async/src/blocking/pool.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::collections::VecDeque;
6 use std::future::Future;
7 use std::mem;
8 use std::sync::mpsc::channel;
9 use std::sync::mpsc::Receiver;
10 use std::sync::mpsc::Sender;
11 use std::sync::Arc;
12 use std::thread;
13 use std::thread::JoinHandle;
14 use std::time::Duration;
15 use std::time::Instant;
16 
17 use base::error;
18 use base::warn;
19 use futures::channel::oneshot;
20 use slab::Slab;
21 use sync::Condvar;
22 use sync::Mutex;
23 
24 const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
25 
26 struct State {
27     tasks: VecDeque<Box<dyn FnOnce() + Send>>,
28     num_threads: usize,
29     num_idle: usize,
30     num_notified: usize,
31     worker_threads: Slab<JoinHandle<()>>,
32     exited_threads: Option<Receiver<usize>>,
33     exit: Sender<usize>,
34     shutting_down: bool,
35 }
36 
run_blocking_thread(idx: usize, inner: Arc<Inner>, exit: Sender<usize>)37 fn run_blocking_thread(idx: usize, inner: Arc<Inner>, exit: Sender<usize>) {
38     let mut state = inner.state.lock();
39     while !state.shutting_down {
40         if let Some(f) = state.tasks.pop_front() {
41             drop(state);
42             f();
43             state = inner.state.lock();
44             continue;
45         }
46 
47         // No more tasks so wait for more work.
48         state.num_idle += 1;
49 
50         let (guard, result) = inner
51             .condvar
52             .wait_timeout_while(state, inner.keepalive, |s| {
53                 !s.shutting_down && s.num_notified == 0
54             });
55         state = guard;
56 
57         // If `state.num_notified > 0` then this was a real wakeup.
58         if state.num_notified > 0 {
59             state.num_notified -= 1;
60             continue;
61         }
62 
63         // Only decrement the idle count if we timed out. Otherwise, it was decremented when new
64         // work was added to `state.tasks`.
65         if result.timed_out() {
66             state.num_idle = state
67                 .num_idle
68                 .checked_sub(1)
69                 .expect("`num_idle` underflow on timeout");
70             break;
71         }
72     }
73 
74     state.num_threads -= 1;
75 
76     // If we're shutting down then the BlockingPool will take care of joining all the threads.
77     // Otherwise, we need to join the last worker thread that exited here.
78     let last_exited_thread = if let Some(exited_threads) = state.exited_threads.as_mut() {
79         exited_threads
80             .try_recv()
81             .map(|idx| state.worker_threads.remove(idx))
82             .ok()
83     } else {
84         None
85     };
86 
87     // Drop the lock before trying to join the last exited thread.
88     drop(state);
89 
90     if let Some(handle) = last_exited_thread {
91         let _ = handle.join();
92     }
93 
94     if let Err(e) = exit.send(idx) {
95         error!("Failed to send thread exit event on channel: {}", e);
96     }
97 }
98 
99 struct Inner {
100     state: Mutex<State>,
101     condvar: Condvar,
102     max_threads: usize,
103     keepalive: Duration,
104 }
105 
106 impl Inner {
spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,107     pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
108     where
109         F: FnOnce() -> R + Send + 'static,
110         R: Send + 'static,
111     {
112         let mut state = self.state.lock();
113 
114         // If we're shutting down then nothing is going to run this task.
115         if state.shutting_down {
116             error!("spawn called after shutdown");
117             return futures::future::Either::Left(async {
118                 panic!("tried to poll BlockingPool task after shutdown")
119             });
120         }
121 
122         let (send_chan, recv_chan) = oneshot::channel();
123         state.tasks.push_back(Box::new(|| {
124             let _ = send_chan.send(f());
125         }));
126 
127         if state.num_idle == 0 {
128             // There are no idle threads.  Spawn a new one if possible.
129             if state.num_threads < self.max_threads {
130                 state.num_threads += 1;
131                 let exit = state.exit.clone();
132                 let entry = state.worker_threads.vacant_entry();
133                 let idx = entry.key();
134                 let inner = self.clone();
135                 entry.insert(
136                     thread::Builder::new()
137                         .name(format!("blockingPool{}", idx))
138                         .spawn(move || run_blocking_thread(idx, inner, exit))
139                         .unwrap(),
140                 );
141             }
142         } else {
143             // We have idle threads, wake one up.
144             state.num_idle -= 1;
145             state.num_notified += 1;
146             self.condvar.notify_one();
147         }
148 
149         futures::future::Either::Right(async {
150             recv_chan
151                 .await
152                 .expect("BlockingThread task unexpectedly cancelled")
153         })
154     }
155 }
156 
157 #[derive(Debug, thiserror::Error)]
158 #[error("{0} BlockingPool threads did not exit in time and will be detached")]
159 pub struct ShutdownTimedOut(usize);
160 
161 /// A thread pool for running work that may block.
162 ///
163 /// It is generally discouraged to do any blocking work inside an async function. However, this is
164 /// sometimes unavoidable when dealing with interfaces that don't provide async variants. In this
165 /// case callers may use the `BlockingPool` to run the blocking work on a different thread and
166 /// `await` for its result to finish, which will prevent blocking the main thread of the
167 /// application.
168 ///
169 /// Since the blocking work is sent to another thread, users should be careful when using the
170 /// `BlockingPool` for latency-sensitive operations. Additionally, the `BlockingPool` is intended to
171 /// be used for work that will eventually complete on its own. Users who want to spawn a thread
172 /// should just use `thread::spawn` directly.
173 ///
174 /// There is no way to cancel work once it has been picked up by one of the worker threads in the
175 /// `BlockingPool`. Dropping or shutting down the pool will block up to a timeout (default 10
176 /// seconds) to wait for any active blocking work to finish. Any threads running tasks that have not
177 /// completed by that time will be detached.
178 ///
179 /// # Examples
180 ///
181 /// Spawn a task to run in the `BlockingPool` and await on its result.
182 ///
183 /// ```edition2018
184 /// use cros_async::BlockingPool;
185 ///
186 /// # async fn do_it() {
187 ///     let pool = BlockingPool::default();
188 ///
189 ///     let res = pool.spawn(move || {
190 ///         // Do some CPU-intensive or blocking work here.
191 ///
192 ///         42
193 ///     }).await;
194 ///
195 ///     assert_eq!(res, 42);
196 /// # }
197 /// # cros_async::block_on(do_it());
198 /// ```
199 pub struct BlockingPool {
200     inner: Arc<Inner>,
201 }
202 
203 impl BlockingPool {
204     /// Create a new `BlockingPool`.
205     ///
206     /// The `BlockingPool` will never spawn more than `max_threads` threads to do work, regardless
207     /// of the number of tasks that are added to it. This value should be set relatively low (for
208     /// example, the number of CPUs on the machine) if the pool is intended to run CPU intensive
209     /// work or it should be set relatively high (128 or more) if the pool is intended to be used
210     /// for various IO operations that cannot be completed asynchronously. The default value is 256.
211     ///
212     /// Worker threads are spawned on demand when new work is added to the pool and will
213     /// automatically exit after being idle for some time so there is no overhead for setting
214     /// `max_threads` to a large value when there is little to no work assigned to the
215     /// `BlockingPool`. `keepalive` determines the idle duration after which the worker thread will
216     /// exit. The default value is 10 seconds.
new(max_threads: usize, keepalive: Duration) -> BlockingPool217     pub fn new(max_threads: usize, keepalive: Duration) -> BlockingPool {
218         let (exit, exited_threads) = channel();
219         BlockingPool {
220             inner: Arc::new(Inner {
221                 state: Mutex::new(State {
222                     tasks: VecDeque::new(),
223                     num_threads: 0,
224                     num_idle: 0,
225                     num_notified: 0,
226                     worker_threads: Slab::new(),
227                     exited_threads: Some(exited_threads),
228                     exit,
229                     shutting_down: false,
230                 }),
231                 condvar: Condvar::new(),
232                 max_threads,
233                 keepalive,
234             }),
235         }
236     }
237 
238     /// Like new but with pre-allocating capacity for up to `max_threads`.
with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool239     pub fn with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool {
240         let (exit, exited_threads) = channel();
241         BlockingPool {
242             inner: Arc::new(Inner {
243                 state: Mutex::new(State {
244                     tasks: VecDeque::new(),
245                     num_threads: 0,
246                     num_idle: 0,
247                     num_notified: 0,
248                     worker_threads: Slab::with_capacity(max_threads),
249                     exited_threads: Some(exited_threads),
250                     exit,
251                     shutting_down: false,
252                 }),
253                 condvar: Condvar::new(),
254                 max_threads,
255                 keepalive,
256             }),
257         }
258     }
259 
260     /// Spawn a task to run in the `BlockingPool`.
261     ///
262     /// Callers may `await` the returned `Future` to be notified when the work is completed.
263     /// Dropping the future will not cancel the task.
264     ///
265     /// # Panics
266     ///
267     /// `await`ing a `Task` after dropping the `BlockingPool` or calling `BlockingPool::shutdown`
268     /// will panic if the work was not completed before the pool was shut down.
spawn<F, R>(&self, f: F) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,269     pub fn spawn<F, R>(&self, f: F) -> impl Future<Output = R>
270     where
271         F: FnOnce() -> R + Send + 'static,
272         R: Send + 'static,
273     {
274         self.inner.spawn(f)
275     }
276 
277     /// Shut down the `BlockingPool`.
278     ///
279     /// If `deadline` is provided then this will block until either all worker threads exit or the
280     /// deadline is exceeded. If `deadline` is not given then this will block indefinitely until all
281     /// worker threads exit. Any work that was added to the `BlockingPool` but not yet picked up by
282     /// a worker thread will not complete and `await`ing on the `Task` for that work will panic.
shutdown(&self, deadline: Option<Instant>) -> Result<(), ShutdownTimedOut>283     pub fn shutdown(&self, deadline: Option<Instant>) -> Result<(), ShutdownTimedOut> {
284         let mut state = self.inner.state.lock();
285 
286         if state.shutting_down {
287             // We've already shut down this BlockingPool.
288             return Ok(());
289         }
290 
291         state.shutting_down = true;
292         let exited_threads = state.exited_threads.take().expect("exited_threads missing");
293         let unfinished_tasks = std::mem::take(&mut state.tasks);
294         let mut worker_threads = mem::replace(&mut state.worker_threads, Slab::new());
295         drop(state);
296 
297         self.inner.condvar.notify_all();
298 
299         // Cancel any unfinished work after releasing the lock.
300         drop(unfinished_tasks);
301 
302         // Now wait for all worker threads to exit.
303         if let Some(deadline) = deadline {
304             let mut now = Instant::now();
305             while now < deadline && !worker_threads.is_empty() {
306                 if let Ok(idx) = exited_threads.recv_timeout(deadline - now) {
307                     let _ = worker_threads.remove(idx).join();
308                 }
309                 now = Instant::now();
310             }
311 
312             // Any threads that have not yet joined will just be detached.
313             if !worker_threads.is_empty() {
314                 return Err(ShutdownTimedOut(worker_threads.len()));
315             }
316 
317             Ok(())
318         } else {
319             // Block indefinitely until all worker threads exit.
320             for handle in worker_threads.drain() {
321                 let _ = handle.join();
322             }
323 
324             Ok(())
325         }
326     }
327 
328     #[cfg(test)]
shutting_down(&self) -> bool329     pub(crate) fn shutting_down(&self) -> bool {
330         self.inner.state.lock().shutting_down
331     }
332 }
333 
334 impl Default for BlockingPool {
default() -> BlockingPool335     fn default() -> BlockingPool {
336         BlockingPool::new(256, Duration::from_secs(10))
337     }
338 }
339 
340 impl Drop for BlockingPool {
drop(&mut self)341     fn drop(&mut self) {
342         if let Err(e) = self.shutdown(Some(Instant::now() + DEFAULT_SHUTDOWN_TIMEOUT)) {
343             warn!("{}", e);
344         }
345     }
346 }
347 
348 #[cfg(test)]
349 mod test {
350     use std::sync::Arc;
351     use std::sync::Barrier;
352     use std::thread;
353     use std::time::Duration;
354     use std::time::Instant;
355 
356     use futures::executor::block_on;
357     use futures::stream::FuturesUnordered;
358     use futures::StreamExt;
359     use sync::Condvar;
360     use sync::Mutex;
361 
362     use super::super::super::BlockingPool;
363 
364     #[test]
blocking_sleep()365     fn blocking_sleep() {
366         let pool = BlockingPool::default();
367 
368         let res = block_on(pool.spawn(|| 42));
369         assert_eq!(res, 42);
370     }
371 
372     #[test]
drop_doesnt_block()373     fn drop_doesnt_block() {
374         let pool = BlockingPool::default();
375         let (tx, rx) = std::sync::mpsc::sync_channel(0);
376         // The blocking work should continue even though we drop the future.
377         //
378         // If we cancelled the work, then the recv call would fail. If we blocked on the work, then
379         // the send would never complete because the channel is size zero and so waits for a
380         // matching recv call.
381         std::mem::drop(pool.spawn(move || tx.send(()).unwrap()));
382         rx.recv().unwrap();
383     }
384 
385     #[test]
fast_tasks_with_short_keepalive()386     fn fast_tasks_with_short_keepalive() {
387         let pool = BlockingPool::new(256, Duration::from_millis(1));
388 
389         let streams = FuturesUnordered::new();
390         for _ in 0..2 {
391             for _ in 0..256 {
392                 let task = pool.spawn(|| ());
393                 streams.push(task);
394             }
395 
396             thread::sleep(Duration::from_millis(1));
397         }
398 
399         block_on(streams.collect::<Vec<_>>());
400 
401         // The test passes if there are no panics, which would happen if one of the worker threads
402         // triggered an underflow on `pool.inner.state.num_idle`.
403     }
404 
405     #[test]
more_tasks_than_threads()406     fn more_tasks_than_threads() {
407         let pool = BlockingPool::new(4, Duration::from_secs(10));
408 
409         let stream = (0..19)
410             .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
411             .collect::<FuturesUnordered<_>>();
412 
413         let results = block_on(stream.collect::<Vec<_>>());
414         assert_eq!(results.len(), 19);
415     }
416 
417     #[test]
shutdown()418     fn shutdown() {
419         let pool = BlockingPool::default();
420 
421         let stream = (0..19)
422             .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
423             .collect::<FuturesUnordered<_>>();
424 
425         let results = block_on(stream.collect::<Vec<_>>());
426         assert_eq!(results.len(), 19);
427 
428         pool.shutdown(Some(Instant::now() + Duration::from_secs(10)))
429             .unwrap();
430         let state = pool.inner.state.lock();
431         assert_eq!(state.num_threads, 0);
432     }
433 
434     #[test]
keepalive_timeout()435     fn keepalive_timeout() {
436         // Set the keepalive to a very low value so that threads will exit soon after they run out
437         // of work.
438         let pool = BlockingPool::new(7, Duration::from_millis(1));
439 
440         let stream = (0..19)
441             .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
442             .collect::<FuturesUnordered<_>>();
443 
444         let results = block_on(stream.collect::<Vec<_>>());
445         assert_eq!(results.len(), 19);
446 
447         // Wait for all threads to exit.
448         let deadline = Instant::now() + Duration::from_secs(10);
449         while Instant::now() < deadline {
450             thread::sleep(Duration::from_millis(100));
451             let state = pool.inner.state.lock();
452             if state.num_threads == 0 {
453                 break;
454             }
455         }
456 
457         {
458             let state = pool.inner.state.lock();
459             assert_eq!(state.num_threads, 0);
460             assert_eq!(state.num_idle, 0);
461         }
462     }
463 
464     #[test]
465     #[should_panic]
shutdown_with_pending_work()466     fn shutdown_with_pending_work() {
467         let pool = BlockingPool::new(1, Duration::from_secs(10));
468 
469         let mu = Arc::new(Mutex::new(false));
470         let cv = Arc::new(Condvar::new());
471 
472         // First spawn a thread that blocks the pool.
473         let task_mu = mu.clone();
474         let task_cv = cv.clone();
475         let _blocking_task = pool.spawn(move || {
476             let mut ready = task_mu.lock();
477             while !*ready {
478                 ready = task_cv.wait(ready);
479             }
480         });
481 
482         // This task will never finish because we will shut down the pool first.
483         let unfinished = pool.spawn(|| 5);
484 
485         // Spawn a thread to unblock the work we started earlier once it sees that the pool is
486         // shutting down.
487         let inner = pool.inner.clone();
488         thread::spawn(move || {
489             let mut state = inner.state.lock();
490             while !state.shutting_down {
491                 state = inner.condvar.wait(state);
492             }
493 
494             *mu.lock() = true;
495             cv.notify_all();
496         });
497         pool.shutdown(None).unwrap();
498 
499         // This should panic.
500         assert_eq!(block_on(unfinished), 5);
501     }
502 
503     #[test]
unfinished_worker_thread()504     fn unfinished_worker_thread() {
505         let pool = BlockingPool::default();
506 
507         let ready = Arc::new(Mutex::new(false));
508         let cv = Arc::new(Condvar::new());
509         let barrier = Arc::new(Barrier::new(2));
510 
511         let thread_ready = ready.clone();
512         let thread_barrier = barrier.clone();
513         let thread_cv = cv.clone();
514 
515         let task = pool.spawn(move || {
516             thread_barrier.wait();
517             let mut ready = thread_ready.lock();
518             while !*ready {
519                 ready = thread_cv.wait(ready);
520             }
521         });
522 
523         // Wait to shut down the pool until after the worker thread has started.
524         barrier.wait();
525         pool.shutdown(Some(Instant::now() + Duration::from_millis(5)))
526             .unwrap_err();
527 
528         let num_threads = pool.inner.state.lock().num_threads;
529         assert_eq!(num_threads, 1);
530 
531         // Now wake up the blocked task so we don't leak the thread.
532         *ready.lock() = true;
533         cv.notify_all();
534 
535         block_on(task);
536 
537         let deadline = Instant::now() + Duration::from_secs(10);
538         while Instant::now() < deadline {
539             thread::sleep(Duration::from_millis(100));
540             let state = pool.inner.state.lock();
541             if state.num_threads == 0 {
542                 break;
543             }
544         }
545 
546         {
547             let state = pool.inner.state.lock();
548             assert_eq!(state.num_threads, 0);
549             assert_eq!(state.num_idle, 0);
550         }
551     }
552 }
553