1 use futures_core::future::{FusedFuture, Future};
2 use futures_core::stream::{FusedStream, Stream};
3 use futures_core::task::{Context, Poll};
4 use futures_io::{
5     self as io, AsyncBufRead, AsyncRead, AsyncSeek, AsyncWrite, IoSlice, IoSliceMut, SeekFrom,
6 };
7 use futures_sink::Sink;
8 use pin_project::{pin_project, pinned_drop};
9 use std::pin::Pin;
10 use std::thread::panicking;
11 
12 /// Combinator that asserts that the underlying type is not moved after being polled.
13 ///
14 /// See the `assert_unmoved` methods on:
15 /// * [`FutureTestExt`](crate::future::FutureTestExt::assert_unmoved)
16 /// * [`StreamTestExt`](crate::stream::StreamTestExt::assert_unmoved)
17 /// * [`SinkTestExt`](crate::sink::SinkTestExt::assert_unmoved_sink)
18 /// * [`AsyncReadTestExt`](crate::io::AsyncReadTestExt::assert_unmoved)
19 /// * [`AsyncWriteTestExt`](crate::io::AsyncWriteTestExt::assert_unmoved_write)
20 #[pin_project(PinnedDrop, !Unpin)]
21 #[derive(Debug, Clone)]
22 #[must_use = "futures do nothing unless you `.await` or poll them"]
23 pub struct AssertUnmoved<T> {
24     #[pin]
25     inner: T,
26     this_addr: usize,
27 }
28 
29 impl<T> AssertUnmoved<T> {
new(inner: T) -> Self30     pub(crate) fn new(inner: T) -> Self {
31         Self { inner, this_addr: 0 }
32     }
33 
poll_with<'a, U>(mut self: Pin<&'a mut Self>, f: impl FnOnce(Pin<&'a mut T>) -> U) -> U34     fn poll_with<'a, U>(mut self: Pin<&'a mut Self>, f: impl FnOnce(Pin<&'a mut T>) -> U) -> U {
35         let cur_this = &*self as *const Self as usize;
36         if self.this_addr == 0 {
37             // First time being polled
38             *self.as_mut().project().this_addr = cur_this;
39         } else {
40             assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved between poll calls");
41         }
42         f(self.project().inner)
43     }
44 }
45 
46 impl<Fut: Future> Future for AssertUnmoved<Fut> {
47     type Output = Fut::Output;
48 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>49     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
50         self.poll_with(|f| f.poll(cx))
51     }
52 }
53 
54 impl<Fut: FusedFuture> FusedFuture for AssertUnmoved<Fut> {
is_terminated(&self) -> bool55     fn is_terminated(&self) -> bool {
56         self.inner.is_terminated()
57     }
58 }
59 
60 impl<St: Stream> Stream for AssertUnmoved<St> {
61     type Item = St::Item;
62 
poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>>63     fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
64         self.poll_with(|s| s.poll_next(cx))
65     }
66 }
67 
68 impl<St: FusedStream> FusedStream for AssertUnmoved<St> {
is_terminated(&self) -> bool69     fn is_terminated(&self) -> bool {
70         self.inner.is_terminated()
71     }
72 }
73 
74 impl<Si: Sink<Item>, Item> Sink<Item> for AssertUnmoved<Si> {
75     type Error = Si::Error;
76 
poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>77     fn poll_ready(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
78         self.poll_with(|s| s.poll_ready(cx))
79     }
80 
start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error>81     fn start_send(self: Pin<&mut Self>, item: Item) -> Result<(), Self::Error> {
82         self.poll_with(|s| s.start_send(item))
83     }
84 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>85     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
86         self.poll_with(|s| s.poll_flush(cx))
87     }
88 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>>89     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
90         self.poll_with(|s| s.poll_close(cx))
91     }
92 }
93 
94 impl<R: AsyncRead> AsyncRead for AssertUnmoved<R> {
poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll<io::Result<usize>>95     fn poll_read(
96         self: Pin<&mut Self>,
97         cx: &mut Context<'_>,
98         buf: &mut [u8],
99     ) -> Poll<io::Result<usize>> {
100         self.poll_with(|r| r.poll_read(cx, buf))
101     }
102 
poll_read_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &mut [IoSliceMut<'_>], ) -> Poll<io::Result<usize>>103     fn poll_read_vectored(
104         self: Pin<&mut Self>,
105         cx: &mut Context<'_>,
106         bufs: &mut [IoSliceMut<'_>],
107     ) -> Poll<io::Result<usize>> {
108         self.poll_with(|r| r.poll_read_vectored(cx, bufs))
109     }
110 }
111 
112 impl<W: AsyncWrite> AsyncWrite for AssertUnmoved<W> {
poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll<io::Result<usize>>113     fn poll_write(
114         self: Pin<&mut Self>,
115         cx: &mut Context<'_>,
116         buf: &[u8],
117     ) -> Poll<io::Result<usize>> {
118         self.poll_with(|w| w.poll_write(cx, buf))
119     }
120 
poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[IoSlice<'_>], ) -> Poll<io::Result<usize>>121     fn poll_write_vectored(
122         self: Pin<&mut Self>,
123         cx: &mut Context<'_>,
124         bufs: &[IoSlice<'_>],
125     ) -> Poll<io::Result<usize>> {
126         self.poll_with(|w| w.poll_write_vectored(cx, bufs))
127     }
128 
poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>129     fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
130         self.poll_with(|w| w.poll_flush(cx))
131     }
132 
poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>>133     fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
134         self.poll_with(|w| w.poll_close(cx))
135     }
136 }
137 
138 impl<S: AsyncSeek> AsyncSeek for AssertUnmoved<S> {
poll_seek( self: Pin<&mut Self>, cx: &mut Context<'_>, pos: SeekFrom, ) -> Poll<io::Result<u64>>139     fn poll_seek(
140         self: Pin<&mut Self>,
141         cx: &mut Context<'_>,
142         pos: SeekFrom,
143     ) -> Poll<io::Result<u64>> {
144         self.poll_with(|s| s.poll_seek(cx, pos))
145     }
146 }
147 
148 impl<R: AsyncBufRead> AsyncBufRead for AssertUnmoved<R> {
poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>>149     fn poll_fill_buf(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<&[u8]>> {
150         self.poll_with(|r| r.poll_fill_buf(cx))
151     }
152 
consume(self: Pin<&mut Self>, amt: usize)153     fn consume(self: Pin<&mut Self>, amt: usize) {
154         self.poll_with(|r| r.consume(amt))
155     }
156 }
157 
158 #[pinned_drop]
159 impl<T> PinnedDrop for AssertUnmoved<T> {
drop(self: Pin<&mut Self>)160     fn drop(self: Pin<&mut Self>) {
161         // If the thread is panicking then we can't panic again as that will
162         // cause the process to be aborted.
163         if !panicking() && self.this_addr != 0 {
164             let cur_this = &*self as *const Self as usize;
165             assert_eq!(self.this_addr, cur_this, "AssertUnmoved moved before drop");
166         }
167     }
168 }
169 
170 #[cfg(test)]
171 mod tests {
172     use futures_core::future::Future;
173     use futures_core::task::{Context, Poll};
174     use futures_util::future::pending;
175     use futures_util::task::noop_waker;
176     use std::pin::Pin;
177 
178     use super::AssertUnmoved;
179 
180     #[test]
assert_send_sync()181     fn assert_send_sync() {
182         fn assert<T: Send + Sync>() {}
183         assert::<AssertUnmoved<()>>();
184     }
185 
186     #[test]
dont_panic_when_not_polled()187     fn dont_panic_when_not_polled() {
188         // This shouldn't panic.
189         let future = AssertUnmoved::new(pending::<()>());
190         drop(future);
191     }
192 
193     #[test]
194     #[should_panic(expected = "AssertUnmoved moved between poll calls")]
dont_double_panic()195     fn dont_double_panic() {
196         // This test should only panic, not abort the process.
197         let waker = noop_waker();
198         let mut cx = Context::from_waker(&waker);
199 
200         // First we allocate the future on the stack and poll it.
201         let mut future = AssertUnmoved::new(pending::<()>());
202         let pinned_future = unsafe { Pin::new_unchecked(&mut future) };
203         assert_eq!(pinned_future.poll(&mut cx), Poll::Pending);
204 
205         // Next we move it back to the heap and poll it again. This second call
206         // should panic (as the future is moved), but we shouldn't panic again
207         // whilst dropping `AssertUnmoved`.
208         let mut future = Box::new(future);
209         let pinned_boxed_future = unsafe { Pin::new_unchecked(&mut *future) };
210         assert_eq!(pinned_boxed_future.poll(&mut cx), Poll::Pending);
211     }
212 }
213