1 use crate::enter; 2 use crate::unpark_mutex::UnparkMutex; 3 use futures_core::future::Future; 4 use futures_core::task::{Context, Poll}; 5 use futures_task::{waker_ref, ArcWake}; 6 use futures_task::{FutureObj, Spawn, SpawnError}; 7 use futures_util::future::FutureExt; 8 use std::boxed::Box; 9 use std::cmp; 10 use std::fmt; 11 use std::format; 12 use std::io; 13 use std::string::String; 14 use std::sync::atomic::{AtomicUsize, Ordering}; 15 use std::sync::mpsc; 16 use std::sync::{Arc, Mutex}; 17 use std::thread; 18 19 /// A general-purpose thread pool for scheduling tasks that poll futures to 20 /// completion. 21 /// 22 /// The thread pool multiplexes any number of tasks onto a fixed number of 23 /// worker threads. 24 /// 25 /// This type is a clonable handle to the threadpool itself. 26 /// Cloning it will only create a new reference, not a new threadpool. 27 /// 28 /// This type is only available when the `thread-pool` feature of this 29 /// library is activated. 30 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))] 31 pub struct ThreadPool { 32 state: Arc<PoolState>, 33 } 34 35 /// Thread pool configuration object. 36 /// 37 /// This type is only available when the `thread-pool` feature of this 38 /// library is activated. 39 #[cfg_attr(docsrs, doc(cfg(feature = "thread-pool")))] 40 pub struct ThreadPoolBuilder { 41 pool_size: usize, 42 stack_size: usize, 43 name_prefix: Option<String>, 44 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, 45 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, 46 } 47 48 #[allow(dead_code)] 49 trait AssertSendSync: Send + Sync {} 50 impl AssertSendSync for ThreadPool {} 51 52 struct PoolState { 53 tx: Mutex<mpsc::Sender<Message>>, 54 rx: Mutex<mpsc::Receiver<Message>>, 55 cnt: AtomicUsize, 56 size: usize, 57 } 58 59 impl fmt::Debug for ThreadPool { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result60 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 61 f.debug_struct("ThreadPool").field("size", &self.state.size).finish() 62 } 63 } 64 65 impl fmt::Debug for ThreadPoolBuilder { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result66 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 67 f.debug_struct("ThreadPoolBuilder") 68 .field("pool_size", &self.pool_size) 69 .field("name_prefix", &self.name_prefix) 70 .finish() 71 } 72 } 73 74 enum Message { 75 Run(Task), 76 Close, 77 } 78 79 impl ThreadPool { 80 /// Creates a new thread pool with the default configuration. 81 /// 82 /// See documentation for the methods in 83 /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default 84 /// configuration. new() -> Result<Self, io::Error>85 pub fn new() -> Result<Self, io::Error> { 86 ThreadPoolBuilder::new().create() 87 } 88 89 /// Create a default thread pool configuration, which can then be customized. 90 /// 91 /// See documentation for the methods in 92 /// [`ThreadPoolBuilder`](ThreadPoolBuilder) for details on the default 93 /// configuration. builder() -> ThreadPoolBuilder94 pub fn builder() -> ThreadPoolBuilder { 95 ThreadPoolBuilder::new() 96 } 97 98 /// Spawns a future that will be run to completion. 99 /// 100 /// > **Note**: This method is similar to `Spawn::spawn_obj`, except that 101 /// > it is guaranteed to always succeed. spawn_obj_ok(&self, future: FutureObj<'static, ()>)102 pub fn spawn_obj_ok(&self, future: FutureObj<'static, ()>) { 103 let task = Task { 104 future, 105 wake_handle: Arc::new(WakeHandle { exec: self.clone(), mutex: UnparkMutex::new() }), 106 exec: self.clone(), 107 }; 108 self.state.send(Message::Run(task)); 109 } 110 111 /// Spawns a task that polls the given future with output `()` to 112 /// completion. 113 /// 114 /// ``` 115 /// # { 116 /// use futures::executor::ThreadPool; 117 /// 118 /// let pool = ThreadPool::new().unwrap(); 119 /// 120 /// let future = async { /* ... */ }; 121 /// pool.spawn_ok(future); 122 /// # } 123 /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371 124 /// ``` 125 /// 126 /// > **Note**: This method is similar to `SpawnExt::spawn`, except that 127 /// > it is guaranteed to always succeed. spawn_ok<Fut>(&self, future: Fut) where Fut: Future<Output = ()> + Send + 'static,128 pub fn spawn_ok<Fut>(&self, future: Fut) 129 where 130 Fut: Future<Output = ()> + Send + 'static, 131 { 132 self.spawn_obj_ok(FutureObj::new(Box::new(future))) 133 } 134 } 135 136 impl Spawn for ThreadPool { spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError>137 fn spawn_obj(&self, future: FutureObj<'static, ()>) -> Result<(), SpawnError> { 138 self.spawn_obj_ok(future); 139 Ok(()) 140 } 141 } 142 143 impl PoolState { send(&self, msg: Message)144 fn send(&self, msg: Message) { 145 self.tx.lock().unwrap().send(msg).unwrap(); 146 } 147 work( &self, idx: usize, after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, )148 fn work( 149 &self, 150 idx: usize, 151 after_start: Option<Arc<dyn Fn(usize) + Send + Sync>>, 152 before_stop: Option<Arc<dyn Fn(usize) + Send + Sync>>, 153 ) { 154 let _scope = enter().unwrap(); 155 if let Some(after_start) = after_start { 156 after_start(idx); 157 } 158 loop { 159 let msg = self.rx.lock().unwrap().recv().unwrap(); 160 match msg { 161 Message::Run(task) => task.run(), 162 Message::Close => break, 163 } 164 } 165 if let Some(before_stop) = before_stop { 166 before_stop(idx); 167 } 168 } 169 } 170 171 impl Clone for ThreadPool { clone(&self) -> Self172 fn clone(&self) -> Self { 173 self.state.cnt.fetch_add(1, Ordering::Relaxed); 174 Self { state: self.state.clone() } 175 } 176 } 177 178 impl Drop for ThreadPool { drop(&mut self)179 fn drop(&mut self) { 180 if self.state.cnt.fetch_sub(1, Ordering::Relaxed) == 1 { 181 for _ in 0..self.state.size { 182 self.state.send(Message::Close); 183 } 184 } 185 } 186 } 187 188 impl ThreadPoolBuilder { 189 /// Create a default thread pool configuration. 190 /// 191 /// See the other methods on this type for details on the defaults. new() -> Self192 pub fn new() -> Self { 193 Self { 194 pool_size: cmp::max(1, num_cpus::get()), 195 stack_size: 0, 196 name_prefix: None, 197 after_start: None, 198 before_stop: None, 199 } 200 } 201 202 /// Set size of a future ThreadPool 203 /// 204 /// The size of a thread pool is the number of worker threads spawned. By 205 /// default, this is equal to the number of CPU cores. 206 /// 207 /// # Panics 208 /// 209 /// Panics if `pool_size == 0`. pool_size(&mut self, size: usize) -> &mut Self210 pub fn pool_size(&mut self, size: usize) -> &mut Self { 211 assert!(size > 0); 212 self.pool_size = size; 213 self 214 } 215 216 /// Set stack size of threads in the pool, in bytes. 217 /// 218 /// By default, worker threads use Rust's standard stack size. stack_size(&mut self, stack_size: usize) -> &mut Self219 pub fn stack_size(&mut self, stack_size: usize) -> &mut Self { 220 self.stack_size = stack_size; 221 self 222 } 223 224 /// Set thread name prefix of a future ThreadPool. 225 /// 226 /// Thread name prefix is used for generating thread names. For example, if prefix is 227 /// `my-pool-`, then threads in the pool will get names like `my-pool-1` etc. 228 /// 229 /// By default, worker threads are assigned Rust's standard thread name. name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self230 pub fn name_prefix<S: Into<String>>(&mut self, name_prefix: S) -> &mut Self { 231 self.name_prefix = Some(name_prefix.into()); 232 self 233 } 234 235 /// Execute the closure `f` immediately after each worker thread is started, 236 /// but before running any tasks on it. 237 /// 238 /// This hook is intended for bookkeeping and monitoring. 239 /// The closure `f` will be dropped after the `builder` is dropped 240 /// and all worker threads in the pool have executed it. 241 /// 242 /// The closure provided will receive an index corresponding to the worker 243 /// thread it's running on. after_start<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,244 pub fn after_start<F>(&mut self, f: F) -> &mut Self 245 where 246 F: Fn(usize) + Send + Sync + 'static, 247 { 248 self.after_start = Some(Arc::new(f)); 249 self 250 } 251 252 /// Execute closure `f` just prior to shutting down each worker thread. 253 /// 254 /// This hook is intended for bookkeeping and monitoring. 255 /// The closure `f` will be dropped after the `builder` is dropped 256 /// and all threads in the pool have executed it. 257 /// 258 /// The closure provided will receive an index corresponding to the worker 259 /// thread it's running on. before_stop<F>(&mut self, f: F) -> &mut Self where F: Fn(usize) + Send + Sync + 'static,260 pub fn before_stop<F>(&mut self, f: F) -> &mut Self 261 where 262 F: Fn(usize) + Send + Sync + 'static, 263 { 264 self.before_stop = Some(Arc::new(f)); 265 self 266 } 267 268 /// Create a [`ThreadPool`](ThreadPool) with the given configuration. create(&mut self) -> Result<ThreadPool, io::Error>269 pub fn create(&mut self) -> Result<ThreadPool, io::Error> { 270 let (tx, rx) = mpsc::channel(); 271 let pool = ThreadPool { 272 state: Arc::new(PoolState { 273 tx: Mutex::new(tx), 274 rx: Mutex::new(rx), 275 cnt: AtomicUsize::new(1), 276 size: self.pool_size, 277 }), 278 }; 279 280 for counter in 0..self.pool_size { 281 let state = pool.state.clone(); 282 let after_start = self.after_start.clone(); 283 let before_stop = self.before_stop.clone(); 284 let mut thread_builder = thread::Builder::new(); 285 if let Some(ref name_prefix) = self.name_prefix { 286 thread_builder = thread_builder.name(format!("{}{}", name_prefix, counter)); 287 } 288 if self.stack_size > 0 { 289 thread_builder = thread_builder.stack_size(self.stack_size); 290 } 291 thread_builder.spawn(move || state.work(counter, after_start, before_stop))?; 292 } 293 Ok(pool) 294 } 295 } 296 297 impl Default for ThreadPoolBuilder { default() -> Self298 fn default() -> Self { 299 Self::new() 300 } 301 } 302 303 /// A task responsible for polling a future to completion. 304 struct Task { 305 future: FutureObj<'static, ()>, 306 exec: ThreadPool, 307 wake_handle: Arc<WakeHandle>, 308 } 309 310 struct WakeHandle { 311 mutex: UnparkMutex<Task>, 312 exec: ThreadPool, 313 } 314 315 impl Task { 316 /// Actually run the task (invoking `poll` on the future) on the current 317 /// thread. run(self)318 fn run(self) { 319 let Self { mut future, wake_handle, mut exec } = self; 320 let waker = waker_ref(&wake_handle); 321 let mut cx = Context::from_waker(&waker); 322 323 // Safety: The ownership of this `Task` object is evidence that 324 // we are in the `POLLING`/`REPOLL` state for the mutex. 325 unsafe { 326 wake_handle.mutex.start_poll(); 327 328 loop { 329 let res = future.poll_unpin(&mut cx); 330 match res { 331 Poll::Pending => {} 332 Poll::Ready(()) => return wake_handle.mutex.complete(), 333 } 334 let task = Self { future, wake_handle: wake_handle.clone(), exec }; 335 match wake_handle.mutex.wait(task) { 336 Ok(()) => return, // we've waited 337 Err(task) => { 338 // someone's notified us 339 future = task.future; 340 exec = task.exec; 341 } 342 } 343 } 344 } 345 } 346 } 347 348 impl fmt::Debug for Task { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result349 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 350 f.debug_struct("Task").field("contents", &"...").finish() 351 } 352 } 353 354 impl ArcWake for WakeHandle { wake_by_ref(arc_self: &Arc<Self>)355 fn wake_by_ref(arc_self: &Arc<Self>) { 356 if let Ok(task) = arc_self.mutex.notify() { 357 arc_self.exec.state.send(Message::Run(task)) 358 } 359 } 360 } 361 362 #[cfg(test)] 363 mod tests { 364 use super::*; 365 366 #[test] test_drop_after_start()367 fn test_drop_after_start() { 368 { 369 let (tx, rx) = mpsc::sync_channel(2); 370 let _cpu_pool = ThreadPoolBuilder::new() 371 .pool_size(2) 372 .after_start(move |_| tx.send(1).unwrap()) 373 .create() 374 .unwrap(); 375 376 // After ThreadPoolBuilder is deconstructed, the tx should be dropped 377 // so that we can use rx as an iterator. 378 let count = rx.into_iter().count(); 379 assert_eq!(count, 2); 380 } 381 std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371 382 } 383 } 384