1 use crate::task::{waker_ref, ArcWake}; 2 use futures_core::future::{FusedFuture, Future}; 3 use futures_core::task::{Context, Poll, Waker}; 4 use slab::Slab; 5 use std::cell::UnsafeCell; 6 use std::fmt; 7 use std::hash::Hasher; 8 use std::pin::Pin; 9 use std::ptr; 10 use std::sync::atomic::AtomicUsize; 11 use std::sync::atomic::Ordering::{Acquire, SeqCst}; 12 use std::sync::{Arc, Mutex, Weak}; 13 14 /// Future for the [`shared`](super::FutureExt::shared) method. 15 #[must_use = "futures do nothing unless you `.await` or poll them"] 16 pub struct Shared<Fut: Future> { 17 inner: Option<Arc<Inner<Fut>>>, 18 waker_key: usize, 19 } 20 21 struct Inner<Fut: Future> { 22 future_or_output: UnsafeCell<FutureOrOutput<Fut>>, 23 notifier: Arc<Notifier>, 24 } 25 26 struct Notifier { 27 state: AtomicUsize, 28 wakers: Mutex<Option<Slab<Option<Waker>>>>, 29 } 30 31 /// A weak reference to a [`Shared`] that can be upgraded much like an `Arc`. 32 pub struct WeakShared<Fut: Future>(Weak<Inner<Fut>>); 33 34 impl<Fut: Future> Clone for WeakShared<Fut> { clone(&self) -> Self35 fn clone(&self) -> Self { 36 Self(self.0.clone()) 37 } 38 } 39 40 impl<Fut: Future> fmt::Debug for Shared<Fut> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result41 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 42 f.debug_struct("Shared") 43 .field("inner", &self.inner) 44 .field("waker_key", &self.waker_key) 45 .finish() 46 } 47 } 48 49 impl<Fut: Future> fmt::Debug for Inner<Fut> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result50 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 51 f.debug_struct("Inner").finish() 52 } 53 } 54 55 impl<Fut: Future> fmt::Debug for WeakShared<Fut> { fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { 57 f.debug_struct("WeakShared").finish() 58 } 59 } 60 61 enum FutureOrOutput<Fut: Future> { 62 Future(Fut), 63 Output(Fut::Output), 64 } 65 66 unsafe impl<Fut> Send for Inner<Fut> 67 where 68 Fut: Future + Send, 69 Fut::Output: Send + Sync, 70 { 71 } 72 73 unsafe impl<Fut> Sync for Inner<Fut> 74 where 75 Fut: Future + Send, 76 Fut::Output: Send + Sync, 77 { 78 } 79 80 const IDLE: usize = 0; 81 const POLLING: usize = 1; 82 const COMPLETE: usize = 2; 83 const POISONED: usize = 3; 84 85 const NULL_WAKER_KEY: usize = usize::MAX; 86 87 impl<Fut: Future> Shared<Fut> { new(future: Fut) -> Self88 pub(super) fn new(future: Fut) -> Self { 89 let inner = Inner { 90 future_or_output: UnsafeCell::new(FutureOrOutput::Future(future)), 91 notifier: Arc::new(Notifier { 92 state: AtomicUsize::new(IDLE), 93 wakers: Mutex::new(Some(Slab::new())), 94 }), 95 }; 96 97 Self { inner: Some(Arc::new(inner)), waker_key: NULL_WAKER_KEY } 98 } 99 } 100 101 impl<Fut> Shared<Fut> 102 where 103 Fut: Future, 104 { 105 /// Returns [`Some`] containing a reference to this [`Shared`]'s output if 106 /// it has already been computed by a clone or [`None`] if it hasn't been 107 /// computed yet or this [`Shared`] already returned its output from 108 /// [`poll`](Future::poll). peek(&self) -> Option<&Fut::Output>109 pub fn peek(&self) -> Option<&Fut::Output> { 110 if let Some(inner) = self.inner.as_ref() { 111 match inner.notifier.state.load(SeqCst) { 112 COMPLETE => unsafe { return Some(inner.output()) }, 113 POISONED => panic!("inner future panicked during poll"), 114 _ => {} 115 } 116 } 117 None 118 } 119 120 /// Creates a new [`WeakShared`] for this [`Shared`]. 121 /// 122 /// Returns [`None`] if it has already been polled to completion. downgrade(&self) -> Option<WeakShared<Fut>>123 pub fn downgrade(&self) -> Option<WeakShared<Fut>> { 124 if let Some(inner) = self.inner.as_ref() { 125 return Some(WeakShared(Arc::downgrade(inner))); 126 } 127 None 128 } 129 130 /// Gets the number of strong pointers to this allocation. 131 /// 132 /// Returns [`None`] if it has already been polled to completion. 133 /// 134 /// # Safety 135 /// 136 /// This method by itself is safe, but using it correctly requires extra care. Another thread 137 /// can change the strong count at any time, including potentially between calling this method 138 /// and acting on the result. 139 #[allow(clippy::unnecessary_safety_doc)] strong_count(&self) -> Option<usize>140 pub fn strong_count(&self) -> Option<usize> { 141 self.inner.as_ref().map(|arc| Arc::strong_count(arc)) 142 } 143 144 /// Gets the number of weak pointers to this allocation. 145 /// 146 /// Returns [`None`] if it has already been polled to completion. 147 /// 148 /// # Safety 149 /// 150 /// This method by itself is safe, but using it correctly requires extra care. Another thread 151 /// can change the weak count at any time, including potentially between calling this method 152 /// and acting on the result. 153 #[allow(clippy::unnecessary_safety_doc)] weak_count(&self) -> Option<usize>154 pub fn weak_count(&self) -> Option<usize> { 155 self.inner.as_ref().map(|arc| Arc::weak_count(arc)) 156 } 157 158 /// Hashes the internal state of this `Shared` in a way that's compatible with `ptr_eq`. ptr_hash<H: Hasher>(&self, state: &mut H)159 pub fn ptr_hash<H: Hasher>(&self, state: &mut H) { 160 match self.inner.as_ref() { 161 Some(arc) => { 162 state.write_u8(1); 163 ptr::hash(Arc::as_ptr(arc), state); 164 } 165 None => { 166 state.write_u8(0); 167 } 168 } 169 } 170 171 /// Returns `true` if the two `Shared`s point to the same future (in a vein similar to 172 /// `Arc::ptr_eq`). 173 /// 174 /// Returns `false` if either `Shared` has terminated. ptr_eq(&self, rhs: &Self) -> bool175 pub fn ptr_eq(&self, rhs: &Self) -> bool { 176 let lhs = match self.inner.as_ref() { 177 Some(lhs) => lhs, 178 None => return false, 179 }; 180 let rhs = match rhs.inner.as_ref() { 181 Some(rhs) => rhs, 182 None => return false, 183 }; 184 Arc::ptr_eq(lhs, rhs) 185 } 186 } 187 188 impl<Fut> Inner<Fut> 189 where 190 Fut: Future, 191 { 192 /// Safety: callers must first ensure that `self.inner.state` 193 /// is `COMPLETE` output(&self) -> &Fut::Output194 unsafe fn output(&self) -> &Fut::Output { 195 match unsafe { &*self.future_or_output.get() } { 196 FutureOrOutput::Output(item) => item, 197 FutureOrOutput::Future(_) => unreachable!(), 198 } 199 } 200 } 201 202 impl<Fut> Inner<Fut> 203 where 204 Fut: Future, 205 Fut::Output: Clone, 206 { 207 /// Registers the current task to receive a wakeup when we are awoken. record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>)208 fn record_waker(&self, waker_key: &mut usize, cx: &mut Context<'_>) { 209 let mut wakers_guard = self.notifier.wakers.lock().unwrap(); 210 211 let wakers = match wakers_guard.as_mut() { 212 Some(wakers) => wakers, 213 None => return, 214 }; 215 216 let new_waker = cx.waker(); 217 218 if *waker_key == NULL_WAKER_KEY { 219 *waker_key = wakers.insert(Some(new_waker.clone())); 220 } else { 221 match wakers[*waker_key] { 222 Some(ref old_waker) if new_waker.will_wake(old_waker) => {} 223 // Could use clone_from here, but Waker doesn't specialize it. 224 ref mut slot => *slot = Some(new_waker.clone()), 225 } 226 } 227 debug_assert!(*waker_key != NULL_WAKER_KEY); 228 } 229 230 /// Safety: callers must first ensure that `inner.state` 231 /// is `COMPLETE` take_or_clone_output(self: Arc<Self>) -> Fut::Output232 unsafe fn take_or_clone_output(self: Arc<Self>) -> Fut::Output { 233 match Arc::try_unwrap(self) { 234 Ok(inner) => match inner.future_or_output.into_inner() { 235 FutureOrOutput::Output(item) => item, 236 FutureOrOutput::Future(_) => unreachable!(), 237 }, 238 Err(inner) => unsafe { inner.output().clone() }, 239 } 240 } 241 } 242 243 impl<Fut> FusedFuture for Shared<Fut> 244 where 245 Fut: Future, 246 Fut::Output: Clone, 247 { is_terminated(&self) -> bool248 fn is_terminated(&self) -> bool { 249 self.inner.is_none() 250 } 251 } 252 253 impl<Fut> Future for Shared<Fut> 254 where 255 Fut: Future, 256 Fut::Output: Clone, 257 { 258 type Output = Fut::Output; 259 poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>260 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 261 let this = &mut *self; 262 263 let inner = this.inner.take().expect("Shared future polled again after completion"); 264 265 // Fast path for when the wrapped future has already completed 266 if inner.notifier.state.load(Acquire) == COMPLETE { 267 // Safety: We're in the COMPLETE state 268 return unsafe { Poll::Ready(inner.take_or_clone_output()) }; 269 } 270 271 inner.record_waker(&mut this.waker_key, cx); 272 273 match inner 274 .notifier 275 .state 276 .compare_exchange(IDLE, POLLING, SeqCst, SeqCst) 277 .unwrap_or_else(|x| x) 278 { 279 IDLE => { 280 // Lock acquired, fall through 281 } 282 POLLING => { 283 // Another task is currently polling, at this point we just want 284 // to ensure that the waker for this task is registered 285 this.inner = Some(inner); 286 return Poll::Pending; 287 } 288 COMPLETE => { 289 // Safety: We're in the COMPLETE state 290 return unsafe { Poll::Ready(inner.take_or_clone_output()) }; 291 } 292 POISONED => panic!("inner future panicked during poll"), 293 _ => unreachable!(), 294 } 295 296 let waker = waker_ref(&inner.notifier); 297 let mut cx = Context::from_waker(&waker); 298 299 struct Reset<'a> { 300 state: &'a AtomicUsize, 301 did_not_panic: bool, 302 } 303 304 impl Drop for Reset<'_> { 305 fn drop(&mut self) { 306 if !self.did_not_panic { 307 self.state.store(POISONED, SeqCst); 308 } 309 } 310 } 311 312 let mut reset = Reset { state: &inner.notifier.state, did_not_panic: false }; 313 314 let output = { 315 let future = unsafe { 316 match &mut *inner.future_or_output.get() { 317 FutureOrOutput::Future(fut) => Pin::new_unchecked(fut), 318 _ => unreachable!(), 319 } 320 }; 321 322 let poll_result = future.poll(&mut cx); 323 reset.did_not_panic = true; 324 325 match poll_result { 326 Poll::Pending => { 327 if inner.notifier.state.compare_exchange(POLLING, IDLE, SeqCst, SeqCst).is_ok() 328 { 329 // Success 330 drop(reset); 331 this.inner = Some(inner); 332 return Poll::Pending; 333 } else { 334 unreachable!() 335 } 336 } 337 Poll::Ready(output) => output, 338 } 339 }; 340 341 unsafe { 342 *inner.future_or_output.get() = FutureOrOutput::Output(output); 343 } 344 345 inner.notifier.state.store(COMPLETE, SeqCst); 346 347 // Wake all tasks and drop the slab 348 let mut wakers_guard = inner.notifier.wakers.lock().unwrap(); 349 let mut wakers = wakers_guard.take().unwrap(); 350 for waker in wakers.drain().flatten() { 351 waker.wake(); 352 } 353 354 drop(reset); // Make borrow checker happy 355 drop(wakers_guard); 356 357 // Safety: We're in the COMPLETE state 358 unsafe { Poll::Ready(inner.take_or_clone_output()) } 359 } 360 } 361 362 impl<Fut> Clone for Shared<Fut> 363 where 364 Fut: Future, 365 { clone(&self) -> Self366 fn clone(&self) -> Self { 367 Self { inner: self.inner.clone(), waker_key: NULL_WAKER_KEY } 368 } 369 } 370 371 impl<Fut> Drop for Shared<Fut> 372 where 373 Fut: Future, 374 { drop(&mut self)375 fn drop(&mut self) { 376 if self.waker_key != NULL_WAKER_KEY { 377 if let Some(ref inner) = self.inner { 378 if let Ok(mut wakers) = inner.notifier.wakers.lock() { 379 if let Some(wakers) = wakers.as_mut() { 380 wakers.remove(self.waker_key); 381 } 382 } 383 } 384 } 385 } 386 } 387 388 impl ArcWake for Notifier { wake_by_ref(arc_self: &Arc<Self>)389 fn wake_by_ref(arc_self: &Arc<Self>) { 390 let wakers = &mut *arc_self.wakers.lock().unwrap(); 391 if let Some(wakers) = wakers.as_mut() { 392 for (_key, opt_waker) in wakers { 393 if let Some(waker) = opt_waker.take() { 394 waker.wake(); 395 } 396 } 397 } 398 } 399 } 400 401 impl<Fut: Future> WeakShared<Fut> { 402 /// Attempts to upgrade this [`WeakShared`] into a [`Shared`]. 403 /// 404 /// Returns [`None`] if all clones of the [`Shared`] have been dropped or polled 405 /// to completion. upgrade(&self) -> Option<Shared<Fut>>406 pub fn upgrade(&self) -> Option<Shared<Fut>> { 407 Some(Shared { inner: Some(self.0.upgrade()?), waker_key: NULL_WAKER_KEY }) 408 } 409 } 410