1 use crate::primitive::sync::{Arc, Condvar, Mutex};
2 use std::fmt;
3 
4 /// Enables threads to synchronize the beginning or end of some computation.
5 ///
6 /// # Wait groups vs barriers
7 ///
8 /// `WaitGroup` is very similar to [`Barrier`], but there are a few differences:
9 ///
10 /// * [`Barrier`] needs to know the number of threads at construction, while `WaitGroup` is cloned to
11 ///   register more threads.
12 ///
13 /// * A [`Barrier`] can be reused even after all threads have synchronized, while a `WaitGroup`
14 ///   synchronizes threads only once.
15 ///
16 /// * All threads wait for others to reach the [`Barrier`]. With `WaitGroup`, each thread can choose
17 ///   to either wait for other threads or to continue without blocking.
18 ///
19 /// # Examples
20 ///
21 /// ```
22 /// use crossbeam_utils::sync::WaitGroup;
23 /// use std::thread;
24 ///
25 /// // Create a new wait group.
26 /// let wg = WaitGroup::new();
27 ///
28 /// for _ in 0..4 {
29 ///     // Create another reference to the wait group.
30 ///     let wg = wg.clone();
31 ///
32 ///     thread::spawn(move || {
33 ///         // Do some work.
34 ///
35 ///         // Drop the reference to the wait group.
36 ///         drop(wg);
37 ///     });
38 /// }
39 ///
40 /// // Block until all threads have finished their work.
41 /// wg.wait();
42 /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
43 /// ```
44 ///
45 /// [`Barrier`]: std::sync::Barrier
46 pub struct WaitGroup {
47     inner: Arc<Inner>,
48 }
49 
50 /// Inner state of a `WaitGroup`.
51 struct Inner {
52     cvar: Condvar,
53     count: Mutex<usize>,
54 }
55 
56 impl Default for WaitGroup {
default() -> Self57     fn default() -> Self {
58         Self {
59             inner: Arc::new(Inner {
60                 cvar: Condvar::new(),
61                 count: Mutex::new(1),
62             }),
63         }
64     }
65 }
66 
67 impl WaitGroup {
68     /// Creates a new wait group and returns the single reference to it.
69     ///
70     /// # Examples
71     ///
72     /// ```
73     /// use crossbeam_utils::sync::WaitGroup;
74     ///
75     /// let wg = WaitGroup::new();
76     /// ```
new() -> Self77     pub fn new() -> Self {
78         Self::default()
79     }
80 
81     /// Drops this reference and waits until all other references are dropped.
82     ///
83     /// # Examples
84     ///
85     /// ```
86     /// use crossbeam_utils::sync::WaitGroup;
87     /// use std::thread;
88     ///
89     /// let wg = WaitGroup::new();
90     ///
91     /// thread::spawn({
92     ///     let wg = wg.clone();
93     ///     move || {
94     ///         // Block until both threads have reached `wait()`.
95     ///         wg.wait();
96     ///     }
97     /// });
98     ///
99     /// // Block until both threads have reached `wait()`.
100     /// wg.wait();
101     /// # std::thread::sleep(std::time::Duration::from_millis(500)); // wait for background threads closed: https://github.com/rust-lang/miri/issues/1371
102     /// ```
wait(self)103     pub fn wait(self) {
104         if *self.inner.count.lock().unwrap() == 1 {
105             return;
106         }
107 
108         let inner = self.inner.clone();
109         drop(self);
110 
111         let mut count = inner.count.lock().unwrap();
112         while *count > 0 {
113             count = inner.cvar.wait(count).unwrap();
114         }
115     }
116 }
117 
118 impl Drop for WaitGroup {
drop(&mut self)119     fn drop(&mut self) {
120         let mut count = self.inner.count.lock().unwrap();
121         *count -= 1;
122 
123         if *count == 0 {
124             self.inner.cvar.notify_all();
125         }
126     }
127 }
128 
129 impl Clone for WaitGroup {
clone(&self) -> WaitGroup130     fn clone(&self) -> WaitGroup {
131         let mut count = self.inner.count.lock().unwrap();
132         *count += 1;
133 
134         WaitGroup {
135             inner: self.inner.clone(),
136         }
137     }
138 }
139 
140 impl fmt::Debug for WaitGroup {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result141     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
142         let count: &usize = &*self.inner.count.lock().unwrap();
143         f.debug_struct("WaitGroup").field("count", count).finish()
144     }
145 }
146