1 use crate::task::{waker_ref, ArcWake};
2 use futures_core::future::{FusedFuture, Future};
3 use futures_core::task::{Context, Poll, Waker};
4 use slab::Slab;
5 use std::cell::UnsafeCell;
6 use std::fmt;
7 use std::hash::Hasher;
8 use std::pin::Pin;
9 use std::ptr;
10 use std::sync::atomic::AtomicUsize;
11 use std::sync::atomic::Ordering::{Acquire, SeqCst};
12 use std::sync::{Arc, Mutex, Weak};
13 
14 /// Future for the [`shared`](super::FutureExt::shared) method.
15 #[must_use = "futures do nothing unless you `.await` or poll them"]
16 pub struct Shared<Fut: Future> {
17     inner: Option<Arc<Inner<Fut>>>,
18     waker_key: usize,
19 }
20 
21 struct Inner<Fut: Future> {
22     future_or_output: UnsafeCell<FutureOrOutput<Fut>>,
23     notifier: Arc<Notifier>,
24 }
25 
26 struct Notifier {
27     state: AtomicUsize,
28     wakers: Mutex<Option<Slab<Option<Waker>>>>,
29 }
30 
31 /// A weak reference to a [`Shared`] that can be upgraded much like an `Arc`.
32 pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>);
33 
34 impl<Fut: Future> Clone for WeakShared<Fut> {
clone(&self) -> Self35     fn clone(&self) -> Self {
36         Self(self.0.clone())
37     }
38 }
39 
40 impl<Fut: Future> fmt::Debug for Shared<Fut> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result41     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
42         f.debug_struct("Shared")
43             .field("inner", &self.inner)
44             .field("waker_key", &self.waker_key)
45             .finish()
46     }
47 }
48 
49 impl<Fut: Future> fmt::Debug for Inner<Fut> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result50     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
51         f.debug_struct("Inner").finish()
52     }
53 }
54 
55 impl<Fut: Future> fmt::Debug for WeakShared<Fut> {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result56     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57         f.debug_struct("WeakShared").finish()
58     }
59 }
60 
61 enum FutureOrOutput<Fut: Future> {
62     Future(Fut),
63     Output(Fut::Output),
64 }
65 
66 unsafe impl<Fut> Send for Inner<Fut>
67 where
68     Fut: Future + Send,
69     Fut::Output: Send + Sync,
70 {
71 }
72 
73 unsafe impl<Fut> Sync for Inner<Fut>
74 where
75     Fut: Future + Send,
76     Fut::Output: Send + Sync,
77 {
78 }
79 
80 const IDLE: usize = 0;
81 const POLLING: usize = 1;
82 const COMPLETE: usize = 2;
83 const POISONED: usize = 3;
84 
85 const NULL_WAKER_KEY: usize = usize::MAX;
86 
87 impl<Fut: Future> Shared<Fut> {
new(future: Fut) -> Self88     pub(super) fn new(future: Fut) -> Self {
89         let inner = Inner {
90             future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)),
91             notifier: Arc::new(Notifier {
92                 state: AtomicUsize::new(IDLE),
93                 wakers: Mutex::new(Some(Slab::new())),
94             }),
95         };
96 
97         Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY }
98     }
99 }
100 
101 impl<Fut> Shared<Fut>
102 where
103     Fut: Future,
104 {
105     /// Returns [`Some`] containing a reference to this [`Shared`]'s output if
106     /// it has already been computed by a clone or [`None`] if it hasn't been
107     /// computed yet or this [`Shared`] already returned its output from
108     /// [`poll`](Future::poll).
peek(&self) -> Option<&Fut::Output>109     pub fn peek(&self) -> Option<&Fut::Output> {
110         if let Some(inner) = self.inner.as_ref() {
111             match inner.notifier.state.load(SeqCst) {
112                 COMPLETE => unsafe { return Some(inner.output()) },
113                 POISONED => panic!("inner future panicked during poll"),
114                 _ => {}
115             }
116         }
117         None
118     }
119 
120     /// Creates a new [`WeakShared`] for this [`Shared`].
121     ///
122     /// Returns [`None`] if it has already been polled to completion.
downgrade(&self) -> Option<WeakShared<Fut>>123     pub fn downgrade(&self) -> Option<WeakShared<Fut>> {
124         if let Some(inner) = self.inner.as_ref() {
125             return Some(WeakShared(Arc::downgrade(inner)));
126         }
127         None
128     }
129 
130     /// Gets the number of strong pointers to this allocation.
131     ///
132     /// Returns [`None`] if it has already been polled to completion.
133     ///
134     /// # Safety
135     ///
136     /// This method by itself is safe, but using it correctly requires extra care. Another thread
137     /// can change the strong count at any time, including potentially between calling this method
138     /// and acting on the result.
139     #[allow(clippy::unnecessary_safety_doc)]
strong_count(&self) -> Option<usize>140     pub fn strong_count(&self) -> Option<usize> {
141         self.inner.as_ref().map(|arc| Arc::strong_count(arc))
142     }
143 
144     /// Gets the number of weak pointers to this allocation.
145     ///
146     /// Returns [`None`] if it has already been polled to completion.
147     ///
148     /// # Safety
149     ///
150     /// This method by itself is safe, but using it correctly requires extra care. Another thread
151     /// can change the weak count at any time, including potentially between calling this method
152     /// and acting on the result.
153     #[allow(clippy::unnecessary_safety_doc)]
weak_count(&self) -> Option<usize>154     pub fn weak_count(&self) -> Option<usize> {
155         self.inner.as_ref().map(|arc| Arc::weak_count(arc))
156     }
157 
158     /// Hashes the internal state of this `Shared` in a way that's compatible with `ptr_eq`.
ptr_hash<H: Hasher>(&self, state: &mut H)159     pub fn ptr_hash<H: Hasher>(&self, state: &mut H) {
160         match self.inner.as_ref() {
161             Some(arc) => {
162                 state.write_u8(1);
163                 ptr::hash(Arc::as_ptr(arc), state);
164             }
165             None => {
166                 state.write_u8(0);
167             }
168         }
169     }
170 
171     /// Returns `true` if the two `Shared`s point to the same future (in a vein similar to
172     /// `Arc::ptr_eq`).
173     ///
174     /// Returns `false` if either `Shared` has terminated.
ptr_eq(&self, rhs: &Self) -> bool175     pub fn ptr_eq(&self, rhs: &Self) -> bool {
176         let lhs = match self.inner.as_ref() {
177             Some(lhs) => lhs,
178             None => return false,
179         };
180         let rhs = match rhs.inner.as_ref() {
181             Some(rhs) => rhs,
182             None => return false,
183         };
184         Arc::ptr_eq(lhs, rhs)
185     }
186 }
187 
188 impl<Fut> Inner<Fut>
189 where
190     Fut: Future,
191 {
192     /// Safety: callers must first ensure that `self.inner.state`
193     /// is `COMPLETE`
output(&self) -> &Fut::Output194     unsafe fn output(&self) -> &Fut::Output {
195         match unsafe { &*self.future_or_output.get() } {
196             FutureOrOutput::Output(item) => item,
197             FutureOrOutput::Future(_) => unreachable!(),
198         }
199     }
200 }
201 
202 impl<Fut> Inner<Fut>
203 where
204     Fut: Future,
205     Fut::Output: Clone,
206 {
207     /// Registers the current task to receive a wakeup when we are awoken.
record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>)208     fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) {
209         let mut wakers_guard = self.notifier.wakers.lock().unwrap();
210 
211         let wakers = match wakers_guard.as_mut() {
212             Some(wakers) => wakers,
213             None => return,
214         };
215 
216         let new_waker = cx.waker();
217 
218         if *waker_key == NULL_WAKER_KEY {
219             *waker_key = wakers.insert(Some(new_waker.clone()));
220         } else {
221             match wakers[*waker_key] {
222                 Some(ref old_waker) if new_waker.will_wake(old_waker) => {}
223                 // Could use clone_from here, but Waker doesn't specialize it.
224                 ref mut slot => *slot = Some(new_waker.clone()),
225             }
226         }
227         debug_assert!(*waker_key != NULL_WAKER_KEY);
228     }
229 
230     /// Safety: callers must first ensure that `inner.state`
231     /// is `COMPLETE`
take_or_clone_output(self: Arc<Self>) -> Fut::Output232     unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output {
233         match Arc::try_unwrap(self) {
234             Ok(inner) => match inner.future_or_output.into_inner() {
235                 FutureOrOutput::Output(item) => item,
236                 FutureOrOutput::Future(_) => unreachable!(),
237             },
238             Err(inner) => unsafe { inner.output().clone() },
239         }
240     }
241 }
242 
243 impl<Fut> FusedFuture for Shared<Fut>
244 where
245     Fut: Future,
246     Fut::Output: Clone,
247 {
is_terminated(&self) -> bool248     fn is_terminated(&self) -> bool {
249         self.inner.is_none()
250     }
251 }
252 
253 impl<Fut> Future for Shared<Fut>
254 where
255     Fut: Future,
256     Fut::Output: Clone,
257 {
258     type Output = Fut::Output;
259 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>260     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
261         let this = &mut *self;
262 
263         let inner = this.inner.take().expect("Shared future polled again after completion");
264 
265         // Fast path for when the wrapped future has already completed
266         if inner.notifier.state.load(Acquire) == COMPLETE {
267             // Safety: We're in the COMPLETE state
268             return unsafe { Poll::Ready(inner.take_or_clone_output()) };
269         }
270 
271         inner.record_waker(&mut this.waker_key, cx);
272 
273         match inner
274             .notifier
275             .state
276             .compare_exchange(IDLE, POLLING, SeqCst, SeqCst)
277             .unwrap_or_else(|x| x)
278         {
279             IDLE => {
280                 // Lock acquired, fall through
281             }
282             POLLING => {
283                 // Another task is currently polling, at this point we just want
284                 // to ensure that the waker for this task is registered
285                 this.inner = Some(inner);
286                 return Poll::Pending;
287             }
288             COMPLETE => {
289                 // Safety: We're in the COMPLETE state
290                 return unsafe { Poll::Ready(inner.take_or_clone_output()) };
291             }
292             POISONED => panic!("inner future panicked during poll"),
293             _ => unreachable!(),
294         }
295 
296         let waker = waker_ref(&inner.notifier);
297         let mut cx = Context::from_waker(&waker);
298 
299         struct Reset<'a> {
300             state: &'a AtomicUsize,
301             did_not_panic: bool,
302         }
303 
304         impl Drop for Reset<'_> {
305             fn drop(&mut self) {
306                 if !self.did_not_panic {
307                     self.state.store(POISONED, SeqCst);
308                 }
309             }
310         }
311 
312         let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false };
313 
314         let output = {
315             let future = unsafe {
316                 match &mut *inner.future_or_output.get() {
317                     FutureOrOutput::Future(fut) => Pin::new_unchecked(fut),
318                     _ => unreachable!(),
319                 }
320             };
321 
322             let poll_result = future.poll(&mut cx);
323             reset.did_not_panic = true;
324 
325             match poll_result {
326                 Poll::Pending => {
327                     if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok()
328                     {
329                         // Success
330                         drop(reset);
331                         this.inner = Some(inner);
332                         return Poll::Pending;
333                     } else {
334                         unreachable!()
335                     }
336                 }
337                 Poll::Ready(output) => output,
338             }
339         };
340 
341         unsafe {
342             *inner.future_or_output.get() = FutureOrOutput::Output(output);
343         }
344 
345         inner.notifier.state.store(COMPLETE, SeqCst);
346 
347         // Wake all tasks and drop the slab
348         let mut wakers_guard = inner.notifier.wakers.lock().unwrap();
349         let mut wakers = wakers_guard.take().unwrap();
350         for waker in wakers.drain().flatten() {
351             waker.wake();
352         }
353 
354         drop(reset); // Make borrow checker happy
355         drop(wakers_guard);
356 
357         // Safety: We're in the COMPLETE state
358         unsafe { Poll::Ready(inner.take_or_clone_output()) }
359     }
360 }
361 
362 impl<Fut> Clone for Shared<Fut>
363 where
364     Fut: Future,
365 {
clone(&self) -> Self366     fn clone(&self) -> Self {
367         Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY }
368     }
369 }
370 
371 impl<Fut> Drop for Shared<Fut>
372 where
373     Fut: Future,
374 {
drop(&mut self)375     fn drop(&mut self) {
376         if self.waker_key != NULL_WAKER_KEY {
377             if let Some(ref inner) = self.inner {
378                 if let Ok(mut wakers) = inner.notifier.wakers.lock() {
379                     if let Some(wakers) = wakers.as_mut() {
380                         wakers.remove(self.waker_key);
381                     }
382                 }
383             }
384         }
385     }
386 }
387 
388 impl ArcWake for Notifier {
wake_by_ref(arc_self: &Arc<Self>)389     fn wake_by_ref(arc_self: &Arc<Self>) {
390         let wakers = &mut *arc_self.wakers.lock().unwrap();
391         if let Some(wakers) = wakers.as_mut() {
392             for (_key, opt_waker) in wakers {
393                 if let Some(waker) = opt_waker.take() {
394                     waker.wake();
395                 }
396             }
397         }
398     }
399 }
400 
401 impl<Fut: Future> WeakShared<Fut> {
402     /// Attempts to upgrade this [`WeakShared`] into a [`Shared`].
403     ///
404     /// Returns [`None`] if all clones of the [`Shared`] have been dropped or polled
405     /// to completion.
upgrade(&self) -> Option<Shared<Fut>>406     pub fn upgrade(&self) -> Option<Shared<Fut>> {
407         Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY })
408     }
409 }
410