1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::collections::VecDeque;
6 use std::future::Future;
7 use std::mem;
8 use std::sync::mpsc::channel;
9 use std::sync::mpsc::Receiver;
10 use std::sync::mpsc::Sender;
11 use std::sync::Arc;
12 use std::thread;
13 use std::thread::JoinHandle;
14 use std::time::Duration;
15 use std::time::Instant;
16
17 use base::error;
18 use base::warn;
19 use futures::channel::oneshot;
20 use slab::Slab;
21 use sync::Condvar;
22 use sync::Mutex;
23
24 const DEFAULT_SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(10);
25
26 struct State {
27 tasks: VecDeque<Box<dyn FnOnce() + Send>>,
28 num_threads: usize,
29 num_idle: usize,
30 num_notified: usize,
31 worker_threads: Slab<JoinHandle<()>>,
32 exited_threads: Option<Receiver<usize>>,
33 exit: Sender<usize>,
34 shutting_down: bool,
35 }
36
run_blocking_thread(idx: usize, inner: Arc<Inner>, exit: Sender<usize>)37 fn run_blocking_thread(idx: usize, inner: Arc<Inner>, exit: Sender<usize>) {
38 let mut state = inner.state.lock();
39 while !state.shutting_down {
40 if let Some(f) = state.tasks.pop_front() {
41 drop(state);
42 f();
43 state = inner.state.lock();
44 continue;
45 }
46
47 // No more tasks so wait for more work.
48 state.num_idle += 1;
49
50 let (guard, result) = inner
51 .condvar
52 .wait_timeout_while(state, inner.keepalive, |s| {
53 !s.shutting_down && s.num_notified == 0
54 });
55 state = guard;
56
57 // If `state.num_notified > 0` then this was a real wakeup.
58 if state.num_notified > 0 {
59 state.num_notified -= 1;
60 continue;
61 }
62
63 // Only decrement the idle count if we timed out. Otherwise, it was decremented when new
64 // work was added to `state.tasks`.
65 if result.timed_out() {
66 state.num_idle = state
67 .num_idle
68 .checked_sub(1)
69 .expect("`num_idle` underflow on timeout");
70 break;
71 }
72 }
73
74 state.num_threads -= 1;
75
76 // If we're shutting down then the BlockingPool will take care of joining all the threads.
77 // Otherwise, we need to join the last worker thread that exited here.
78 let last_exited_thread = if let Some(exited_threads) = state.exited_threads.as_mut() {
79 exited_threads
80 .try_recv()
81 .map(|idx| state.worker_threads.remove(idx))
82 .ok()
83 } else {
84 None
85 };
86
87 // Drop the lock before trying to join the last exited thread.
88 drop(state);
89
90 if let Some(handle) = last_exited_thread {
91 let _ = handle.join();
92 }
93
94 if let Err(e) = exit.send(idx) {
95 error!("Failed to send thread exit event on channel: {}", e);
96 }
97 }
98
99 struct Inner {
100 state: Mutex<State>,
101 condvar: Condvar,
102 max_threads: usize,
103 keepalive: Duration,
104 }
105
106 impl Inner {
spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,107 pub fn spawn<F, R>(self: &Arc<Self>, f: F) -> impl Future<Output = R>
108 where
109 F: FnOnce() -> R + Send + 'static,
110 R: Send + 'static,
111 {
112 let mut state = self.state.lock();
113
114 // If we're shutting down then nothing is going to run this task.
115 if state.shutting_down {
116 error!("spawn called after shutdown");
117 return futures::future::Either::Left(async {
118 panic!("tried to poll BlockingPool task after shutdown")
119 });
120 }
121
122 let (send_chan, recv_chan) = oneshot::channel();
123 state.tasks.push_back(Box::new(|| {
124 let _ = send_chan.send(f());
125 }));
126
127 if state.num_idle == 0 {
128 // There are no idle threads. Spawn a new one if possible.
129 if state.num_threads < self.max_threads {
130 state.num_threads += 1;
131 let exit = state.exit.clone();
132 let entry = state.worker_threads.vacant_entry();
133 let idx = entry.key();
134 let inner = self.clone();
135 entry.insert(
136 thread::Builder::new()
137 .name(format!("blockingPool{}", idx))
138 .spawn(move || run_blocking_thread(idx, inner, exit))
139 .unwrap(),
140 );
141 }
142 } else {
143 // We have idle threads, wake one up.
144 state.num_idle -= 1;
145 state.num_notified += 1;
146 self.condvar.notify_one();
147 }
148
149 futures::future::Either::Right(async {
150 recv_chan
151 .await
152 .expect("BlockingThread task unexpectedly cancelled")
153 })
154 }
155 }
156
157 #[derive(Debug, thiserror::Error)]
158 #[error("{0} BlockingPool threads did not exit in time and will be detached")]
159 pub struct ShutdownTimedOut(usize);
160
161 /// A thread pool for running work that may block.
162 ///
163 /// It is generally discouraged to do any blocking work inside an async function. However, this is
164 /// sometimes unavoidable when dealing with interfaces that don't provide async variants. In this
165 /// case callers may use the `BlockingPool` to run the blocking work on a different thread and
166 /// `await` for its result to finish, which will prevent blocking the main thread of the
167 /// application.
168 ///
169 /// Since the blocking work is sent to another thread, users should be careful when using the
170 /// `BlockingPool` for latency-sensitive operations. Additionally, the `BlockingPool` is intended to
171 /// be used for work that will eventually complete on its own. Users who want to spawn a thread
172 /// should just use `thread::spawn` directly.
173 ///
174 /// There is no way to cancel work once it has been picked up by one of the worker threads in the
175 /// `BlockingPool`. Dropping or shutting down the pool will block up to a timeout (default 10
176 /// seconds) to wait for any active blocking work to finish. Any threads running tasks that have not
177 /// completed by that time will be detached.
178 ///
179 /// # Examples
180 ///
181 /// Spawn a task to run in the `BlockingPool` and await on its result.
182 ///
183 /// ```edition2018
184 /// use cros_async::BlockingPool;
185 ///
186 /// # async fn do_it() {
187 /// let pool = BlockingPool::default();
188 ///
189 /// let res = pool.spawn(move || {
190 /// // Do some CPU-intensive or blocking work here.
191 ///
192 /// 42
193 /// }).await;
194 ///
195 /// assert_eq!(res, 42);
196 /// # }
197 /// # cros_async::block_on(do_it());
198 /// ```
199 pub struct BlockingPool {
200 inner: Arc<Inner>,
201 }
202
203 impl BlockingPool {
204 /// Create a new `BlockingPool`.
205 ///
206 /// The `BlockingPool` will never spawn more than `max_threads` threads to do work, regardless
207 /// of the number of tasks that are added to it. This value should be set relatively low (for
208 /// example, the number of CPUs on the machine) if the pool is intended to run CPU intensive
209 /// work or it should be set relatively high (128 or more) if the pool is intended to be used
210 /// for various IO operations that cannot be completed asynchronously. The default value is 256.
211 ///
212 /// Worker threads are spawned on demand when new work is added to the pool and will
213 /// automatically exit after being idle for some time so there is no overhead for setting
214 /// `max_threads` to a large value when there is little to no work assigned to the
215 /// `BlockingPool`. `keepalive` determines the idle duration after which the worker thread will
216 /// exit. The default value is 10 seconds.
new(max_threads: usize, keepalive: Duration) -> BlockingPool217 pub fn new(max_threads: usize, keepalive: Duration) -> BlockingPool {
218 let (exit, exited_threads) = channel();
219 BlockingPool {
220 inner: Arc::new(Inner {
221 state: Mutex::new(State {
222 tasks: VecDeque::new(),
223 num_threads: 0,
224 num_idle: 0,
225 num_notified: 0,
226 worker_threads: Slab::new(),
227 exited_threads: Some(exited_threads),
228 exit,
229 shutting_down: false,
230 }),
231 condvar: Condvar::new(),
232 max_threads,
233 keepalive,
234 }),
235 }
236 }
237
238 /// Like new but with pre-allocating capacity for up to `max_threads`.
with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool239 pub fn with_capacity(max_threads: usize, keepalive: Duration) -> BlockingPool {
240 let (exit, exited_threads) = channel();
241 BlockingPool {
242 inner: Arc::new(Inner {
243 state: Mutex::new(State {
244 tasks: VecDeque::new(),
245 num_threads: 0,
246 num_idle: 0,
247 num_notified: 0,
248 worker_threads: Slab::with_capacity(max_threads),
249 exited_threads: Some(exited_threads),
250 exit,
251 shutting_down: false,
252 }),
253 condvar: Condvar::new(),
254 max_threads,
255 keepalive,
256 }),
257 }
258 }
259
260 /// Spawn a task to run in the `BlockingPool`.
261 ///
262 /// Callers may `await` the returned `Future` to be notified when the work is completed.
263 /// Dropping the future will not cancel the task.
264 ///
265 /// # Panics
266 ///
267 /// `await`ing a `Task` after dropping the `BlockingPool` or calling `BlockingPool::shutdown`
268 /// will panic if the work was not completed before the pool was shut down.
spawn<F, R>(&self, f: F) -> impl Future<Output = R> where F: FnOnce() -> R + Send + 'static, R: Send + 'static,269 pub fn spawn<F, R>(&self, f: F) -> impl Future<Output = R>
270 where
271 F: FnOnce() -> R + Send + 'static,
272 R: Send + 'static,
273 {
274 self.inner.spawn(f)
275 }
276
277 /// Shut down the `BlockingPool`.
278 ///
279 /// If `deadline` is provided then this will block until either all worker threads exit or the
280 /// deadline is exceeded. If `deadline` is not given then this will block indefinitely until all
281 /// worker threads exit. Any work that was added to the `BlockingPool` but not yet picked up by
282 /// a worker thread will not complete and `await`ing on the `Task` for that work will panic.
shutdown(&self, deadline: Option<Instant>) -> Result<(), ShutdownTimedOut>283 pub fn shutdown(&self, deadline: Option<Instant>) -> Result<(), ShutdownTimedOut> {
284 let mut state = self.inner.state.lock();
285
286 if state.shutting_down {
287 // We've already shut down this BlockingPool.
288 return Ok(());
289 }
290
291 state.shutting_down = true;
292 let exited_threads = state.exited_threads.take().expect("exited_threads missing");
293 let unfinished_tasks = std::mem::take(&mut state.tasks);
294 let mut worker_threads = mem::replace(&mut state.worker_threads, Slab::new());
295 drop(state);
296
297 self.inner.condvar.notify_all();
298
299 // Cancel any unfinished work after releasing the lock.
300 drop(unfinished_tasks);
301
302 // Now wait for all worker threads to exit.
303 if let Some(deadline) = deadline {
304 let mut now = Instant::now();
305 while now < deadline && !worker_threads.is_empty() {
306 if let Ok(idx) = exited_threads.recv_timeout(deadline - now) {
307 let _ = worker_threads.remove(idx).join();
308 }
309 now = Instant::now();
310 }
311
312 // Any threads that have not yet joined will just be detached.
313 if !worker_threads.is_empty() {
314 return Err(ShutdownTimedOut(worker_threads.len()));
315 }
316
317 Ok(())
318 } else {
319 // Block indefinitely until all worker threads exit.
320 for handle in worker_threads.drain() {
321 let _ = handle.join();
322 }
323
324 Ok(())
325 }
326 }
327
328 #[cfg(test)]
shutting_down(&self) -> bool329 pub(crate) fn shutting_down(&self) -> bool {
330 self.inner.state.lock().shutting_down
331 }
332 }
333
334 impl Default for BlockingPool {
default() -> BlockingPool335 fn default() -> BlockingPool {
336 BlockingPool::new(256, Duration::from_secs(10))
337 }
338 }
339
340 impl Drop for BlockingPool {
drop(&mut self)341 fn drop(&mut self) {
342 if let Err(e) = self.shutdown(Some(Instant::now() + DEFAULT_SHUTDOWN_TIMEOUT)) {
343 warn!("{}", e);
344 }
345 }
346 }
347
348 #[cfg(test)]
349 mod test {
350 use std::sync::Arc;
351 use std::sync::Barrier;
352 use std::thread;
353 use std::time::Duration;
354 use std::time::Instant;
355
356 use futures::executor::block_on;
357 use futures::stream::FuturesUnordered;
358 use futures::StreamExt;
359 use sync::Condvar;
360 use sync::Mutex;
361
362 use super::super::super::BlockingPool;
363
364 #[test]
blocking_sleep()365 fn blocking_sleep() {
366 let pool = BlockingPool::default();
367
368 let res = block_on(pool.spawn(|| 42));
369 assert_eq!(res, 42);
370 }
371
372 #[test]
drop_doesnt_block()373 fn drop_doesnt_block() {
374 let pool = BlockingPool::default();
375 let (tx, rx) = std::sync::mpsc::sync_channel(0);
376 // The blocking work should continue even though we drop the future.
377 //
378 // If we cancelled the work, then the recv call would fail. If we blocked on the work, then
379 // the send would never complete because the channel is size zero and so waits for a
380 // matching recv call.
381 std::mem::drop(pool.spawn(move || tx.send(()).unwrap()));
382 rx.recv().unwrap();
383 }
384
385 #[test]
fast_tasks_with_short_keepalive()386 fn fast_tasks_with_short_keepalive() {
387 let pool = BlockingPool::new(256, Duration::from_millis(1));
388
389 let streams = FuturesUnordered::new();
390 for _ in 0..2 {
391 for _ in 0..256 {
392 let task = pool.spawn(|| ());
393 streams.push(task);
394 }
395
396 thread::sleep(Duration::from_millis(1));
397 }
398
399 block_on(streams.collect::<Vec<_>>());
400
401 // The test passes if there are no panics, which would happen if one of the worker threads
402 // triggered an underflow on `pool.inner.state.num_idle`.
403 }
404
405 #[test]
more_tasks_than_threads()406 fn more_tasks_than_threads() {
407 let pool = BlockingPool::new(4, Duration::from_secs(10));
408
409 let stream = (0..19)
410 .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
411 .collect::<FuturesUnordered<_>>();
412
413 let results = block_on(stream.collect::<Vec<_>>());
414 assert_eq!(results.len(), 19);
415 }
416
417 #[test]
shutdown()418 fn shutdown() {
419 let pool = BlockingPool::default();
420
421 let stream = (0..19)
422 .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
423 .collect::<FuturesUnordered<_>>();
424
425 let results = block_on(stream.collect::<Vec<_>>());
426 assert_eq!(results.len(), 19);
427
428 pool.shutdown(Some(Instant::now() + Duration::from_secs(10)))
429 .unwrap();
430 let state = pool.inner.state.lock();
431 assert_eq!(state.num_threads, 0);
432 }
433
434 #[test]
keepalive_timeout()435 fn keepalive_timeout() {
436 // Set the keepalive to a very low value so that threads will exit soon after they run out
437 // of work.
438 let pool = BlockingPool::new(7, Duration::from_millis(1));
439
440 let stream = (0..19)
441 .map(|_| pool.spawn(|| thread::sleep(Duration::from_millis(5))))
442 .collect::<FuturesUnordered<_>>();
443
444 let results = block_on(stream.collect::<Vec<_>>());
445 assert_eq!(results.len(), 19);
446
447 // Wait for all threads to exit.
448 let deadline = Instant::now() + Duration::from_secs(10);
449 while Instant::now() < deadline {
450 thread::sleep(Duration::from_millis(100));
451 let state = pool.inner.state.lock();
452 if state.num_threads == 0 {
453 break;
454 }
455 }
456
457 {
458 let state = pool.inner.state.lock();
459 assert_eq!(state.num_threads, 0);
460 assert_eq!(state.num_idle, 0);
461 }
462 }
463
464 #[test]
465 #[should_panic]
shutdown_with_pending_work()466 fn shutdown_with_pending_work() {
467 let pool = BlockingPool::new(1, Duration::from_secs(10));
468
469 let mu = Arc::new(Mutex::new(false));
470 let cv = Arc::new(Condvar::new());
471
472 // First spawn a thread that blocks the pool.
473 let task_mu = mu.clone();
474 let task_cv = cv.clone();
475 let _blocking_task = pool.spawn(move || {
476 let mut ready = task_mu.lock();
477 while !*ready {
478 ready = task_cv.wait(ready);
479 }
480 });
481
482 // This task will never finish because we will shut down the pool first.
483 let unfinished = pool.spawn(|| 5);
484
485 // Spawn a thread to unblock the work we started earlier once it sees that the pool is
486 // shutting down.
487 let inner = pool.inner.clone();
488 thread::spawn(move || {
489 let mut state = inner.state.lock();
490 while !state.shutting_down {
491 state = inner.condvar.wait(state);
492 }
493
494 *mu.lock() = true;
495 cv.notify_all();
496 });
497 pool.shutdown(None).unwrap();
498
499 // This should panic.
500 assert_eq!(block_on(unfinished), 5);
501 }
502
503 #[test]
unfinished_worker_thread()504 fn unfinished_worker_thread() {
505 let pool = BlockingPool::default();
506
507 let ready = Arc::new(Mutex::new(false));
508 let cv = Arc::new(Condvar::new());
509 let barrier = Arc::new(Barrier::new(2));
510
511 let thread_ready = ready.clone();
512 let thread_barrier = barrier.clone();
513 let thread_cv = cv.clone();
514
515 let task = pool.spawn(move || {
516 thread_barrier.wait();
517 let mut ready = thread_ready.lock();
518 while !*ready {
519 ready = thread_cv.wait(ready);
520 }
521 });
522
523 // Wait to shut down the pool until after the worker thread has started.
524 barrier.wait();
525 pool.shutdown(Some(Instant::now() + Duration::from_millis(5)))
526 .unwrap_err();
527
528 let num_threads = pool.inner.state.lock().num_threads;
529 assert_eq!(num_threads, 1);
530
531 // Now wake up the blocked task so we don't leak the thread.
532 *ready.lock() = true;
533 cv.notify_all();
534
535 block_on(task);
536
537 let deadline = Instant::now() + Duration::from_secs(10);
538 while Instant::now() < deadline {
539 thread::sleep(Duration::from_millis(100));
540 let state = pool.inner.state.lock();
541 if state.num_threads == 0 {
542 break;
543 }
544 }
545
546 {
547 let state = pool.inner.state.lock();
548 assert_eq!(state.num_threads, 0);
549 assert_eq!(state.num_idle, 0);
550 }
551 }
552 }
553