1 use crate::enter;
2 use crate::unpark_mutex::UnparkMutex;
3 use futures_core::future::Future;
4 use futures_core::task::{Context, Poll};
5 use futures_task::{waker_ref, ArcWake};
6 use futures_task::{FutureObj, Spawn, SpawnError};
7 use futures_util::future::FutureExt;
8 use std::boxed::Box;
9 use std::cmp;
10 use std::fmt;
11 use std::format;
12 use std::io;
13 use std::string::String;
14 use std::sync::atomic::{AtomicUsize, Ordering};
15 use std::sync::mpsc;
16 use std::sync::{Arc, Mutex};
17 use std::thread;
18 
19 /// A general-purpose thread pool for scheduling tasks that poll futures to
20 /// completion.
21 ///
22 /// The thread pool multiplexes any number of tasks onto a fixed number of
23 /// worker threads.
24 ///
25 /// This type is a clonable handle to the threadpool itself.
26 /// Cloning it will only create a new reference, not a new threadpool.
27 ///
28 /// This type is only available when the `thread-pool` feature of this
29 /// library is activated.
30 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
31 pub struct ThreadPool {
32     state: Arc<PoolState>,
33 }
34 
35 /// Thread pool configuration object.
36 ///
37 /// This type is only available when the `thread-pool` feature of this
38 /// library is activated.
39 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))]
40 pub struct ThreadPoolBuilder {
41     pool_size: usize,
42     stack_size: usize,
43     name_prefix: Option<String>,
44     after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
45     before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
46 }
47 
48 #[allow(dead_code)]
49 trait AssertSendSync: Send + Sync {}
50 impl AssertSendSync for ThreadPool {}
51 
52 struct PoolState {
53     tx: Mutex<mpsc::Sender<Message>>,
54     rx: Mutex<mpsc::Receiver<Message>>,
55     cnt: AtomicUsize,
56     size: usize,
57 }
58 
59 impl fmt::Debug for ThreadPool {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result60     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
61         f.debug_struct("ThreadPool").field("size", &self.state.size).finish()
62     }
63 }
64 
65 impl fmt::Debug for ThreadPoolBuilder {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result66     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
67         f.debug_struct("ThreadPoolBuilder")
68             .field("pool_size", &self.pool_size)
69             .field("name_prefix", &self.name_prefix)
70             .finish()
71     }
72 }
73 
74 enum Message {
75     Run(Task),
76     Close,
77 }
78 
79 impl ThreadPool {
80     /// Creates a new thread pool with the default configuration.
81     ///
82     /// See documentation for the methods in
83     /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
84     /// configuration.
new() -> Result<Self, io::Error>85     pub fn new() -> Result<Self, io::Error> {
86         ThreadPoolBuilder::new().create()
87     }
88 
89     /// Create a default thread pool configuration, which can then be customized.
90     ///
91     /// See documentation for the methods in
92     /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default
93     /// configuration.
builder() -> ThreadPoolBuilder94     pub fn builder() -> ThreadPoolBuilder {
95         ThreadPoolBuilder::new()
96     }
97 
98     /// Spawns a future that will be run to completion.
99     ///
100     /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that
101     /// >           it is guaranteed to always succeed.
spawn_obj_ok(&self, future: FutureObj<'static, ()>)102     pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) {
103         let task = Task {
104             future,
105             wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }),
106             exec: self.clone(),
107         };
108         self.state.send(Message::Run(task));
109     }
110 
111     /// Spawns a task that polls the given future with output `()` to
112     /// completion.
113     ///
114     /// ```
115     /// # {
116     /// use futures::executor::ThreadPool;
117     ///
118     /// let pool = ThreadPool::new().unwrap();
119     ///
120     /// let future = async { /* ... */ };
121     /// pool.spawn_ok(future);
122     /// # }
123     /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
124     /// ```
125     ///
126     /// > **Note**: This method is similar to `SpawnExt::spawn`, except that
127     /// >           it is guaranteed to always succeed.
spawn_ok<Fut>(&self, future: Fut) where Fut: Future<Output = ()> + Send + 'static,128     pub fn spawn_ok<Fut>(&self, future: Fut)
129     where
130         Fut: Future<Output = ()> + Send + 'static,
131     {
132         self.spawn_obj_ok(FutureObj::new(Box::new(future)))
133     }
134 }
135 
136 impl Spawn for ThreadPool {
spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError>137     fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> {
138         self.spawn_obj_ok(future);
139         Ok(())
140     }
141 }
142 
143 impl PoolState {
send(&self, msg: Message)144     fn send(&self, msg: Message) {
145         self.tx.lock().unwrap().send(msg).unwrap();
146     }
147 
work( &self, idx: usize, after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, )148     fn work(
149         &self,
150         idx: usize,
151         after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>,
152         before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>,
153     ) {
154         let _scope = enter().unwrap();
155         if let Some(after_start) = after_start {
156             after_start(idx);
157         }
158         loop {
159             let msg = self.rx.lock().unwrap().recv().unwrap();
160             match msg {
161                 Message::Run(task) => task.run(),
162                 Message::Close => break,
163             }
164         }
165         if let Some(before_stop) = before_stop {
166             before_stop(idx);
167         }
168     }
169 }
170 
171 impl Clone for ThreadPool {
clone(&self) -> Self172     fn clone(&self) -> Self {
173         self.state.cnt.fetch_add(1, Ordering::Relaxed);
174         Self { state: self.state.clone() }
175     }
176 }
177 
178 impl Drop for ThreadPool {
drop(&mut self)179     fn drop(&mut self) {
180         if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 {
181             for _ in 0..self.state.size {
182                 self.state.send(Message::Close);
183             }
184         }
185     }
186 }
187 
188 impl ThreadPoolBuilder {
189     /// Create a default thread pool configuration.
190     ///
191     /// See the other methods on this type for details on the defaults.
new() -> Self192     pub fn new() -> Self {
193         Self {
194             pool_size: cmp::max(1, num_cpus::get()),
195             stack_size: 0,
196             name_prefix: None,
197             after_start: None,
198             before_stop: None,
199         }
200     }
201 
202     /// Set size of a future ThreadPool
203     ///
204     /// The size of a thread pool is the number of worker threads spawned. By
205     /// default, this is equal to the number of CPU cores.
206     ///
207     /// # Panics
208     ///
209     /// Panics if `pool_size == 0`.
pool_size(&mut self, size: usize) -> &mut Self210     pub fn pool_size(&mut self, size: usize) -> &mut Self {
211         assert!(size > 0);
212         self.pool_size = size;
213         self
214     }
215 
216     /// Set stack size of threads in the pool, in bytes.
217     ///
218     /// By default, worker threads use Rust's standard stack size.
stack_size(&mut self, stack_size: usize) -> &mut Self219     pub fn stack_size(&mut self, stack_size: usize) -> &mut Self {
220         self.stack_size = stack_size;
221         self
222     }
223 
224     /// Set thread name prefix of a future ThreadPool.
225     ///
226     /// Thread name prefix is used for generating thread names. For example, if prefix is
227     /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc.
228     ///
229     /// By default, worker threads are assigned Rust's standard thread name.
name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self230     pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self {
231         self.name_prefix = Some(name_prefix.into());
232         self
233     }
234 
235     /// Execute the closure `f` immediately after each worker thread is started,
236     /// but before running any tasks on it.
237     ///
238     /// This hook is intended for bookkeeping and monitoring.
239     /// The closure `f` will be dropped after the `builder` is dropped
240     /// and all worker threads in the pool have executed it.
241     ///
242     /// The closure provided will receive an index corresponding to the worker
243     /// thread it's running on.
after_start<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,244     pub fn after_start<F>(&mut self, f: F) -> &mut Self
245     where
246         F: Fn(usize) + Send + Sync + 'static,
247     {
248         self.after_start = Some(Arc::new(f));
249         self
250     }
251 
252     /// Execute closure `f` just prior to shutting down each worker thread.
253     ///
254     /// This hook is intended for bookkeeping and monitoring.
255     /// The closure `f` will be dropped after the `builder` is dropped
256     /// and all threads in the pool have executed it.
257     ///
258     /// The closure provided will receive an index corresponding to the worker
259     /// thread it's running on.
before_stop<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,260     pub fn before_stop<F>(&mut self, f: F) -> &mut Self
261     where
262         F: Fn(usize) + Send + Sync + 'static,
263     {
264         self.before_stop = Some(Arc::new(f));
265         self
266     }
267 
268     /// Create a [`ThreadPool`](ThreadPool) with the given configuration.
create(&mut self) -> Result<ThreadPool, io::Error>269     pub fn create(&mut self) -> Result<ThreadPool, io::Error> {
270         let (tx, rx) = mpsc::channel();
271         let pool = ThreadPool {
272             state: Arc::new(PoolState {
273                 tx: Mutex::new(tx),
274                 rx: Mutex::new(rx),
275                 cnt: AtomicUsize::new(1),
276                 size: self.pool_size,
277             }),
278         };
279 
280         for counter in 0..self.pool_size {
281             let state = pool.state.clone();
282             let after_start = self.after_start.clone();
283             let before_stop = self.before_stop.clone();
284             let mut thread_builder = thread::Builder::new();
285             if let Some(ref name_prefix) = self.name_prefix {
286                 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter));
287             }
288             if self.stack_size > 0 {
289                 thread_builder = thread_builder.stack_size(self.stack_size);
290             }
291             thread_builder.spawn(move || state.work(counter, after_start, before_stop))?;
292         }
293         Ok(pool)
294     }
295 }
296 
297 impl Default for ThreadPoolBuilder {
default() -> Self298     fn default() -> Self {
299         Self::new()
300     }
301 }
302 
303 /// A task responsible for polling a future to completion.
304 struct Task {
305     future: FutureObj<'static, ()>,
306     exec: ThreadPool,
307     wake_handle: Arc<WakeHandle>,
308 }
309 
310 struct WakeHandle {
311     mutex: UnparkMutex<Task>,
312     exec: ThreadPool,
313 }
314 
315 impl Task {
316     /// Actually run the task (invoking `poll` on the future) on the current
317     /// thread.
run(self)318     fn run(self) {
319         let Self { mut future, wake_handle, mut exec } = self;
320         let waker = waker_ref(&wake_handle);
321         let mut cx = Context::from_waker(&waker);
322 
323         // Safety: The ownership of this `Task` object is evidence that
324         // we are in the `POLLING`/`REPOLL` state for the mutex.
325         unsafe {
326             wake_handle.mutex.start_poll();
327 
328             loop {
329                 let res = future.poll_unpin(&mut cx);
330                 match res {
331                     Poll::Pending => {}
332                     Poll::Ready(()) => return wake_handle.mutex.complete(),
333                 }
334                 let task = Self { future, wake_handle: wake_handle.clone(), exec };
335                 match wake_handle.mutex.wait(task) {
336                     Ok(()) => return, // we've waited
337                     Err(task) => {
338                         // someone's notified us
339                         future = task.future;
340                         exec = task.exec;
341                     }
342                 }
343             }
344         }
345     }
346 }
347 
348 impl fmt::Debug for Task {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result349     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
350         f.debug_struct("Task").field("contents", &"...").finish()
351     }
352 }
353 
354 impl ArcWake for WakeHandle {
wake_by_ref(arc_self: &Arc<Self>)355     fn wake_by_ref(arc_self: &Arc<Self>) {
356         if let Ok(task) = arc_self.mutex.notify() {
357             arc_self.exec.state.send(Message::Run(task))
358         }
359     }
360 }
361 
362 #[cfg(test)]
363 mod tests {
364     use super::*;
365 
366     #[test]
test_drop_after_start()367     fn test_drop_after_start() {
368         {
369             let (tx, rx) = mpsc::sync_channel(2);
370             let _cpu_pool = ThreadPoolBuilder::new()
371                 .pool_size(2)
372                 .after_start(move |_| tx.send(1).unwrap())
373                 .create()
374                 .unwrap();
375 
376             // After ThreadPoolBuilder is deconstructed, the tx should be dropped
377             // so that we can use rx as an iterator.
378             let count = rx.into_iter().count();
379             assert_eq!(count, 2);
380         }
381         std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
382     }
383 }
384