1 use crate::runtime::task::{
2     self, unowned, Id, JoinHandle, OwnedTasks, Schedule, Task, TaskHarnessScheduleHooks,
3 };
4 use crate::runtime::tests::NoopSchedule;
5 
6 use std::collections::VecDeque;
7 use std::future::Future;
8 use std::sync::atomic::{AtomicBool, Ordering};
9 use std::sync::{Arc, Mutex};
10 
11 struct AssertDropHandle {
12     is_dropped: Arc<AtomicBool>,
13 }
14 impl AssertDropHandle {
15     #[track_caller]
assert_dropped(&self)16     fn assert_dropped(&self) {
17         assert!(self.is_dropped.load(Ordering::SeqCst));
18     }
19 
20     #[track_caller]
assert_not_dropped(&self)21     fn assert_not_dropped(&self) {
22         assert!(!self.is_dropped.load(Ordering::SeqCst));
23     }
24 }
25 
26 struct AssertDrop {
27     is_dropped: Arc<AtomicBool>,
28 }
29 impl AssertDrop {
new() -> (Self, AssertDropHandle)30     fn new() -> (Self, AssertDropHandle) {
31         let shared = Arc::new(AtomicBool::new(false));
32         (
33             AssertDrop {
34                 is_dropped: shared.clone(),
35             },
36             AssertDropHandle {
37                 is_dropped: shared.clone(),
38             },
39         )
40     }
41 }
42 impl Drop for AssertDrop {
drop(&mut self)43     fn drop(&mut self) {
44         self.is_dropped.store(true, Ordering::SeqCst);
45     }
46 }
47 
48 // A Notified does not shut down on drop, but it is dropped once the ref-count
49 // hits zero.
50 #[test]
create_drop1()51 fn create_drop1() {
52     let (ad, handle) = AssertDrop::new();
53     let (notified, join) = unowned(
54         async {
55             drop(ad);
56             unreachable!()
57         },
58         NoopSchedule,
59         Id::next(),
60     );
61     drop(notified);
62     handle.assert_not_dropped();
63     drop(join);
64     handle.assert_dropped();
65 }
66 
67 #[test]
create_drop2()68 fn create_drop2() {
69     let (ad, handle) = AssertDrop::new();
70     let (notified, join) = unowned(
71         async {
72             drop(ad);
73             unreachable!()
74         },
75         NoopSchedule,
76         Id::next(),
77     );
78     drop(join);
79     handle.assert_not_dropped();
80     drop(notified);
81     handle.assert_dropped();
82 }
83 
84 #[test]
drop_abort_handle1()85 fn drop_abort_handle1() {
86     let (ad, handle) = AssertDrop::new();
87     let (notified, join) = unowned(
88         async {
89             drop(ad);
90             unreachable!()
91         },
92         NoopSchedule,
93         Id::next(),
94     );
95     let abort = join.abort_handle();
96     drop(join);
97     handle.assert_not_dropped();
98     drop(notified);
99     handle.assert_not_dropped();
100     drop(abort);
101     handle.assert_dropped();
102 }
103 
104 #[test]
drop_abort_handle2()105 fn drop_abort_handle2() {
106     let (ad, handle) = AssertDrop::new();
107     let (notified, join) = unowned(
108         async {
109             drop(ad);
110             unreachable!()
111         },
112         NoopSchedule,
113         Id::next(),
114     );
115     let abort = join.abort_handle();
116     drop(notified);
117     handle.assert_not_dropped();
118     drop(abort);
119     handle.assert_not_dropped();
120     drop(join);
121     handle.assert_dropped();
122 }
123 
124 #[test]
drop_abort_handle_clone()125 fn drop_abort_handle_clone() {
126     let (ad, handle) = AssertDrop::new();
127     let (notified, join) = unowned(
128         async {
129             drop(ad);
130             unreachable!()
131         },
132         NoopSchedule,
133         Id::next(),
134     );
135     let abort = join.abort_handle();
136     let abort_clone = abort.clone();
137     drop(join);
138     handle.assert_not_dropped();
139     drop(notified);
140     handle.assert_not_dropped();
141     drop(abort);
142     handle.assert_not_dropped();
143     drop(abort_clone);
144     handle.assert_dropped();
145 }
146 
147 // Shutting down through Notified works
148 #[test]
create_shutdown1()149 fn create_shutdown1() {
150     let (ad, handle) = AssertDrop::new();
151     let (notified, join) = unowned(
152         async {
153             drop(ad);
154             unreachable!()
155         },
156         NoopSchedule,
157         Id::next(),
158     );
159     drop(join);
160     handle.assert_not_dropped();
161     notified.shutdown();
162     handle.assert_dropped();
163 }
164 
165 #[test]
create_shutdown2()166 fn create_shutdown2() {
167     let (ad, handle) = AssertDrop::new();
168     let (notified, join) = unowned(
169         async {
170             drop(ad);
171             unreachable!()
172         },
173         NoopSchedule,
174         Id::next(),
175     );
176     handle.assert_not_dropped();
177     notified.shutdown();
178     handle.assert_dropped();
179     drop(join);
180 }
181 
182 #[test]
unowned_poll()183 fn unowned_poll() {
184     let (task, _) = unowned(async {}, NoopSchedule, Id::next());
185     task.run();
186 }
187 
188 #[test]
schedule()189 fn schedule() {
190     with(|rt| {
191         rt.spawn(async {
192             crate::task::yield_now().await;
193         });
194 
195         assert_eq!(2, rt.tick());
196         rt.shutdown();
197     })
198 }
199 
200 #[test]
shutdown()201 fn shutdown() {
202     with(|rt| {
203         rt.spawn(async {
204             loop {
205                 crate::task::yield_now().await;
206             }
207         });
208 
209         rt.tick_max(1);
210 
211         rt.shutdown();
212     })
213 }
214 
215 #[test]
shutdown_immediately()216 fn shutdown_immediately() {
217     with(|rt| {
218         rt.spawn(async {
219             loop {
220                 crate::task::yield_now().await;
221             }
222         });
223 
224         rt.shutdown();
225     })
226 }
227 
228 // Test for https://github.com/tokio-rs/tokio/issues/6729
229 #[test]
spawn_niche_in_task()230 fn spawn_niche_in_task() {
231     use std::future::poll_fn;
232     use std::task::{Context, Poll, Waker};
233 
234     with(|rt| {
235         let state = Arc::new(Mutex::new(State::new()));
236 
237         let mut subscriber = Subscriber::new(Arc::clone(&state), 1);
238         rt.spawn(async move {
239             subscriber.wait().await;
240             subscriber.wait().await;
241         });
242 
243         rt.spawn(async move {
244             state.lock().unwrap().set_version(2);
245             state.lock().unwrap().set_version(0);
246         });
247 
248         rt.tick_max(10);
249         assert!(rt.is_empty());
250         rt.shutdown();
251     });
252 
253     pub(crate) struct Subscriber {
254         state: Arc<Mutex<State>>,
255         observed_version: u64,
256         waker_key: Option<usize>,
257     }
258 
259     impl Subscriber {
260         pub(crate) fn new(state: Arc<Mutex<State>>, version: u64) -> Self {
261             Self {
262                 state,
263                 observed_version: version,
264                 waker_key: None,
265             }
266         }
267 
268         pub(crate) async fn wait(&mut self) {
269             poll_fn(|cx| {
270                 self.state
271                     .lock()
272                     .unwrap()
273                     .poll_update(&mut self.observed_version, &mut self.waker_key, cx)
274                     .map(|_| ())
275             })
276             .await;
277         }
278     }
279 
280     struct State {
281         version: u64,
282         wakers: Vec<Waker>,
283     }
284 
285     impl State {
286         pub(crate) fn new() -> Self {
287             Self {
288                 version: 1,
289                 wakers: Vec::new(),
290             }
291         }
292 
293         pub(crate) fn poll_update(
294             &mut self,
295             observed_version: &mut u64,
296             waker_key: &mut Option<usize>,
297             cx: &Context<'_>,
298         ) -> Poll<Option<()>> {
299             if self.version == 0 {
300                 *waker_key = None;
301                 Poll::Ready(None)
302             } else if *observed_version < self.version {
303                 *waker_key = None;
304                 *observed_version = self.version;
305                 Poll::Ready(Some(()))
306             } else {
307                 self.wakers.push(cx.waker().clone());
308                 *waker_key = Some(self.wakers.len());
309                 Poll::Pending
310             }
311         }
312 
313         pub(crate) fn set_version(&mut self, version: u64) {
314             self.version = version;
315             for waker in self.wakers.drain(..) {
316                 waker.wake();
317             }
318         }
319     }
320 }
321 
322 #[test]
spawn_during_shutdown()323 fn spawn_during_shutdown() {
324     static DID_SPAWN: AtomicBool = AtomicBool::new(false);
325 
326     struct SpawnOnDrop(Runtime);
327     impl Drop for SpawnOnDrop {
328         fn drop(&mut self) {
329             DID_SPAWN.store(true, Ordering::SeqCst);
330             self.0.spawn(async {});
331         }
332     }
333 
334     with(|rt| {
335         let rt2 = rt.clone();
336         rt.spawn(async move {
337             let _spawn_on_drop = SpawnOnDrop(rt2);
338 
339             loop {
340                 crate::task::yield_now().await;
341             }
342         });
343 
344         rt.tick_max(1);
345         rt.shutdown();
346     });
347 
348     assert!(DID_SPAWN.load(Ordering::SeqCst));
349 }
350 
with(f: impl FnOnce(Runtime))351 fn with(f: impl FnOnce(Runtime)) {
352     struct Reset;
353 
354     impl Drop for Reset {
355         fn drop(&mut self) {
356             let _rt = CURRENT.try_lock().unwrap().take();
357         }
358     }
359 
360     let _reset = Reset;
361 
362     let rt = Runtime(Arc::new(Inner {
363         owned: OwnedTasks::new(16),
364         core: Mutex::new(Core {
365             queue: VecDeque::new(),
366         }),
367     }));
368 
369     *CURRENT.try_lock().unwrap() = Some(rt.clone());
370     f(rt)
371 }
372 
373 #[derive(Clone)]
374 struct Runtime(Arc<Inner>);
375 
376 struct Inner {
377     core: Mutex<Core>,
378     owned: OwnedTasks<Runtime>,
379 }
380 
381 struct Core {
382     queue: VecDeque<task::Notified<Runtime>>,
383 }
384 
385 static CURRENT: Mutex<Option<Runtime>> = Mutex::new(None);
386 
387 impl Runtime {
spawn<T>(&self, future: T) -> JoinHandle<T::Output> where T: 'static + Send + Future, T::Output: 'static + Send,388     fn spawn<T>(&self, future: T) -> JoinHandle<T::Output>
389     where
390         T: 'static + Send + Future,
391         T::Output: 'static + Send,
392     {
393         let (handle, notified) = self.0.owned.bind(future, self.clone(), Id::next());
394 
395         if let Some(notified) = notified {
396             self.schedule(notified);
397         }
398 
399         handle
400     }
401 
tick(&self) -> usize402     fn tick(&self) -> usize {
403         self.tick_max(usize::MAX)
404     }
405 
tick_max(&self, max: usize) -> usize406     fn tick_max(&self, max: usize) -> usize {
407         let mut n = 0;
408 
409         while !self.is_empty() && n < max {
410             let task = self.next_task();
411             n += 1;
412             let task = self.0.owned.assert_owner(task);
413             task.run();
414         }
415 
416         n
417     }
418 
is_empty(&self) -> bool419     fn is_empty(&self) -> bool {
420         self.0.core.try_lock().unwrap().queue.is_empty()
421     }
422 
next_task(&self) -> task::Notified<Runtime>423     fn next_task(&self) -> task::Notified<Runtime> {
424         self.0.core.try_lock().unwrap().queue.pop_front().unwrap()
425     }
426 
shutdown(&self)427     fn shutdown(&self) {
428         let mut core = self.0.core.try_lock().unwrap();
429 
430         self.0.owned.close_and_shutdown_all(0);
431 
432         while let Some(task) = core.queue.pop_back() {
433             drop(task);
434         }
435 
436         drop(core);
437         assert!(self.0.owned.is_empty());
438     }
439 }
440 
441 impl Schedule for Runtime {
release(&self, task: &Task<Self>) -> Option<Task<Self>>442     fn release(&self, task: &Task<Self>) -> Option<Task<Self>> {
443         self.0.owned.remove(task)
444     }
445 
schedule(&self, task: task::Notified<Self>)446     fn schedule(&self, task: task::Notified<Self>) {
447         self.0.core.try_lock().unwrap().queue.push_back(task);
448     }
449 
hooks(&self) -> TaskHarnessScheduleHooks450     fn hooks(&self) -> TaskHarnessScheduleHooks {
451         TaskHarnessScheduleHooks {
452             task_terminate_callback: None,
453         }
454     }
455 }
456