xref: /aosp_15_r20/external/crosvm/base_tokio/src/sys/windows/tube.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use base::warn;
6 use base::AsRawDescriptor;
7 use base::Descriptor;
8 use base::Error;
9 use base::Event;
10 use base::Tube;
11 use base::TubeError;
12 use base::TubeResult;
13 use tokio::sync::mpsc;
14 use tokio::sync::oneshot;
15 use winapi::um::ioapiset::CancelIoEx;
16 
17 /// An async version of `Tube`.
18 ///
19 /// Implementation note: We don't trust `base::Tube::recv` to behave in a non-blocking manner even
20 /// when the read notifier is signalled, so we offload the actual `send` and `recv` calls onto a
21 /// blocking thread.
22 pub struct TubeTokio {
23     worker: tokio::task::JoinHandle<Tube>,
24     cmd_tx: mpsc::Sender<Box<dyn FnOnce(&Tube) + Send>>,
25     // Clone of the tube's read notifier.
26     read_notifier: Event,
27     // Tube's RawDescriptor.
28     tube_descriptor: Descriptor,
29 }
30 
31 impl TubeTokio {
new(mut tube: Tube) -> anyhow::Result<Self>32     pub fn new(mut tube: Tube) -> anyhow::Result<Self> {
33         let read_notifier = tube.get_read_notifier_event().try_clone()?;
34         let tube_descriptor = Descriptor(tube.as_raw_descriptor());
35 
36         let (cmd_tx, mut cmd_rx) = mpsc::channel::<Box<dyn FnOnce(&Tube) + Send>>(1);
37         let worker = tokio::task::spawn_blocking(move || {
38             while let Some(f) = cmd_rx.blocking_recv() {
39                 f(&mut tube)
40             }
41             tube
42         });
43         Ok(Self {
44             worker,
45             cmd_tx,
46             read_notifier,
47             tube_descriptor,
48         })
49     }
50 
into_inner(self) -> Tube51     pub async fn into_inner(self) -> Tube {
52         drop(self.cmd_tx);
53 
54         // Attempt to cancel any blocking IO the worker thread is doing so that we don't get stuck
55         // here if a `recv` call blocked. This is racy since we don't know if the queue'd up IO
56         // requests have actually started yet.
57         //
58         // SAFETY: The descriptor should still be valid since we own the tube in the blocking task.
59         if unsafe { CancelIoEx(self.tube_descriptor.0, std::ptr::null_mut()) } == 0 {
60             warn!(
61                 "Cancel IO for handle:{:?} failed with {}",
62                 self.tube_descriptor.0,
63                 Error::last()
64             );
65         }
66 
67         self.worker.await.expect("failed to join tube worker")
68     }
69 
send<T: serde::Serialize + Send + 'static>(&mut self, msg: T) -> TubeResult<()>70     pub async fn send<T: serde::Serialize + Send + 'static>(&mut self, msg: T) -> TubeResult<()> {
71         // It is unlikely the tube is full given crosvm usage patterns, so request the blocking
72         // send immediately.
73         let (tx, rx) = oneshot::channel();
74         self.cmd_tx
75             .send(Box::new(move |tube| {
76                 let _ = tx.send(tube.send(&msg));
77             }))
78             .await
79             .expect("worker missing");
80         rx.await.map_err(|_| TubeError::OperationCancelled)??;
81         Ok(())
82     }
83 
recv<T: serde::de::DeserializeOwned + Send + 'static>(&mut self) -> TubeResult<T>84     pub async fn recv<T: serde::de::DeserializeOwned + Send + 'static>(&mut self) -> TubeResult<T> {
85         // `Tube`'s read notifier event is a manual-reset event and `Tube::recv` wants to
86         // handle the reset, so we bypass `EventAsync`.
87         base::sys::windows::async_wait_for_single_object(&self.read_notifier)
88             .await
89             .map_err(TubeError::Recv)?;
90 
91         let (tx, rx) = oneshot::channel();
92         self.cmd_tx
93             .send(Box::new(move |tube| {
94                 let _ = tx.send(tube.recv());
95             }))
96             .await
97             .expect("worker missing");
98         rx.await.map_err(|_| TubeError::OperationCancelled)?
99     }
100 }
101