1 use futures_core::future::{FusedFuture, Future}; 2 use futures_core::stream::{FusedStream, Stream}; 3 use futures_core::task::{Context, Poll}; 4 use futures_io::{ 5 self as io, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom, 6 }; 7 use futures_sink::Sink; 8 use pin_project::{pin_project, pinned_drop}; 9 use std::pin::Pin; 10 use std::thread::panicking; 11 12 /// Combinator that asserts that the underlying type is not moved after being polled. 13 /// 14 /// See the `assert_unmoved` methods on: 15 /// * [`FutureTestExt`](crate::future::FutureTestExt::assert_unmoved) 16 /// * [`StreamTestExt`](crate::stream::StreamTestExt::assert_unmoved) 17 /// * [`SinkTestExt`](crate::sink::SinkTestExt::assert_unmoved_sink) 18 /// * [`AsyncReadTestExt`](crate::io::AsyncReadTestExt::assert_unmoved) 19 /// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::assert_unmoved_write) 20 #[pin_project(PinnedDrop, !Unpin)] 21 #[derive(Debug, Clone)] 22 #[must_use = "futures do nothing unless you `.await` or poll them"] 23 pub struct AssertUnmoved<T> { 24 #[pin] 25 inner: T, 26 this_addr: usize, 27 } 28 29 impl<T> AssertUnmoved<T> { new(inner: T) -> Self30 pub(crate) fn new(inner: T) -> Self { 31 Self { inner, this_addr: 0 } 32 } 33 poll_with<'a, U>(mut self: Pin<&'a mut Self>, f: impl FnOnce(Pin<&'a mut T>) -> U) -> U34 fn poll_with<'a, U>(mut self: Pin<&'a mut Self>, f: impl FnOnce(Pin<&'a mut T>) -> U) -> U { 35 let cur_this = &*self as *const Self as usize; 36 if self.this_addr == 0 { 37 // First time being polled 38 *self.as_mut().project().this_addr = cur_this; 39 } else { 40 assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved between poll calls"); 41 } 42 f(self.project().inner) 43 } 44 } 45 46 impl<Fut: Future> Future for AssertUnmoved<Fut> { 47 type Output = Fut::Output; 48 poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>49 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> { 50 self.poll_with(|f| f.poll(cx)) 51 } 52 } 53 54 impl<Fut: FusedFuture> FusedFuture for AssertUnmoved<Fut> { is_terminated(&self) -> bool55 fn is_terminated(&self) -> bool { 56 self.inner.is_terminated() 57 } 58 } 59 60 impl<St: Stream> Stream for AssertUnmoved<St> { 61 type Item = St::Item; 62 poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>63 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> { 64 self.poll_with(|s| s.poll_next(cx)) 65 } 66 } 67 68 impl<St: FusedStream> FusedStream for AssertUnmoved<St> { is_terminated(&self) -> bool69 fn is_terminated(&self) -> bool { 70 self.inner.is_terminated() 71 } 72 } 73 74 impl<Si: Sink<Item>, Item> Sink<Item> for AssertUnmoved<Si> { 75 type Error = Si::Error; 76 poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>77 fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 78 self.poll_with(|s| s.poll_ready(cx)) 79 } 80 start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error>81 fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> { 82 self.poll_with(|s| s.start_send(item)) 83 } 84 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>85 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 86 self.poll_with(|s| s.poll_flush(cx)) 87 } 88 poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>89 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> { 90 self.poll_with(|s| s.poll_close(cx)) 91 } 92 } 93 94 impl<R: AsyncRead> AsyncRead for AssertUnmoved<R> { poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>95 fn poll_read( 96 self: Pin<&mut Self>, 97 cx: &mut Context<'_>, 98 buf: &mut [u8], 99 ) -> Poll<io::Result<usize>> { 100 self.poll_with(|r| r.poll_read(cx, buf)) 101 } 102 poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll<io::Result<usize>>103 fn poll_read_vectored( 104 self: Pin<&mut Self>, 105 cx: &mut Context<'_>, 106 bufs: &mut [IoSliceMut<'_>], 107 ) -> Poll<io::Result<usize>> { 108 self.poll_with(|r| r.poll_read_vectored(cx, bufs)) 109 } 110 } 111 112 impl<W: AsyncWrite> AsyncWrite for AssertUnmoved<W> { poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>113 fn poll_write( 114 self: Pin<&mut Self>, 115 cx: &mut Context<'_>, 116 buf: &[u8], 117 ) -> Poll<io::Result<usize>> { 118 self.poll_with(|w| w.poll_write(cx, buf)) 119 } 120 poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<io::Result<usize>>121 fn poll_write_vectored( 122 self: Pin<&mut Self>, 123 cx: &mut Context<'_>, 124 bufs: &[IoSlice<'_>], 125 ) -> Poll<io::Result<usize>> { 126 self.poll_with(|w| w.poll_write_vectored(cx, bufs)) 127 } 128 poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>129 fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 130 self.poll_with(|w| w.poll_flush(cx)) 131 } 132 poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>133 fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> { 134 self.poll_with(|w| w.poll_close(cx)) 135 } 136 } 137 138 impl<S: AsyncSeek> AsyncSeek for AssertUnmoved<S> { poll_seek( self: Pin<&mut Self>, cx: &mut Context<'_>, pos: SeekFrom, ) -> Poll<io::Result<u64>>139 fn poll_seek( 140 self: Pin<&mut Self>, 141 cx: &mut Context<'_>, 142 pos: SeekFrom, 143 ) -> Poll<io::Result<u64>> { 144 self.poll_with(|s| s.poll_seek(cx, pos)) 145 } 146 } 147 148 impl<R: AsyncBufRead> AsyncBufRead for AssertUnmoved<R> { poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>149 fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> { 150 self.poll_with(|r| r.poll_fill_buf(cx)) 151 } 152 consume(self: Pin<&mut Self>, amt: usize)153 fn consume(self: Pin<&mut Self>, amt: usize) { 154 self.poll_with(|r| r.consume(amt)) 155 } 156 } 157 158 #[pinned_drop] 159 impl<T> PinnedDrop for AssertUnmoved<T> { drop(self: Pin<&mut Self>)160 fn drop(self: Pin<&mut Self>) { 161 // If the thread is panicking then we can't panic again as that will 162 // cause the process to be aborted. 163 if !panicking() && self.this_addr != 0 { 164 let cur_this = &*self as *const Self as usize; 165 assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved before drop"); 166 } 167 } 168 } 169 170 #[cfg(test)] 171 mod tests { 172 use futures_core::future::Future; 173 use futures_core::task::{Context, Poll}; 174 use futures_util::future::pending; 175 use futures_util::task::noop_waker; 176 use std::pin::Pin; 177 178 use super::AssertUnmoved; 179 180 #[test] assert_send_sync()181 fn assert_send_sync() { 182 fn assert<T: Send + Sync>() {} 183 assert::<AssertUnmoved<()>>(); 184 } 185 186 #[test] dont_panic_when_not_polled()187 fn dont_panic_when_not_polled() { 188 // This shouldn't panic. 189 let future = AssertUnmoved::new(pending::<()>()); 190 drop(future); 191 } 192 193 #[test] 194 #[should_panic(expected = "AssertUnmoved moved between poll calls")] dont_double_panic()195 fn dont_double_panic() { 196 // This test should only panic, not abort the process. 197 let waker = noop_waker(); 198 let mut cx = Context::from_waker(&waker); 199 200 // First we allocate the future on the stack and poll it. 201 let mut future = AssertUnmoved::new(pending::<()>()); 202 let pinned_future = unsafe { Pin::new_unchecked(&mut future) }; 203 assert_eq!(pinned_future.poll(&mut cx), Poll::Pending); 204 205 // Next we move it back to the heap and poll it again. This second call 206 // should panic (as the future is moved), but we shouldn't panic again 207 // whilst dropping `AssertUnmoved`. 208 let mut future = Box::new(future); 209 let pinned_boxed_future = unsafe { Pin::new_unchecked(&mut *future) }; 210 assert_eq!(pinned_boxed_future.poll(&mut cx), Poll::Pending); 211 } 212 } 213