1 //! This module has containers for storing the tasks spawned on a scheduler. The
2 //! `OwnedTasks` container is thread-safe but can only store tasks that
3 //! implement Send. The `LocalOwnedTasks` container is not thread safe, but can
4 //! store non-Send tasks.
5 //!
6 //! The collections can be closed to prevent adding new tasks during shutdown of
7 //! the scheduler with the collection.
8 
9 use crate::future::Future;
10 use crate::loom::cell::UnsafeCell;
11 use crate::runtime::task::{JoinHandle, LocalNotified, Notified, Schedule, Task};
12 use crate::util::linked_list::{Link, LinkedList};
13 use crate::util::sharded_list;
14 
15 use crate::loom::sync::atomic::{AtomicBool, Ordering};
16 use std::marker::PhantomData;
17 use std::num::NonZeroU64;
18 
19 // The id from the module below is used to verify whether a given task is stored
20 // in this OwnedTasks, or some other task. The counter starts at one so we can
21 // use `None` for tasks not owned by any list.
22 //
23 // The safety checks in this file can technically be violated if the counter is
24 // overflown, but the checks are not supposed to ever fail unless there is a
25 // bug in Tokio, so we accept that certain bugs would not be caught if the two
26 // mixed up runtimes happen to have the same id.
27 
28 cfg_has_atomic_u64! {
29     use std::sync::atomic::AtomicU64;
30 
31     static NEXT_OWNED_TASKS_ID: AtomicU64 = AtomicU64::new(1);
32 
33     fn get_next_id() -> NonZeroU64 {
34         loop {
35             let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
36             if let Some(id) = NonZeroU64::new(id) {
37                 return id;
38             }
39         }
40     }
41 }
42 
43 cfg_not_has_atomic_u64! {
44     use std::sync::atomic::AtomicU32;
45 
46     static NEXT_OWNED_TASKS_ID: AtomicU32 = AtomicU32::new(1);
47 
48     fn get_next_id() -> NonZeroU64 {
49         loop {
50             let id = NEXT_OWNED_TASKS_ID.fetch_add(1, Ordering::Relaxed);
51             if let Some(id) = NonZeroU64::new(u64::from(id)) {
52                 return id;
53             }
54         }
55     }
56 }
57 
58 pub(crate) struct OwnedTasks<S: 'static> {
59     list: List<S>,
60     pub(crate) id: NonZeroU64,
61     closed: AtomicBool,
62 }
63 
64 type List<S> = sharded_list::ShardedList<Task<S>, <Task<S> as Link>::Target>;
65 
66 pub(crate) struct LocalOwnedTasks<S: 'static> {
67     inner: UnsafeCell<OwnedTasksInner<S>>,
68     pub(crate) id: NonZeroU64,
69     _not_send_or_sync: PhantomData<*const ()>,
70 }
71 
72 struct OwnedTasksInner<S: 'static> {
73     list: LinkedList<Task<S>, <Task<S> as Link>::Target>,
74     closed: bool,
75 }
76 
77 impl<S: 'static> OwnedTasks<S> {
new(num_cores: usize) -> Self78     pub(crate) fn new(num_cores: usize) -> Self {
79         let shard_size = Self::gen_shared_list_size(num_cores);
80         Self {
81             list: List::new(shard_size),
82             closed: AtomicBool::new(false),
83             id: get_next_id(),
84         }
85     }
86 
87     /// Binds the provided task to this `OwnedTasks` instance. This fails if the
88     /// `OwnedTasks` has been closed.
bind<T>( &self, task: T, scheduler: S, id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + Send + 'static, T::Output: Send + 'static,89     pub(crate) fn bind<T>(
90         &self,
91         task: T,
92         scheduler: S,
93         id: super::Id,
94     ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
95     where
96         S: Schedule,
97         T: Future + Send + 'static,
98         T::Output: Send + 'static,
99     {
100         let (task, notified, join) = super::new_task(task, scheduler, id);
101         let notified = unsafe { self.bind_inner(task, notified) };
102         (join, notified)
103     }
104 
105     /// Bind a task that isn't safe to transfer across thread boundaries.
106     ///
107     /// # Safety
108     /// Only use this in `LocalRuntime` where the task cannot move
bind_local<T>( &self, task: T, scheduler: S, id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + 'static, T::Output: 'static,109     pub(crate) unsafe fn bind_local<T>(
110         &self,
111         task: T,
112         scheduler: S,
113         id: super::Id,
114     ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
115     where
116         S: Schedule,
117         T: Future + 'static,
118         T::Output: 'static,
119     {
120         let (task, notified, join) = super::new_task(task, scheduler, id);
121         let notified = unsafe { self.bind_inner(task, notified) };
122         (join, notified)
123     }
124 
125     /// The part of `bind` that's the same for every type of future.
bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>> where S: Schedule,126     unsafe fn bind_inner(&self, task: Task<S>, notified: Notified<S>) -> Option<Notified<S>>
127     where
128         S: Schedule,
129     {
130         unsafe {
131             // safety: We just created the task, so we have exclusive access
132             // to the field.
133             task.header().set_owner_id(self.id);
134         }
135 
136         let shard = self.list.lock_shard(&task);
137         // Check the closed flag in the lock for ensuring all that tasks
138         // will shut down after the OwnedTasks has been closed.
139         if self.closed.load(Ordering::Acquire) {
140             drop(shard);
141             task.shutdown();
142             return None;
143         }
144         shard.push(task);
145         Some(notified)
146     }
147 
148     /// Asserts that the given task is owned by this `OwnedTasks` and convert it to
149     /// a `LocalNotified`, giving the thread permission to poll this task.
150     #[inline]
assert_owner(&self, task: Notified<S>) -> LocalNotified<S>151     pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
152         debug_assert_eq!(task.header().get_owner_id(), Some(self.id));
153         // safety: All tasks bound to this OwnedTasks are Send, so it is safe
154         // to poll it on this thread no matter what thread we are on.
155         LocalNotified {
156             task: task.0,
157             _not_send: PhantomData,
158         }
159     }
160 
161     /// Shuts down all tasks in the collection. This call also closes the
162     /// collection, preventing new items from being added.
163     ///
164     /// The parameter start determines which shard this method will start at.
165     /// Using different values for each worker thread reduces contention.
close_and_shutdown_all(&self, start: usize) where S: Schedule,166     pub(crate) fn close_and_shutdown_all(&self, start: usize)
167     where
168         S: Schedule,
169     {
170         self.closed.store(true, Ordering::Release);
171         for i in start..self.get_shard_size() + start {
172             loop {
173                 let task = self.list.pop_back(i);
174                 match task {
175                     Some(task) => {
176                         task.shutdown();
177                     }
178                     None => break,
179                 }
180             }
181         }
182     }
183 
184     #[inline]
get_shard_size(&self) -> usize185     pub(crate) fn get_shard_size(&self) -> usize {
186         self.list.shard_size()
187     }
188 
num_alive_tasks(&self) -> usize189     pub(crate) fn num_alive_tasks(&self) -> usize {
190         self.list.len()
191     }
192 
193     cfg_64bit_metrics! {
194         pub(crate) fn spawned_tasks_count(&self) -> u64 {
195             self.list.added()
196         }
197     }
198 
remove(&self, task: &Task<S>) -> Option<Task<S>>199     pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
200         // If the task's owner ID is `None` then it is not part of any list and
201         // doesn't need removing.
202         let task_id = task.header().get_owner_id()?;
203 
204         assert_eq!(task_id, self.id);
205 
206         // safety: We just checked that the provided task is not in some other
207         // linked list.
208         unsafe { self.list.remove(task.header_ptr()) }
209     }
210 
is_empty(&self) -> bool211     pub(crate) fn is_empty(&self) -> bool {
212         self.list.is_empty()
213     }
214 
215     /// Generates the size of the sharded list based on the number of worker threads.
216     ///
217     /// The sharded lock design can effectively alleviate
218     /// lock contention performance problems caused by high concurrency.
219     ///
220     /// However, as the number of shards increases, the memory continuity between
221     /// nodes in the intrusive linked list will diminish. Furthermore,
222     /// the construction time of the sharded list will also increase with a higher number of shards.
223     ///
224     /// Due to the above reasons, we set a maximum value for the shared list size,
225     /// denoted as `MAX_SHARED_LIST_SIZE`.
gen_shared_list_size(num_cores: usize) -> usize226     fn gen_shared_list_size(num_cores: usize) -> usize {
227         const MAX_SHARED_LIST_SIZE: usize = 1 << 16;
228         usize::min(MAX_SHARED_LIST_SIZE, num_cores.next_power_of_two() * 4)
229     }
230 }
231 
232 cfg_taskdump! {
233     impl<S: 'static> OwnedTasks<S> {
234         /// Locks the tasks, and calls `f` on an iterator over them.
235         pub(crate) fn for_each<F>(&self, f: F)
236         where
237             F: FnMut(&Task<S>),
238         {
239             self.list.for_each(f);
240         }
241     }
242 }
243 
244 impl<S: 'static> LocalOwnedTasks<S> {
new() -> Self245     pub(crate) fn new() -> Self {
246         Self {
247             inner: UnsafeCell::new(OwnedTasksInner {
248                 list: LinkedList::new(),
249                 closed: false,
250             }),
251             id: get_next_id(),
252             _not_send_or_sync: PhantomData,
253         }
254     }
255 
bind<T>( &self, task: T, scheduler: S, id: super::Id, ) -> (JoinHandle<T::Output>, Option<Notified<S>>) where S: Schedule, T: Future + 'static, T::Output: 'static,256     pub(crate) fn bind<T>(
257         &self,
258         task: T,
259         scheduler: S,
260         id: super::Id,
261     ) -> (JoinHandle<T::Output>, Option<Notified<S>>)
262     where
263         S: Schedule,
264         T: Future + 'static,
265         T::Output: 'static,
266     {
267         let (task, notified, join) = super::new_task(task, scheduler, id);
268 
269         unsafe {
270             // safety: We just created the task, so we have exclusive access
271             // to the field.
272             task.header().set_owner_id(self.id);
273         }
274 
275         if self.is_closed() {
276             drop(notified);
277             task.shutdown();
278             (join, None)
279         } else {
280             self.with_inner(|inner| {
281                 inner.list.push_front(task);
282             });
283             (join, Some(notified))
284         }
285     }
286 
287     /// Shuts down all tasks in the collection. This call also closes the
288     /// collection, preventing new items from being added.
close_and_shutdown_all(&self) where S: Schedule,289     pub(crate) fn close_and_shutdown_all(&self)
290     where
291         S: Schedule,
292     {
293         self.with_inner(|inner| inner.closed = true);
294 
295         while let Some(task) = self.with_inner(|inner| inner.list.pop_back()) {
296             task.shutdown();
297         }
298     }
299 
remove(&self, task: &Task<S>) -> Option<Task<S>>300     pub(crate) fn remove(&self, task: &Task<S>) -> Option<Task<S>> {
301         // If the task's owner ID is `None` then it is not part of any list and
302         // doesn't need removing.
303         let task_id = task.header().get_owner_id()?;
304 
305         assert_eq!(task_id, self.id);
306 
307         self.with_inner(|inner|
308             // safety: We just checked that the provided task is not in some
309             // other linked list.
310             unsafe { inner.list.remove(task.header_ptr()) })
311     }
312 
313     /// Asserts that the given task is owned by this `LocalOwnedTasks` and convert
314     /// it to a `LocalNotified`, giving the thread permission to poll this task.
315     #[inline]
assert_owner(&self, task: Notified<S>) -> LocalNotified<S>316     pub(crate) fn assert_owner(&self, task: Notified<S>) -> LocalNotified<S> {
317         assert_eq!(task.header().get_owner_id(), Some(self.id));
318 
319         // safety: The task was bound to this LocalOwnedTasks, and the
320         // LocalOwnedTasks is not Send or Sync, so we are on the right thread
321         // for polling this task.
322         LocalNotified {
323             task: task.0,
324             _not_send: PhantomData,
325         }
326     }
327 
328     #[inline]
with_inner<F, T>(&self, f: F) -> T where F: FnOnce(&mut OwnedTasksInner<S>) -> T,329     fn with_inner<F, T>(&self, f: F) -> T
330     where
331         F: FnOnce(&mut OwnedTasksInner<S>) -> T,
332     {
333         // safety: This type is not Sync, so concurrent calls of this method
334         // can't happen.  Furthermore, all uses of this method in this file make
335         // sure that they don't call `with_inner` recursively.
336         self.inner.with_mut(|ptr| unsafe { f(&mut *ptr) })
337     }
338 
is_closed(&self) -> bool339     pub(crate) fn is_closed(&self) -> bool {
340         self.with_inner(|inner| inner.closed)
341     }
342 
is_empty(&self) -> bool343     pub(crate) fn is_empty(&self) -> bool {
344         self.with_inner(|inner| inner.list.is_empty())
345     }
346 }
347 
348 #[cfg(test)]
349 mod tests {
350     use super::*;
351 
352     // This test may run in parallel with other tests, so we only test that ids
353     // come in increasing order.
354     #[test]
test_id_not_broken()355     fn test_id_not_broken() {
356         let mut last_id = get_next_id();
357 
358         for _ in 0..1000 {
359             let next_id = get_next_id();
360             assert!(last_id < next_id);
361             last_id = next_id;
362         }
363     }
364 }
365