1 // Copyright 2014 The Rust Project Developers. See the COPYRIGHT
2 // file at the top-level directory of this distribution and at
3 // http://rust-lang.org/COPYRIGHT.
4 //
5 // Licensed under the Apache License, Version 2.0 <LICENSE-APACHE or
6 // http://www.apache.org/licenses/LICENSE-2.0> or the MIT license
7 // <LICENSE-MIT or http://opensource.org/licenses/MIT>, at your
8 // option. This file may not be copied, modified, or distributed
9 // except according to those terms.
10 
11 //! A thread pool used to execute functions in parallel.
12 //!
13 //! Spawns a specified number of worker threads and replenishes the pool if any worker threads
14 //! panic.
15 //!
16 //! # Examples
17 //!
18 //! ## Synchronized with a channel
19 //!
20 //! Every thread sends one message over the channel, which then is collected with the `take()`.
21 //!
22 //! ```
23 //! use threadpool::ThreadPool;
24 //! use std::sync::mpsc::channel;
25 //!
26 //! let n_workers = 4;
27 //! let n_jobs = 8;
28 //! let pool = ThreadPool::new(n_workers);
29 //!
30 //! let (tx, rx) = channel();
31 //! for _ in 0..n_jobs {
32 //!     let tx = tx.clone();
33 //!     pool.execute(move|| {
34 //!         tx.send(1).expect("channel will be there waiting for the pool");
35 //!     });
36 //! }
37 //!
38 //! assert_eq!(rx.iter().take(n_jobs).fold(0, |a, b| a + b), 8);
39 //! ```
40 //!
41 //! ## Synchronized with a barrier
42 //!
43 //! Keep in mind, if a barrier synchronizes more jobs than you have workers in the pool,
44 //! you will end up with a [deadlock](https://en.wikipedia.org/wiki/Deadlock)
45 //! at the barrier which is [not considered unsafe](
46 //! https://doc.rust-lang.org/reference/behavior-not-considered-unsafe.html).
47 //!
48 //! ```
49 //! use threadpool::ThreadPool;
50 //! use std::sync::{Arc, Barrier};
51 //! use std::sync::atomic::{AtomicUsize, Ordering};
52 //!
53 //! // create at least as many workers as jobs or you will deadlock yourself
54 //! let n_workers = 42;
55 //! let n_jobs = 23;
56 //! let pool = ThreadPool::new(n_workers);
57 //! let an_atomic = Arc::new(AtomicUsize::new(0));
58 //!
59 //! assert!(n_jobs <= n_workers, "too many jobs, will deadlock");
60 //!
61 //! // create a barrier that waits for all jobs plus the starter thread
62 //! let barrier = Arc::new(Barrier::new(n_jobs + 1));
63 //! for _ in 0..n_jobs {
64 //!     let barrier = barrier.clone();
65 //!     let an_atomic = an_atomic.clone();
66 //!
67 //!     pool.execute(move|| {
68 //!         // do the heavy work
69 //!         an_atomic.fetch_add(1, Ordering::Relaxed);
70 //!
71 //!         // then wait for the other threads
72 //!         barrier.wait();
73 //!     });
74 //! }
75 //!
76 //! // wait for the threads to finish the work
77 //! barrier.wait();
78 //! assert_eq!(an_atomic.load(Ordering::SeqCst), /* n_jobs = */ 23);
79 //! ```
80 
81 extern crate num_cpus;
82 
83 use std::fmt;
84 use std::sync::atomic::{AtomicUsize, Ordering};
85 use std::sync::mpsc::{channel, Receiver, Sender};
86 use std::sync::{Arc, Condvar, Mutex};
87 use std::thread;
88 
89 trait FnBox {
call_box(self: Box<Self>)90     fn call_box(self: Box<Self>);
91 }
92 
93 impl<F: FnOnce()> FnBox for F {
call_box(self: Box<F>)94     fn call_box(self: Box<F>) {
95         (*self)()
96     }
97 }
98 
99 type Thunk<'a> = Box<FnBox + Send + 'a>;
100 
101 struct Sentinel<'a> {
102     shared_data: &'a Arc<ThreadPoolSharedData>,
103     active: bool,
104 }
105 
106 impl<'a> Sentinel<'a> {
new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a>107     fn new(shared_data: &'a Arc<ThreadPoolSharedData>) -> Sentinel<'a> {
108         Sentinel {
109             shared_data: shared_data,
110             active: true,
111         }
112     }
113 
114     /// Cancel and destroy this sentinel.
cancel(mut self)115     fn cancel(mut self) {
116         self.active = false;
117     }
118 }
119 
120 impl<'a> Drop for Sentinel<'a> {
drop(&mut self)121     fn drop(&mut self) {
122         if self.active {
123             self.shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
124             if thread::panicking() {
125                 self.shared_data.panic_count.fetch_add(1, Ordering::SeqCst);
126             }
127             self.shared_data.no_work_notify_all();
128             spawn_in_pool(self.shared_data.clone())
129         }
130     }
131 }
132 
133 /// [`ThreadPool`] factory, which can be used in order to configure the properties of the
134 /// [`ThreadPool`].
135 ///
136 /// The three configuration options available:
137 ///
138 /// * `num_threads`: maximum number of threads that will be alive at any given moment by the built
139 ///   [`ThreadPool`]
140 /// * `thread_name`: thread name for each of the threads spawned by the built [`ThreadPool`]
141 /// * `thread_stack_size`: stack size (in bytes) for each of the threads spawned by the built
142 ///   [`ThreadPool`]
143 ///
144 /// [`ThreadPool`]: struct.ThreadPool.html
145 ///
146 /// # Examples
147 ///
148 /// Build a [`ThreadPool`] that uses a maximum of eight threads simultaneously and each thread has
149 /// a 8 MB stack size:
150 ///
151 /// ```
152 /// let pool = threadpool::Builder::new()
153 ///     .num_threads(8)
154 ///     .thread_stack_size(8_000_000)
155 ///     .build();
156 /// ```
157 #[derive(Clone, Default)]
158 pub struct Builder {
159     num_threads: Option<usize>,
160     thread_name: Option<String>,
161     thread_stack_size: Option<usize>,
162 }
163 
164 impl Builder {
165     /// Initiate a new [`Builder`].
166     ///
167     /// [`Builder`]: struct.Builder.html
168     ///
169     /// # Examples
170     ///
171     /// ```
172     /// let builder = threadpool::Builder::new();
173     /// ```
new() -> Builder174     pub fn new() -> Builder {
175         Builder {
176             num_threads: None,
177             thread_name: None,
178             thread_stack_size: None,
179         }
180     }
181 
182     /// Set the maximum number of worker-threads that will be alive at any given moment by the built
183     /// [`ThreadPool`]. If not specified, defaults the number of threads to the number of CPUs.
184     ///
185     /// [`ThreadPool`]: struct.ThreadPool.html
186     ///
187     /// # Panics
188     ///
189     /// This method will panic if `num_threads` is 0.
190     ///
191     /// # Examples
192     ///
193     /// No more than eight threads will be alive simultaneously for this pool:
194     ///
195     /// ```
196     /// use std::thread;
197     ///
198     /// let pool = threadpool::Builder::new()
199     ///     .num_threads(8)
200     ///     .build();
201     ///
202     /// for _ in 0..100 {
203     ///     pool.execute(|| {
204     ///         println!("Hello from a worker thread!")
205     ///     })
206     /// }
207     /// ```
num_threads(mut self, num_threads: usize) -> Builder208     pub fn num_threads(mut self, num_threads: usize) -> Builder {
209         assert!(num_threads > 0);
210         self.num_threads = Some(num_threads);
211         self
212     }
213 
214     /// Set the thread name for each of the threads spawned by the built [`ThreadPool`]. If not
215     /// specified, threads spawned by the thread pool will be unnamed.
216     ///
217     /// [`ThreadPool`]: struct.ThreadPool.html
218     ///
219     /// # Examples
220     ///
221     /// Each thread spawned by this pool will have the name "foo":
222     ///
223     /// ```
224     /// use std::thread;
225     ///
226     /// let pool = threadpool::Builder::new()
227     ///     .thread_name("foo".into())
228     ///     .build();
229     ///
230     /// for _ in 0..100 {
231     ///     pool.execute(|| {
232     ///         assert_eq!(thread::current().name(), Some("foo"));
233     ///     })
234     /// }
235     /// ```
thread_name(mut self, name: String) -> Builder236     pub fn thread_name(mut self, name: String) -> Builder {
237         self.thread_name = Some(name);
238         self
239     }
240 
241     /// Set the stack size (in bytes) for each of the threads spawned by the built [`ThreadPool`].
242     /// If not specified, threads spawned by the threadpool will have a stack size [as specified in
243     /// the `std::thread` documentation][thread].
244     ///
245     /// [thread]: https://doc.rust-lang.org/nightly/std/thread/index.html#stack-size
246     /// [`ThreadPool`]: struct.ThreadPool.html
247     ///
248     /// # Examples
249     ///
250     /// Each thread spawned by this pool will have a 4 MB stack:
251     ///
252     /// ```
253     /// let pool = threadpool::Builder::new()
254     ///     .thread_stack_size(4_000_000)
255     ///     .build();
256     ///
257     /// for _ in 0..100 {
258     ///     pool.execute(|| {
259     ///         println!("This thread has a 4 MB stack size!");
260     ///     })
261     /// }
262     /// ```
thread_stack_size(mut self, size: usize) -> Builder263     pub fn thread_stack_size(mut self, size: usize) -> Builder {
264         self.thread_stack_size = Some(size);
265         self
266     }
267 
268     /// Finalize the [`Builder`] and build the [`ThreadPool`].
269     ///
270     /// [`Builder`]: struct.Builder.html
271     /// [`ThreadPool`]: struct.ThreadPool.html
272     ///
273     /// # Examples
274     ///
275     /// ```
276     /// let pool = threadpool::Builder::new()
277     ///     .num_threads(8)
278     ///     .thread_stack_size(4_000_000)
279     ///     .build();
280     /// ```
build(self) -> ThreadPool281     pub fn build(self) -> ThreadPool {
282         let (tx, rx) = channel::<Thunk<'static>>();
283 
284         let num_threads = self.num_threads.unwrap_or_else(num_cpus::get);
285 
286         let shared_data = Arc::new(ThreadPoolSharedData {
287             name: self.thread_name,
288             job_receiver: Mutex::new(rx),
289             empty_condvar: Condvar::new(),
290             empty_trigger: Mutex::new(()),
291             join_generation: AtomicUsize::new(0),
292             queued_count: AtomicUsize::new(0),
293             active_count: AtomicUsize::new(0),
294             max_thread_count: AtomicUsize::new(num_threads),
295             panic_count: AtomicUsize::new(0),
296             stack_size: self.thread_stack_size,
297         });
298 
299         // Threadpool threads
300         for _ in 0..num_threads {
301             spawn_in_pool(shared_data.clone());
302         }
303 
304         ThreadPool {
305             jobs: tx,
306             shared_data: shared_data,
307         }
308     }
309 }
310 
311 struct ThreadPoolSharedData {
312     name: Option<String>,
313     job_receiver: Mutex<Receiver<Thunk<'static>>>,
314     empty_trigger: Mutex<()>,
315     empty_condvar: Condvar,
316     join_generation: AtomicUsize,
317     queued_count: AtomicUsize,
318     active_count: AtomicUsize,
319     max_thread_count: AtomicUsize,
320     panic_count: AtomicUsize,
321     stack_size: Option<usize>,
322 }
323 
324 impl ThreadPoolSharedData {
has_work(&self) -> bool325     fn has_work(&self) -> bool {
326         self.queued_count.load(Ordering::SeqCst) > 0 || self.active_count.load(Ordering::SeqCst) > 0
327     }
328 
329     /// Notify all observers joining this pool if there is no more work to do.
no_work_notify_all(&self)330     fn no_work_notify_all(&self) {
331         if !self.has_work() {
332             *self
333                 .empty_trigger
334                 .lock()
335                 .expect("Unable to notify all joining threads");
336             self.empty_condvar.notify_all();
337         }
338     }
339 }
340 
341 /// Abstraction of a thread pool for basic parallelism.
342 pub struct ThreadPool {
343     // How the threadpool communicates with subthreads.
344     //
345     // This is the only such Sender, so when it is dropped all subthreads will
346     // quit.
347     jobs: Sender<Thunk<'static>>,
348     shared_data: Arc<ThreadPoolSharedData>,
349 }
350 
351 impl ThreadPool {
352     /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
353     ///
354     /// # Panics
355     ///
356     /// This function will panic if `num_threads` is 0.
357     ///
358     /// # Examples
359     ///
360     /// Create a new thread pool capable of executing four jobs concurrently:
361     ///
362     /// ```
363     /// use threadpool::ThreadPool;
364     ///
365     /// let pool = ThreadPool::new(4);
366     /// ```
new(num_threads: usize) -> ThreadPool367     pub fn new(num_threads: usize) -> ThreadPool {
368         Builder::new().num_threads(num_threads).build()
369     }
370 
371     /// Creates a new thread pool capable of executing `num_threads` number of jobs concurrently.
372     /// Each thread will have the [name][thread name] `name`.
373     ///
374     /// # Panics
375     ///
376     /// This function will panic if `num_threads` is 0.
377     ///
378     /// # Examples
379     ///
380     /// ```rust
381     /// use std::thread;
382     /// use threadpool::ThreadPool;
383     ///
384     /// let pool = ThreadPool::with_name("worker".into(), 2);
385     /// for _ in 0..2 {
386     ///     pool.execute(|| {
387     ///         assert_eq!(
388     ///             thread::current().name(),
389     ///             Some("worker")
390     ///         );
391     ///     });
392     /// }
393     /// pool.join();
394     /// ```
395     ///
396     /// [thread name]: https://doc.rust-lang.org/std/thread/struct.Thread.html#method.name
with_name(name: String, num_threads: usize) -> ThreadPool397     pub fn with_name(name: String, num_threads: usize) -> ThreadPool {
398         Builder::new()
399             .num_threads(num_threads)
400             .thread_name(name)
401             .build()
402     }
403 
404     /// **Deprecated: Use [`ThreadPool::with_name`](#method.with_name)**
405     #[inline(always)]
406     #[deprecated(since = "1.4.0", note = "use ThreadPool::with_name")]
new_with_name(name: String, num_threads: usize) -> ThreadPool407     pub fn new_with_name(name: String, num_threads: usize) -> ThreadPool {
408         Self::with_name(name, num_threads)
409     }
410 
411     /// Executes the function `job` on a thread in the pool.
412     ///
413     /// # Examples
414     ///
415     /// Execute four jobs on a thread pool that can run two jobs concurrently:
416     ///
417     /// ```
418     /// use threadpool::ThreadPool;
419     ///
420     /// let pool = ThreadPool::new(2);
421     /// pool.execute(|| println!("hello"));
422     /// pool.execute(|| println!("world"));
423     /// pool.execute(|| println!("foo"));
424     /// pool.execute(|| println!("bar"));
425     /// pool.join();
426     /// ```
execute<F>(&self, job: F) where F: FnOnce() + Send + 'static,427     pub fn execute<F>(&self, job: F)
428     where
429         F: FnOnce() + Send + 'static,
430     {
431         self.shared_data.queued_count.fetch_add(1, Ordering::SeqCst);
432         self.jobs
433             .send(Box::new(job))
434             .expect("ThreadPool::execute unable to send job into queue.");
435     }
436 
437     /// Returns the number of jobs waiting to executed in the pool.
438     ///
439     /// # Examples
440     ///
441     /// ```
442     /// use threadpool::ThreadPool;
443     /// use std::time::Duration;
444     /// use std::thread::sleep;
445     ///
446     /// let pool = ThreadPool::new(2);
447     /// for _ in 0..10 {
448     ///     pool.execute(|| {
449     ///         sleep(Duration::from_secs(100));
450     ///     });
451     /// }
452     ///
453     /// sleep(Duration::from_secs(1)); // wait for threads to start
454     /// assert_eq!(8, pool.queued_count());
455     /// ```
queued_count(&self) -> usize456     pub fn queued_count(&self) -> usize {
457         self.shared_data.queued_count.load(Ordering::Relaxed)
458     }
459 
460     /// Returns the number of currently active threads.
461     ///
462     /// # Examples
463     ///
464     /// ```
465     /// use threadpool::ThreadPool;
466     /// use std::time::Duration;
467     /// use std::thread::sleep;
468     ///
469     /// let pool = ThreadPool::new(4);
470     /// for _ in 0..10 {
471     ///     pool.execute(move || {
472     ///         sleep(Duration::from_secs(100));
473     ///     });
474     /// }
475     ///
476     /// sleep(Duration::from_secs(1)); // wait for threads to start
477     /// assert_eq!(4, pool.active_count());
478     /// ```
active_count(&self) -> usize479     pub fn active_count(&self) -> usize {
480         self.shared_data.active_count.load(Ordering::SeqCst)
481     }
482 
483     /// Returns the maximum number of threads the pool will execute concurrently.
484     ///
485     /// # Examples
486     ///
487     /// ```
488     /// use threadpool::ThreadPool;
489     ///
490     /// let mut pool = ThreadPool::new(4);
491     /// assert_eq!(4, pool.max_count());
492     ///
493     /// pool.set_num_threads(8);
494     /// assert_eq!(8, pool.max_count());
495     /// ```
max_count(&self) -> usize496     pub fn max_count(&self) -> usize {
497         self.shared_data.max_thread_count.load(Ordering::Relaxed)
498     }
499 
500     /// Returns the number of panicked threads over the lifetime of the pool.
501     ///
502     /// # Examples
503     ///
504     /// ```
505     /// use threadpool::ThreadPool;
506     ///
507     /// let pool = ThreadPool::new(4);
508     /// for n in 0..10 {
509     ///     pool.execute(move || {
510     ///         // simulate a panic
511     ///         if n % 2 == 0 {
512     ///             panic!()
513     ///         }
514     ///     });
515     /// }
516     /// pool.join();
517     ///
518     /// assert_eq!(5, pool.panic_count());
519     /// ```
panic_count(&self) -> usize520     pub fn panic_count(&self) -> usize {
521         self.shared_data.panic_count.load(Ordering::Relaxed)
522     }
523 
524     /// **Deprecated: Use [`ThreadPool::set_num_threads`](#method.set_num_threads)**
525     #[deprecated(since = "1.3.0", note = "use ThreadPool::set_num_threads")]
set_threads(&mut self, num_threads: usize)526     pub fn set_threads(&mut self, num_threads: usize) {
527         self.set_num_threads(num_threads)
528     }
529 
530     /// Sets the number of worker-threads to use as `num_threads`.
531     /// Can be used to change the threadpool size during runtime.
532     /// Will not abort already running or waiting threads.
533     ///
534     /// # Panics
535     ///
536     /// This function will panic if `num_threads` is 0.
537     ///
538     /// # Examples
539     ///
540     /// ```
541     /// use threadpool::ThreadPool;
542     /// use std::time::Duration;
543     /// use std::thread::sleep;
544     ///
545     /// let mut pool = ThreadPool::new(4);
546     /// for _ in 0..10 {
547     ///     pool.execute(move || {
548     ///         sleep(Duration::from_secs(100));
549     ///     });
550     /// }
551     ///
552     /// sleep(Duration::from_secs(1)); // wait for threads to start
553     /// assert_eq!(4, pool.active_count());
554     /// assert_eq!(6, pool.queued_count());
555     ///
556     /// // Increase thread capacity of the pool
557     /// pool.set_num_threads(8);
558     ///
559     /// sleep(Duration::from_secs(1)); // wait for new threads to start
560     /// assert_eq!(8, pool.active_count());
561     /// assert_eq!(2, pool.queued_count());
562     ///
563     /// // Decrease thread capacity of the pool
564     /// // No active threads are killed
565     /// pool.set_num_threads(4);
566     ///
567     /// assert_eq!(8, pool.active_count());
568     /// assert_eq!(2, pool.queued_count());
569     /// ```
set_num_threads(&mut self, num_threads: usize)570     pub fn set_num_threads(&mut self, num_threads: usize) {
571         assert!(num_threads >= 1);
572         let prev_num_threads = self
573             .shared_data
574             .max_thread_count
575             .swap(num_threads, Ordering::Release);
576         if let Some(num_spawn) = num_threads.checked_sub(prev_num_threads) {
577             // Spawn new threads
578             for _ in 0..num_spawn {
579                 spawn_in_pool(self.shared_data.clone());
580             }
581         }
582     }
583 
584     /// Block the current thread until all jobs in the pool have been executed.
585     ///
586     /// Calling `join` on an empty pool will cause an immediate return.
587     /// `join` may be called from multiple threads concurrently.
588     /// A `join` is an atomic point in time. All threads joining before the join
589     /// event will exit together even if the pool is processing new jobs by the
590     /// time they get scheduled.
591     ///
592     /// Calling `join` from a thread within the pool will cause a deadlock. This
593     /// behavior is considered safe.
594     ///
595     /// # Examples
596     ///
597     /// ```
598     /// use threadpool::ThreadPool;
599     /// use std::sync::Arc;
600     /// use std::sync::atomic::{AtomicUsize, Ordering};
601     ///
602     /// let pool = ThreadPool::new(8);
603     /// let test_count = Arc::new(AtomicUsize::new(0));
604     ///
605     /// for _ in 0..42 {
606     ///     let test_count = test_count.clone();
607     ///     pool.execute(move || {
608     ///         test_count.fetch_add(1, Ordering::Relaxed);
609     ///     });
610     /// }
611     ///
612     /// pool.join();
613     /// assert_eq!(42, test_count.load(Ordering::Relaxed));
614     /// ```
join(&self)615     pub fn join(&self) {
616         // fast path requires no mutex
617         if self.shared_data.has_work() == false {
618             return ();
619         }
620 
621         let generation = self.shared_data.join_generation.load(Ordering::SeqCst);
622         let mut lock = self.shared_data.empty_trigger.lock().unwrap();
623 
624         while generation == self.shared_data.join_generation.load(Ordering::Relaxed)
625             && self.shared_data.has_work()
626         {
627             lock = self.shared_data.empty_condvar.wait(lock).unwrap();
628         }
629 
630         // increase generation if we are the first thread to come out of the loop
631         self.shared_data.join_generation.compare_and_swap(
632             generation,
633             generation.wrapping_add(1),
634             Ordering::SeqCst,
635         );
636     }
637 }
638 
639 impl Clone for ThreadPool {
640     /// Cloning a pool will create a new handle to the pool.
641     /// The behavior is similar to [Arc](https://doc.rust-lang.org/stable/std/sync/struct.Arc.html).
642     ///
643     /// We could for example submit jobs from multiple threads concurrently.
644     ///
645     /// ```
646     /// use threadpool::ThreadPool;
647     /// use std::thread;
648     /// use std::sync::mpsc::channel;
649     ///
650     /// let pool = ThreadPool::with_name("clone example".into(), 2);
651     ///
652     /// let results = (0..2)
653     ///     .map(|i| {
654     ///         let pool = pool.clone();
655     ///         thread::spawn(move || {
656     ///             let (tx, rx) = channel();
657     ///             for i in 1..12 {
658     ///                 let tx = tx.clone();
659     ///                 pool.execute(move || {
660     ///                     tx.send(i).expect("channel will be waiting");
661     ///                 });
662     ///             }
663     ///             drop(tx);
664     ///             if i == 0 {
665     ///                 rx.iter().fold(0, |accumulator, element| accumulator + element)
666     ///             } else {
667     ///                 rx.iter().fold(1, |accumulator, element| accumulator * element)
668     ///             }
669     ///         })
670     ///     })
671     ///     .map(|join_handle| join_handle.join().expect("collect results from threads"))
672     ///     .collect::<Vec<usize>>();
673     ///
674     /// assert_eq!(vec![66, 39916800], results);
675     /// ```
clone(&self) -> ThreadPool676     fn clone(&self) -> ThreadPool {
677         ThreadPool {
678             jobs: self.jobs.clone(),
679             shared_data: self.shared_data.clone(),
680         }
681     }
682 }
683 
684 /// Create a thread pool with one thread per CPU.
685 /// On machines with hyperthreading,
686 /// this will create one thread per hyperthread.
687 impl Default for ThreadPool {
default() -> Self688     fn default() -> Self {
689         ThreadPool::new(num_cpus::get())
690     }
691 }
692 
693 impl fmt::Debug for ThreadPool {
fmt(&self, f: &mut fmt::Formatter) -> fmt::Result694     fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
695         f.debug_struct("ThreadPool")
696             .field("name", &self.shared_data.name)
697             .field("queued_count", &self.queued_count())
698             .field("active_count", &self.active_count())
699             .field("max_count", &self.max_count())
700             .finish()
701     }
702 }
703 
704 impl PartialEq for ThreadPool {
705     /// Check if you are working with the same pool
706     ///
707     /// ```
708     /// use threadpool::ThreadPool;
709     ///
710     /// let a = ThreadPool::new(2);
711     /// let b = ThreadPool::new(2);
712     ///
713     /// assert_eq!(a, a);
714     /// assert_eq!(b, b);
715     ///
716     /// # // TODO: change this to assert_ne in the future
717     /// assert!(a != b);
718     /// assert!(b != a);
719     /// ```
eq(&self, other: &ThreadPool) -> bool720     fn eq(&self, other: &ThreadPool) -> bool {
721         let a: &ThreadPoolSharedData = &*self.shared_data;
722         let b: &ThreadPoolSharedData = &*other.shared_data;
723         a as *const ThreadPoolSharedData == b as *const ThreadPoolSharedData
724         // with rust 1.17 and late:
725         // Arc::ptr_eq(&self.shared_data, &other.shared_data)
726     }
727 }
728 impl Eq for ThreadPool {}
729 
spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>)730 fn spawn_in_pool(shared_data: Arc<ThreadPoolSharedData>) {
731     let mut builder = thread::Builder::new();
732     if let Some(ref name) = shared_data.name {
733         builder = builder.name(name.clone());
734     }
735     if let Some(ref stack_size) = shared_data.stack_size {
736         builder = builder.stack_size(stack_size.to_owned());
737     }
738     builder
739         .spawn(move || {
740             // Will spawn a new thread on panic unless it is cancelled.
741             let sentinel = Sentinel::new(&shared_data);
742 
743             loop {
744                 // Shutdown this thread if the pool has become smaller
745                 let thread_counter_val = shared_data.active_count.load(Ordering::Acquire);
746                 let max_thread_count_val = shared_data.max_thread_count.load(Ordering::Relaxed);
747                 if thread_counter_val >= max_thread_count_val {
748                     break;
749                 }
750                 let message = {
751                     // Only lock jobs for the time it takes
752                     // to get a job, not run it.
753                     let lock = shared_data
754                         .job_receiver
755                         .lock()
756                         .expect("Worker thread unable to lock job_receiver");
757                     lock.recv()
758                 };
759 
760                 let job = match message {
761                     Ok(job) => job,
762                     // The ThreadPool was dropped.
763                     Err(..) => break,
764                 };
765                 // Do not allow IR around the job execution
766                 shared_data.active_count.fetch_add(1, Ordering::SeqCst);
767                 shared_data.queued_count.fetch_sub(1, Ordering::SeqCst);
768 
769                 job.call_box();
770 
771                 shared_data.active_count.fetch_sub(1, Ordering::SeqCst);
772                 shared_data.no_work_notify_all();
773             }
774 
775             sentinel.cancel();
776         })
777         .unwrap();
778 }
779 
780 #[cfg(test)]
781 mod test {
782     use super::{Builder, ThreadPool};
783     use std::sync::atomic::{AtomicUsize, Ordering};
784     use std::sync::mpsc::{channel, sync_channel};
785     use std::sync::{Arc, Barrier};
786     use std::thread::{self, sleep};
787     use std::time::Duration;
788 
789     const TEST_TASKS: usize = 4;
790 
791     #[test]
test_set_num_threads_increasing()792     fn test_set_num_threads_increasing() {
793         let new_thread_amount = TEST_TASKS + 8;
794         let mut pool = ThreadPool::new(TEST_TASKS);
795         for _ in 0..TEST_TASKS {
796             pool.execute(move || sleep(Duration::from_secs(23)));
797         }
798         sleep(Duration::from_secs(1));
799         assert_eq!(pool.active_count(), TEST_TASKS);
800 
801         pool.set_num_threads(new_thread_amount);
802 
803         for _ in 0..(new_thread_amount - TEST_TASKS) {
804             pool.execute(move || sleep(Duration::from_secs(23)));
805         }
806         sleep(Duration::from_secs(1));
807         assert_eq!(pool.active_count(), new_thread_amount);
808 
809         pool.join();
810     }
811 
812     #[test]
test_set_num_threads_decreasing()813     fn test_set_num_threads_decreasing() {
814         let new_thread_amount = 2;
815         let mut pool = ThreadPool::new(TEST_TASKS);
816         for _ in 0..TEST_TASKS {
817             pool.execute(move || {
818                 assert_eq!(1, 1);
819             });
820         }
821         pool.set_num_threads(new_thread_amount);
822         for _ in 0..new_thread_amount {
823             pool.execute(move || sleep(Duration::from_secs(23)));
824         }
825         sleep(Duration::from_secs(1));
826         assert_eq!(pool.active_count(), new_thread_amount);
827 
828         pool.join();
829     }
830 
831     #[test]
test_active_count()832     fn test_active_count() {
833         let pool = ThreadPool::new(TEST_TASKS);
834         for _ in 0..2 * TEST_TASKS {
835             pool.execute(move || loop {
836                 sleep(Duration::from_secs(10))
837             });
838         }
839         sleep(Duration::from_secs(1));
840         let active_count = pool.active_count();
841         assert_eq!(active_count, TEST_TASKS);
842         let initialized_count = pool.max_count();
843         assert_eq!(initialized_count, TEST_TASKS);
844     }
845 
846     #[test]
test_works()847     fn test_works() {
848         let pool = ThreadPool::new(TEST_TASKS);
849 
850         let (tx, rx) = channel();
851         for _ in 0..TEST_TASKS {
852             let tx = tx.clone();
853             pool.execute(move || {
854                 tx.send(1).unwrap();
855             });
856         }
857 
858         assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
859     }
860 
861     #[test]
862     #[should_panic]
test_zero_tasks_panic()863     fn test_zero_tasks_panic() {
864         ThreadPool::new(0);
865     }
866 
867     #[test]
test_recovery_from_subtask_panic()868     fn test_recovery_from_subtask_panic() {
869         let pool = ThreadPool::new(TEST_TASKS);
870 
871         // Panic all the existing threads.
872         for _ in 0..TEST_TASKS {
873             pool.execute(move || panic!("Ignore this panic, it must!"));
874         }
875         pool.join();
876 
877         assert_eq!(pool.panic_count(), TEST_TASKS);
878 
879         // Ensure new threads were spawned to compensate.
880         let (tx, rx) = channel();
881         for _ in 0..TEST_TASKS {
882             let tx = tx.clone();
883             pool.execute(move || {
884                 tx.send(1).unwrap();
885             });
886         }
887 
888         assert_eq!(rx.iter().take(TEST_TASKS).fold(0, |a, b| a + b), TEST_TASKS);
889     }
890 
891     #[test]
test_should_not_panic_on_drop_if_subtasks_panic_after_drop()892     fn test_should_not_panic_on_drop_if_subtasks_panic_after_drop() {
893         let pool = ThreadPool::new(TEST_TASKS);
894         let waiter = Arc::new(Barrier::new(TEST_TASKS + 1));
895 
896         // Panic all the existing threads in a bit.
897         for _ in 0..TEST_TASKS {
898             let waiter = waiter.clone();
899             pool.execute(move || {
900                 waiter.wait();
901                 panic!("Ignore this panic, it should!");
902             });
903         }
904 
905         drop(pool);
906 
907         // Kick off the failure.
908         waiter.wait();
909     }
910 
911     #[test]
test_massive_task_creation()912     fn test_massive_task_creation() {
913         let test_tasks = 4_200_000;
914 
915         let pool = ThreadPool::new(TEST_TASKS);
916         let b0 = Arc::new(Barrier::new(TEST_TASKS + 1));
917         let b1 = Arc::new(Barrier::new(TEST_TASKS + 1));
918 
919         let (tx, rx) = channel();
920 
921         for i in 0..test_tasks {
922             let tx = tx.clone();
923             let (b0, b1) = (b0.clone(), b1.clone());
924 
925             pool.execute(move || {
926                 // Wait until the pool has been filled once.
927                 if i < TEST_TASKS {
928                     b0.wait();
929                     // wait so the pool can be measured
930                     b1.wait();
931                 }
932 
933                 tx.send(1).is_ok();
934             });
935         }
936 
937         b0.wait();
938         assert_eq!(pool.active_count(), TEST_TASKS);
939         b1.wait();
940 
941         assert_eq!(rx.iter().take(test_tasks).fold(0, |a, b| a + b), test_tasks);
942         pool.join();
943 
944         let atomic_active_count = pool.active_count();
945         assert!(
946             atomic_active_count == 0,
947             "atomic_active_count: {}",
948             atomic_active_count
949         );
950     }
951 
952     #[test]
test_shrink()953     fn test_shrink() {
954         let test_tasks_begin = TEST_TASKS + 2;
955 
956         let mut pool = ThreadPool::new(test_tasks_begin);
957         let b0 = Arc::new(Barrier::new(test_tasks_begin + 1));
958         let b1 = Arc::new(Barrier::new(test_tasks_begin + 1));
959 
960         for _ in 0..test_tasks_begin {
961             let (b0, b1) = (b0.clone(), b1.clone());
962             pool.execute(move || {
963                 b0.wait();
964                 b1.wait();
965             });
966         }
967 
968         let b2 = Arc::new(Barrier::new(TEST_TASKS + 1));
969         let b3 = Arc::new(Barrier::new(TEST_TASKS + 1));
970 
971         for _ in 0..TEST_TASKS {
972             let (b2, b3) = (b2.clone(), b3.clone());
973             pool.execute(move || {
974                 b2.wait();
975                 b3.wait();
976             });
977         }
978 
979         b0.wait();
980         pool.set_num_threads(TEST_TASKS);
981 
982         assert_eq!(pool.active_count(), test_tasks_begin);
983         b1.wait();
984 
985         b2.wait();
986         assert_eq!(pool.active_count(), TEST_TASKS);
987         b3.wait();
988     }
989 
990     #[test]
test_name()991     fn test_name() {
992         let name = "test";
993         let mut pool = ThreadPool::with_name(name.to_owned(), 2);
994         let (tx, rx) = sync_channel(0);
995 
996         // initial thread should share the name "test"
997         for _ in 0..2 {
998             let tx = tx.clone();
999             pool.execute(move || {
1000                 let name = thread::current().name().unwrap().to_owned();
1001                 tx.send(name).unwrap();
1002             });
1003         }
1004 
1005         // new spawn thread should share the name "test" too.
1006         pool.set_num_threads(3);
1007         let tx_clone = tx.clone();
1008         pool.execute(move || {
1009             let name = thread::current().name().unwrap().to_owned();
1010             tx_clone.send(name).unwrap();
1011             panic!();
1012         });
1013 
1014         // recover thread should share the name "test" too.
1015         pool.execute(move || {
1016             let name = thread::current().name().unwrap().to_owned();
1017             tx.send(name).unwrap();
1018         });
1019 
1020         for thread_name in rx.iter().take(4) {
1021             assert_eq!(name, thread_name);
1022         }
1023     }
1024 
1025     #[test]
test_debug()1026     fn test_debug() {
1027         let pool = ThreadPool::new(4);
1028         let debug = format!("{:?}", pool);
1029         assert_eq!(
1030             debug,
1031             "ThreadPool { name: None, queued_count: 0, active_count: 0, max_count: 4 }"
1032         );
1033 
1034         let pool = ThreadPool::with_name("hello".into(), 4);
1035         let debug = format!("{:?}", pool);
1036         assert_eq!(
1037             debug,
1038             "ThreadPool { name: Some(\"hello\"), queued_count: 0, active_count: 0, max_count: 4 }"
1039         );
1040 
1041         let pool = ThreadPool::new(4);
1042         pool.execute(move || sleep(Duration::from_secs(5)));
1043         sleep(Duration::from_secs(1));
1044         let debug = format!("{:?}", pool);
1045         assert_eq!(
1046             debug,
1047             "ThreadPool { name: None, queued_count: 0, active_count: 1, max_count: 4 }"
1048         );
1049     }
1050 
1051     #[test]
test_repeate_join()1052     fn test_repeate_join() {
1053         let pool = ThreadPool::with_name("repeate join test".into(), 8);
1054         let test_count = Arc::new(AtomicUsize::new(0));
1055 
1056         for _ in 0..42 {
1057             let test_count = test_count.clone();
1058             pool.execute(move || {
1059                 sleep(Duration::from_secs(2));
1060                 test_count.fetch_add(1, Ordering::Release);
1061             });
1062         }
1063 
1064         println!("{:?}", pool);
1065         pool.join();
1066         assert_eq!(42, test_count.load(Ordering::Acquire));
1067 
1068         for _ in 0..42 {
1069             let test_count = test_count.clone();
1070             pool.execute(move || {
1071                 sleep(Duration::from_secs(2));
1072                 test_count.fetch_add(1, Ordering::Relaxed);
1073             });
1074         }
1075         pool.join();
1076         assert_eq!(84, test_count.load(Ordering::Relaxed));
1077     }
1078 
1079     #[test]
test_multi_join()1080     fn test_multi_join() {
1081         use std::sync::mpsc::TryRecvError::*;
1082 
1083         // Toggle the following lines to debug the deadlock
1084         fn error(_s: String) {
1085             //use ::std::io::Write;
1086             //let stderr = ::std::io::stderr();
1087             //let mut stderr = stderr.lock();
1088             //stderr.write(&_s.as_bytes()).is_ok();
1089         }
1090 
1091         let pool0 = ThreadPool::with_name("multi join pool0".into(), 4);
1092         let pool1 = ThreadPool::with_name("multi join pool1".into(), 4);
1093         let (tx, rx) = channel();
1094 
1095         for i in 0..8 {
1096             let pool1 = pool1.clone();
1097             let pool0_ = pool0.clone();
1098             let tx = tx.clone();
1099             pool0.execute(move || {
1100                 pool1.execute(move || {
1101                     error(format!("p1: {} -=- {:?}\n", i, pool0_));
1102                     pool0_.join();
1103                     error(format!("p1: send({})\n", i));
1104                     tx.send(i).expect("send i from pool1 -> main");
1105                 });
1106                 error(format!("p0: {}\n", i));
1107             });
1108         }
1109         drop(tx);
1110 
1111         assert_eq!(rx.try_recv(), Err(Empty));
1112         error(format!("{:?}\n{:?}\n", pool0, pool1));
1113         pool0.join();
1114         error(format!("pool0.join() complete =-= {:?}", pool1));
1115         pool1.join();
1116         error("pool1.join() complete\n".into());
1117         assert_eq!(
1118             rx.iter().fold(0, |acc, i| acc + i),
1119             0 + 1 + 2 + 3 + 4 + 5 + 6 + 7
1120         );
1121     }
1122 
1123     #[test]
test_empty_pool()1124     fn test_empty_pool() {
1125         // Joining an empty pool must return imminently
1126         let pool = ThreadPool::new(4);
1127 
1128         pool.join();
1129 
1130         assert!(true);
1131     }
1132 
1133     #[test]
test_no_fun_or_joy()1134     fn test_no_fun_or_joy() {
1135         // What happens when you keep adding jobs after a join
1136 
1137         fn sleepy_function() {
1138             sleep(Duration::from_secs(6));
1139         }
1140 
1141         let pool = ThreadPool::with_name("no fun or joy".into(), 8);
1142 
1143         pool.execute(sleepy_function);
1144 
1145         let p_t = pool.clone();
1146         thread::spawn(move || {
1147             (0..23).map(|_| p_t.execute(sleepy_function)).count();
1148         });
1149 
1150         pool.join();
1151     }
1152 
1153     #[test]
test_clone()1154     fn test_clone() {
1155         let pool = ThreadPool::with_name("clone example".into(), 2);
1156 
1157         // This batch of jobs will occupy the pool for some time
1158         for _ in 0..6 {
1159             pool.execute(move || {
1160                 sleep(Duration::from_secs(2));
1161             });
1162         }
1163 
1164         // The following jobs will be inserted into the pool in a random fashion
1165         let t0 = {
1166             let pool = pool.clone();
1167             thread::spawn(move || {
1168                 // wait for the first batch of tasks to finish
1169                 pool.join();
1170 
1171                 let (tx, rx) = channel();
1172                 for i in 0..42 {
1173                     let tx = tx.clone();
1174                     pool.execute(move || {
1175                         tx.send(i).expect("channel will be waiting");
1176                     });
1177                 }
1178                 drop(tx);
1179                 rx.iter()
1180                     .fold(0, |accumulator, element| accumulator + element)
1181             })
1182         };
1183         let t1 = {
1184             let pool = pool.clone();
1185             thread::spawn(move || {
1186                 // wait for the first batch of tasks to finish
1187                 pool.join();
1188 
1189                 let (tx, rx) = channel();
1190                 for i in 1..12 {
1191                     let tx = tx.clone();
1192                     pool.execute(move || {
1193                         tx.send(i).expect("channel will be waiting");
1194                     });
1195                 }
1196                 drop(tx);
1197                 rx.iter()
1198                     .fold(1, |accumulator, element| accumulator * element)
1199             })
1200         };
1201 
1202         assert_eq!(
1203             861,
1204             t0.join()
1205                 .expect("thread 0 will return after calculating additions",)
1206         );
1207         assert_eq!(
1208             39916800,
1209             t1.join()
1210                 .expect("thread 1 will return after calculating multiplications",)
1211         );
1212     }
1213 
1214     #[test]
test_sync_shared_data()1215     fn test_sync_shared_data() {
1216         fn assert_sync<T: Sync>() {}
1217         assert_sync::<super::ThreadPoolSharedData>();
1218     }
1219 
1220     #[test]
test_send_shared_data()1221     fn test_send_shared_data() {
1222         fn assert_send<T: Send>() {}
1223         assert_send::<super::ThreadPoolSharedData>();
1224     }
1225 
1226     #[test]
test_send()1227     fn test_send() {
1228         fn assert_send<T: Send>() {}
1229         assert_send::<ThreadPool>();
1230     }
1231 
1232     #[test]
test_cloned_eq()1233     fn test_cloned_eq() {
1234         let a = ThreadPool::new(2);
1235 
1236         assert_eq!(a, a.clone());
1237     }
1238 
1239     #[test]
1240     /// The scenario is joining threads should not be stuck once their wave
1241     /// of joins has completed. So once one thread joining on a pool has
1242     /// succeded other threads joining on the same pool must get out even if
1243     /// the thread is used for other jobs while the first group is finishing
1244     /// their join
1245     ///
1246     /// In this example this means the waiting threads will exit the join in
1247     /// groups of four because the waiter pool has four workers.
test_join_wavesurfer()1248     fn test_join_wavesurfer() {
1249         let n_cycles = 4;
1250         let n_workers = 4;
1251         let (tx, rx) = channel();
1252         let builder = Builder::new()
1253             .num_threads(n_workers)
1254             .thread_name("join wavesurfer".into());
1255         let p_waiter = builder.clone().build();
1256         let p_clock = builder.build();
1257 
1258         let barrier = Arc::new(Barrier::new(3));
1259         let wave_clock = Arc::new(AtomicUsize::new(0));
1260         let clock_thread = {
1261             let barrier = barrier.clone();
1262             let wave_clock = wave_clock.clone();
1263             thread::spawn(move || {
1264                 barrier.wait();
1265                 for wave_num in 0..n_cycles {
1266                     wave_clock.store(wave_num, Ordering::SeqCst);
1267                     sleep(Duration::from_secs(1));
1268                 }
1269             })
1270         };
1271 
1272         {
1273             let barrier = barrier.clone();
1274             p_clock.execute(move || {
1275                 barrier.wait();
1276                 // this sleep is for stabilisation on weaker platforms
1277                 sleep(Duration::from_millis(100));
1278             });
1279         }
1280 
1281         // prepare three waves of jobs
1282         for i in 0..3 * n_workers {
1283             let p_clock = p_clock.clone();
1284             let tx = tx.clone();
1285             let wave_clock = wave_clock.clone();
1286             p_waiter.execute(move || {
1287                 let now = wave_clock.load(Ordering::SeqCst);
1288                 p_clock.join();
1289                 // submit jobs for the second wave
1290                 p_clock.execute(|| sleep(Duration::from_secs(1)));
1291                 let clock = wave_clock.load(Ordering::SeqCst);
1292                 tx.send((now, clock, i)).unwrap();
1293             });
1294         }
1295         println!("all scheduled at {}", wave_clock.load(Ordering::SeqCst));
1296         barrier.wait();
1297 
1298         p_clock.join();
1299         //p_waiter.join();
1300 
1301         drop(tx);
1302         let mut hist = vec![0; n_cycles];
1303         let mut data = vec![];
1304         for (now, after, i) in rx.iter() {
1305             let mut dur = after - now;
1306             if dur >= n_cycles - 1 {
1307                 dur = n_cycles - 1;
1308             }
1309             hist[dur] += 1;
1310 
1311             data.push((now, after, i));
1312         }
1313         for (i, n) in hist.iter().enumerate() {
1314             println!(
1315                 "\t{}: {} {}",
1316                 i,
1317                 n,
1318                 &*(0..*n).fold("".to_owned(), |s, _| s + "*")
1319             );
1320         }
1321         assert!(data.iter().all(|&(cycle, stop, i)| if i < n_workers {
1322             cycle == stop
1323         } else {
1324             cycle < stop
1325         }));
1326 
1327         clock_thread.join().unwrap();
1328     }
1329 }
1330