// Copyright 2023 The ChromiumOS Authors // Use of this source code is governed by a BSD-style license that can be // found in the LICENSE file. use std::future::Future; use std::pin::Pin; use std::sync::Arc; use std::sync::OnceLock; use base::AsRawDescriptors; use base::RawDescriptor; use tokio::runtime::Runtime; use tokio::task::LocalSet; use crate::sys::platform::tokio_source::TokioSource; use crate::AsyncError; use crate::AsyncResult; use crate::ExecutorTrait; use crate::IntoAsync; use crate::IoSource; use crate::TaskHandle; mod send_wrapper { use std::thread; #[derive(Clone)] pub(super) struct SendWrapper { instance: T, thread_id: thread::ThreadId, } impl SendWrapper { pub(super) fn new(instance: T) -> SendWrapper { SendWrapper { instance, thread_id: thread::current().id(), } } } // SAFETY: panics when the value is accessed on the wrong thread. unsafe impl Send for SendWrapper {} // SAFETY: panics when the value is accessed on the wrong thread. unsafe impl Sync for SendWrapper {} impl Drop for SendWrapper { fn drop(&mut self) { if self.thread_id != thread::current().id() { panic!("SendWrapper value was dropped on the wrong thread"); } } } impl std::ops::Deref for SendWrapper { type Target = T; fn deref(&self) -> &T { if self.thread_id != thread::current().id() { panic!("SendWrapper value was accessed on the wrong thread"); } &self.instance } } } #[derive(Clone)] pub struct TokioExecutor { runtime: Arc, local_set: Arc>>, } impl TokioExecutor { pub fn new() -> AsyncResult { Ok(TokioExecutor { runtime: Arc::new(Runtime::new().map_err(AsyncError::Io)?), local_set: Arc::new(OnceLock::new()), }) } } impl ExecutorTrait for TokioExecutor { fn async_from<'a, F: IntoAsync + 'a>(&self, f: F) -> AsyncResult> { Ok(IoSource::Tokio(TokioSource::new( f, self.runtime.handle().clone(), )?)) } fn run_until(&self, f: F) -> AsyncResult { let local_set = self .local_set .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new())); Ok(self .runtime .block_on(async { local_set.run_until(f).await })) } fn spawn(&self, f: F) -> TaskHandle where F: Future + Send + 'static, F::Output: Send + 'static, { TaskHandle::Tokio(TokioTaskHandle { join_handle: Some(self.runtime.spawn(f)), }) } fn spawn_blocking(&self, f: F) -> TaskHandle where F: FnOnce() -> R + Send + 'static, R: Send + 'static, { TaskHandle::Tokio(TokioTaskHandle { join_handle: Some(self.runtime.spawn_blocking(f)), }) } fn spawn_local(&self, f: F) -> TaskHandle where F: Future + 'static, F::Output: 'static, { let local_set = self .local_set .get_or_init(|| send_wrapper::SendWrapper::new(LocalSet::new())); TaskHandle::Tokio(TokioTaskHandle { join_handle: Some(local_set.spawn_local(f)), }) } } impl AsRawDescriptors for TokioExecutor { fn as_raw_descriptors(&self) -> Vec { todo!(); } } pub struct TokioTaskHandle { join_handle: Option>, } impl TokioTaskHandle { pub async fn cancel(mut self) -> Option { match self.join_handle.take() { Some(handle) => { handle.abort(); handle.await.ok() } None => None, } } pub fn detach(mut self) { self.join_handle.take(); } } impl Future for TokioTaskHandle { type Output = R; fn poll(self: Pin<&mut Self>, cx: &mut std::task::Context) -> std::task::Poll { let self_mut = self.get_mut(); Pin::new(self_mut.join_handle.as_mut().unwrap()) .poll(cx) .map(|v| v.unwrap()) } } impl std::ops::Drop for TokioTaskHandle { fn drop(&mut self) { if let Some(handle) = self.join_handle.take() { handle.abort() } } }