1 use std::future::Future;
2 use std::mem;
3 use std::pin::Pin;
4 use std::task::{Context, Poll};
5 
6 use pin_project_lite::pin_project;
7 use tokio::sync::watch;
8 
channel() -> (Signal, Watch)9 pub(crate) fn channel() -> (Signal, Watch) {
10     let (tx, rx) = watch::channel(());
11     (Signal { tx }, Watch { rx })
12 }
13 
14 pub(crate) struct Signal {
15     tx: watch::Sender<()>,
16 }
17 
18 pub(crate) struct Draining(Pin<Box<dyn Future<Output = ()> + Send + Sync>>);
19 
20 #[derive(Clone)]
21 pub(crate) struct Watch {
22     rx: watch::Receiver<()>,
23 }
24 
25 pin_project! {
26     #[allow(missing_debug_implementations)]
27     pub struct Watching<F, FN> {
28         #[pin]
29         future: F,
30         state: State<FN>,
31         watch: Pin<Box<dyn Future<Output = ()> + Send + Sync>>,
32         _rx: watch::Receiver<()>,
33     }
34 }
35 
36 enum State<F> {
37     Watch(F),
38     Draining,
39 }
40 
41 impl Signal {
drain(self) -> Draining42     pub(crate) fn drain(self) -> Draining {
43         let _ = self.tx.send(());
44         Draining(Box::pin(async move { self.tx.closed().await }))
45     }
46 }
47 
48 impl Future for Draining {
49     type Output = ();
50 
poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>51     fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
52         Pin::new(&mut self.as_mut().0).poll(cx)
53     }
54 }
55 
56 impl Watch {
watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN> where F: Future, FN: FnOnce(Pin<&mut F>),57     pub(crate) fn watch<F, FN>(self, future: F, on_drain: FN) -> Watching<F, FN>
58     where
59         F: Future,
60         FN: FnOnce(Pin<&mut F>),
61     {
62         let Self { mut rx } = self;
63         let _rx = rx.clone();
64         Watching {
65             future,
66             state: State::Watch(on_drain),
67             watch: Box::pin(async move {
68                 let _ = rx.changed().await;
69             }),
70             // Keep the receiver alive until the future completes, so that
71             // dropping it can signal that draining has completed.
72             _rx,
73         }
74     }
75 }
76 
77 impl<F, FN> Future for Watching<F, FN>
78 where
79     F: Future,
80     FN: FnOnce(Pin<&mut F>),
81 {
82     type Output = F::Output;
83 
poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output>84     fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
85         let mut me = self.project();
86         loop {
87             match mem::replace(me.state, State::Draining) {
88                 State::Watch(on_drain) => {
89                     match Pin::new(&mut me.watch).poll(cx) {
90                         Poll::Ready(()) => {
91                             // Drain has been triggered!
92                             on_drain(me.future.as_mut());
93                         }
94                         Poll::Pending => {
95                             *me.state = State::Watch(on_drain);
96                             return me.future.poll(cx);
97                         }
98                     }
99                 }
100                 State::Draining => return me.future.poll(cx),
101             }
102         }
103     }
104 }
105 
106 #[cfg(test)]
107 mod tests {
108     use super::*;
109 
110     struct TestMe {
111         draining: bool,
112         finished: bool,
113         poll_cnt: usize,
114     }
115 
116     impl Future for TestMe {
117         type Output = ();
118 
poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output>119         fn poll(mut self: Pin<&mut Self>, _: &mut Context<'_>) -> Poll<Self::Output> {
120             self.poll_cnt += 1;
121             if self.finished {
122                 Poll::Ready(())
123             } else {
124                 Poll::Pending
125             }
126         }
127     }
128 
129     #[test]
watch()130     fn watch() {
131         let mut mock = tokio_test::task::spawn(());
132         mock.enter(|cx, _| {
133             let (tx, rx) = channel();
134             let fut = TestMe {
135                 draining: false,
136                 finished: false,
137                 poll_cnt: 0,
138             };
139 
140             let mut watch = rx.watch(fut, |mut fut| {
141                 fut.draining = true;
142             });
143 
144             assert_eq!(watch.future.poll_cnt, 0);
145 
146             // First poll should poll the inner future
147             assert!(Pin::new(&mut watch).poll(cx).is_pending());
148             assert_eq!(watch.future.poll_cnt, 1);
149 
150             // Second poll should poll the inner future again
151             assert!(Pin::new(&mut watch).poll(cx).is_pending());
152             assert_eq!(watch.future.poll_cnt, 2);
153 
154             let mut draining = tx.drain();
155             // Drain signaled, but needs another poll to be noticed.
156             assert!(!watch.future.draining);
157             assert_eq!(watch.future.poll_cnt, 2);
158 
159             // Now, poll after drain has been signaled.
160             assert!(Pin::new(&mut watch).poll(cx).is_pending());
161             assert_eq!(watch.future.poll_cnt, 3);
162             assert!(watch.future.draining);
163 
164             // Draining is not ready until watcher completes
165             assert!(Pin::new(&mut draining).poll(cx).is_pending());
166 
167             // Finishing up the watch future
168             watch.future.finished = true;
169             assert!(Pin::new(&mut watch).poll(cx).is_ready());
170             assert_eq!(watch.future.poll_cnt, 4);
171             drop(watch);
172 
173             assert!(Pin::new(&mut draining).poll(cx).is_ready());
174         })
175     }
176 
177     #[test]
watch_clones()178     fn watch_clones() {
179         let mut mock = tokio_test::task::spawn(());
180         mock.enter(|cx, _| {
181             let (tx, rx) = channel();
182 
183             let fut1 = TestMe {
184                 draining: false,
185                 finished: false,
186                 poll_cnt: 0,
187             };
188             let fut2 = TestMe {
189                 draining: false,
190                 finished: false,
191                 poll_cnt: 0,
192             };
193 
194             let watch1 = rx.clone().watch(fut1, |mut fut| {
195                 fut.draining = true;
196             });
197             let watch2 = rx.watch(fut2, |mut fut| {
198                 fut.draining = true;
199             });
200 
201             let mut draining = tx.drain();
202 
203             // Still 2 outstanding watchers
204             assert!(Pin::new(&mut draining).poll(cx).is_pending());
205 
206             // drop 1 for whatever reason
207             drop(watch1);
208 
209             // Still not ready, 1 other watcher still pending
210             assert!(Pin::new(&mut draining).poll(cx).is_pending());
211 
212             drop(watch2);
213 
214             // Now all watchers are gone, draining is complete
215             assert!(Pin::new(&mut draining).poll(cx).is_ready());
216         });
217     }
218 }
219