1 use crate::future::Future;
2 use crate::runtime::task::core::{Core, Trailer};
3 use crate::runtime::task::{Cell, Harness, Header, Id, Schedule, State};
4 
5 use std::ptr::NonNull;
6 use std::task::{Poll, Waker};
7 
8 /// Raw task handle
9 #[derive(Clone)]
10 pub(crate) struct RawTask {
11     ptr: NonNull<Header>,
12 }
13 
14 pub(super) struct Vtable {
15     /// Polls the future.
16     pub(super) poll: unsafe fn(NonNull<Header>),
17 
18     /// Schedules the task for execution on the runtime.
19     pub(super) schedule: unsafe fn(NonNull<Header>),
20 
21     /// Deallocates the memory.
22     pub(super) dealloc: unsafe fn(NonNull<Header>),
23 
24     /// Reads the task output, if complete.
25     pub(super) try_read_output: unsafe fn(NonNull<Header>, *mut (), &Waker),
26 
27     /// The join handle has been dropped.
28     pub(super) drop_join_handle_slow: unsafe fn(NonNull<Header>),
29 
30     /// An abort handle has been dropped.
31     pub(super) drop_abort_handle: unsafe fn(NonNull<Header>),
32 
33     /// Scheduler is being shutdown.
34     pub(super) shutdown: unsafe fn(NonNull<Header>),
35 
36     /// The number of bytes that the `trailer` field is offset from the header.
37     pub(super) trailer_offset: usize,
38 
39     /// The number of bytes that the `scheduler` field is offset from the header.
40     pub(super) scheduler_offset: usize,
41 
42     /// The number of bytes that the `id` field is offset from the header.
43     pub(super) id_offset: usize,
44 }
45 
46 /// Get the vtable for the requested `T` and `S` generics.
vtable<T: Future, S: Schedule>() -> &'static Vtable47 pub(super) fn vtable<T: Future, S: Schedule>() -> &'static Vtable {
48     &Vtable {
49         poll: poll::<T, S>,
50         schedule: schedule::<S>,
51         dealloc: dealloc::<T, S>,
52         try_read_output: try_read_output::<T, S>,
53         drop_join_handle_slow: drop_join_handle_slow::<T, S>,
54         drop_abort_handle: drop_abort_handle::<T, S>,
55         shutdown: shutdown::<T, S>,
56         trailer_offset: OffsetHelper::<T, S>::TRAILER_OFFSET,
57         scheduler_offset: OffsetHelper::<T, S>::SCHEDULER_OFFSET,
58         id_offset: OffsetHelper::<T, S>::ID_OFFSET,
59     }
60 }
61 
62 /// Calling `get_trailer_offset` directly in vtable doesn't work because it
63 /// prevents the vtable from being promoted to a static reference.
64 ///
65 /// See this thread for more info:
66 /// <https://users.rust-lang.org/t/custom-vtables-with-integers/78508>
67 struct OffsetHelper<T, S>(T, S);
68 impl<T: Future, S: Schedule> OffsetHelper<T, S> {
69     // Pass `size_of`/`align_of` as arguments rather than calling them directly
70     // inside `get_trailer_offset` because trait bounds on generic parameters
71     // of const fn are unstable on our MSRV.
72     const TRAILER_OFFSET: usize = get_trailer_offset(
73         std::mem::size_of::<Header>(),
74         std::mem::size_of::<Core<T, S>>(),
75         std::mem::align_of::<Core<T, S>>(),
76         std::mem::align_of::<Trailer>(),
77     );
78 
79     // The `scheduler` is the first field of `Core`, so it has the same
80     // offset as `Core`.
81     const SCHEDULER_OFFSET: usize = get_core_offset(
82         std::mem::size_of::<Header>(),
83         std::mem::align_of::<Core<T, S>>(),
84     );
85 
86     const ID_OFFSET: usize = get_id_offset(
87         std::mem::size_of::<Header>(),
88         std::mem::align_of::<Core<T, S>>(),
89         std::mem::size_of::<S>(),
90         std::mem::align_of::<Id>(),
91     );
92 }
93 
94 /// Compute the offset of the `Trailer` field in `Cell<T, S>` using the
95 /// `#[repr(C)]` algorithm.
96 ///
97 /// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
98 /// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
get_trailer_offset( header_size: usize, core_size: usize, core_align: usize, trailer_align: usize, ) -> usize99 const fn get_trailer_offset(
100     header_size: usize,
101     core_size: usize,
102     core_align: usize,
103     trailer_align: usize,
104 ) -> usize {
105     let mut offset = header_size;
106 
107     let core_misalign = offset % core_align;
108     if core_misalign > 0 {
109         offset += core_align - core_misalign;
110     }
111     offset += core_size;
112 
113     let trailer_misalign = offset % trailer_align;
114     if trailer_misalign > 0 {
115         offset += trailer_align - trailer_misalign;
116     }
117 
118     offset
119 }
120 
121 /// Compute the offset of the `Core<T, S>` field in `Cell<T, S>` using the
122 /// `#[repr(C)]` algorithm.
123 ///
124 /// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
125 /// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
get_core_offset(header_size: usize, core_align: usize) -> usize126 const fn get_core_offset(header_size: usize, core_align: usize) -> usize {
127     let mut offset = header_size;
128 
129     let core_misalign = offset % core_align;
130     if core_misalign > 0 {
131         offset += core_align - core_misalign;
132     }
133 
134     offset
135 }
136 
137 /// Compute the offset of the `Id` field in `Cell<T, S>` using the
138 /// `#[repr(C)]` algorithm.
139 ///
140 /// Pseudo-code for the `#[repr(C)]` algorithm can be found here:
141 /// <https://doc.rust-lang.org/reference/type-layout.html#reprc-structs>
get_id_offset( header_size: usize, core_align: usize, scheduler_size: usize, id_align: usize, ) -> usize142 const fn get_id_offset(
143     header_size: usize,
144     core_align: usize,
145     scheduler_size: usize,
146     id_align: usize,
147 ) -> usize {
148     let mut offset = get_core_offset(header_size, core_align);
149     offset += scheduler_size;
150 
151     let id_misalign = offset % id_align;
152     if id_misalign > 0 {
153         offset += id_align - id_misalign;
154     }
155 
156     offset
157 }
158 
159 impl RawTask {
new<T, S>(task: T, scheduler: S, id: Id) -> RawTask where T: Future, S: Schedule,160     pub(super) fn new<T, S>(task: T, scheduler: S, id: Id) -> RawTask
161     where
162         T: Future,
163         S: Schedule,
164     {
165         let ptr = Box::into_raw(Cell::<_, S>::new(task, scheduler, State::new(), id));
166         let ptr = unsafe { NonNull::new_unchecked(ptr.cast()) };
167 
168         RawTask { ptr }
169     }
170 
from_raw(ptr: NonNull<Header>) -> RawTask171     pub(super) unsafe fn from_raw(ptr: NonNull<Header>) -> RawTask {
172         RawTask { ptr }
173     }
174 
header_ptr(&self) -> NonNull<Header>175     pub(super) fn header_ptr(&self) -> NonNull<Header> {
176         self.ptr
177     }
178 
trailer_ptr(&self) -> NonNull<Trailer>179     pub(super) fn trailer_ptr(&self) -> NonNull<Trailer> {
180         unsafe { Header::get_trailer(self.ptr) }
181     }
182 
183     /// Returns a reference to the task's header.
header(&self) -> &Header184     pub(super) fn header(&self) -> &Header {
185         unsafe { self.ptr.as_ref() }
186     }
187 
188     /// Returns a reference to the task's trailer.
trailer(&self) -> &Trailer189     pub(super) fn trailer(&self) -> &Trailer {
190         unsafe { &*self.trailer_ptr().as_ptr() }
191     }
192 
193     /// Returns a reference to the task's state.
state(&self) -> &State194     pub(super) fn state(&self) -> &State {
195         &self.header().state
196     }
197 
198     /// Safety: mutual exclusion is required to call this function.
poll(self)199     pub(crate) fn poll(self) {
200         let vtable = self.header().vtable;
201         unsafe { (vtable.poll)(self.ptr) }
202     }
203 
schedule(self)204     pub(super) fn schedule(self) {
205         let vtable = self.header().vtable;
206         unsafe { (vtable.schedule)(self.ptr) }
207     }
208 
dealloc(self)209     pub(super) fn dealloc(self) {
210         let vtable = self.header().vtable;
211         unsafe {
212             (vtable.dealloc)(self.ptr);
213         }
214     }
215 
216     /// Safety: `dst` must be a `*mut Poll<super::Result<T::Output>>` where `T`
217     /// is the future stored by the task.
try_read_output(self, dst: *mut (), waker: &Waker)218     pub(super) unsafe fn try_read_output(self, dst: *mut (), waker: &Waker) {
219         let vtable = self.header().vtable;
220         (vtable.try_read_output)(self.ptr, dst, waker);
221     }
222 
drop_join_handle_slow(self)223     pub(super) fn drop_join_handle_slow(self) {
224         let vtable = self.header().vtable;
225         unsafe { (vtable.drop_join_handle_slow)(self.ptr) }
226     }
227 
drop_abort_handle(self)228     pub(super) fn drop_abort_handle(self) {
229         let vtable = self.header().vtable;
230         unsafe { (vtable.drop_abort_handle)(self.ptr) }
231     }
232 
shutdown(self)233     pub(super) fn shutdown(self) {
234         let vtable = self.header().vtable;
235         unsafe { (vtable.shutdown)(self.ptr) }
236     }
237 
238     /// Increment the task's reference count.
239     ///
240     /// Currently, this is used only when creating an `AbortHandle`.
ref_inc(self)241     pub(super) fn ref_inc(self) {
242         self.header().state.ref_inc();
243     }
244 
245     /// Get the queue-next pointer
246     ///
247     /// This is for usage by the injection queue
248     ///
249     /// Safety: make sure only one queue uses this and access is synchronized.
get_queue_next(self) -> Option<RawTask>250     pub(crate) unsafe fn get_queue_next(self) -> Option<RawTask> {
251         self.header()
252             .queue_next
253             .with(|ptr| *ptr)
254             .map(|p| RawTask::from_raw(p))
255     }
256 
257     /// Sets the queue-next pointer
258     ///
259     /// This is for usage by the injection queue
260     ///
261     /// Safety: make sure only one queue uses this and access is synchronized.
set_queue_next(self, val: Option<RawTask>)262     pub(crate) unsafe fn set_queue_next(self, val: Option<RawTask>) {
263         self.header().set_next(val.map(|task| task.ptr));
264     }
265 }
266 
267 impl Copy for RawTask {}
268 
poll<T: Future, S: Schedule>(ptr: NonNull<Header>)269 unsafe fn poll<T: Future, S: Schedule>(ptr: NonNull<Header>) {
270     let harness = Harness::<T, S>::from_raw(ptr);
271     harness.poll();
272 }
273 
schedule<S: Schedule>(ptr: NonNull<Header>)274 unsafe fn schedule<S: Schedule>(ptr: NonNull<Header>) {
275     use crate::runtime::task::{Notified, Task};
276 
277     let scheduler = Header::get_scheduler::<S>(ptr);
278     scheduler
279         .as_ref()
280         .schedule(Notified(Task::from_raw(ptr.cast())));
281 }
282 
dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>)283 unsafe fn dealloc<T: Future, S: Schedule>(ptr: NonNull<Header>) {
284     let harness = Harness::<T, S>::from_raw(ptr);
285     harness.dealloc();
286 }
287 
try_read_output<T: Future, S: Schedule>( ptr: NonNull<Header>, dst: *mut (), waker: &Waker, )288 unsafe fn try_read_output<T: Future, S: Schedule>(
289     ptr: NonNull<Header>,
290     dst: *mut (),
291     waker: &Waker,
292 ) {
293     let out = &mut *(dst as *mut Poll<super::Result<T::Output>>);
294 
295     let harness = Harness::<T, S>::from_raw(ptr);
296     harness.try_read_output(out, waker);
297 }
298 
drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>)299 unsafe fn drop_join_handle_slow<T: Future, S: Schedule>(ptr: NonNull<Header>) {
300     let harness = Harness::<T, S>::from_raw(ptr);
301     harness.drop_join_handle_slow();
302 }
303 
drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>)304 unsafe fn drop_abort_handle<T: Future, S: Schedule>(ptr: NonNull<Header>) {
305     let harness = Harness::<T, S>::from_raw(ptr);
306     harness.drop_reference();
307 }
308 
shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>)309 unsafe fn shutdown<T: Future, S: Schedule>(ptr: NonNull<Header>) {
310     let harness = Harness::<T, S>::from_raw(ptr);
311     harness.shutdown();
312 }
313