xref: /aosp_15_r20/external/crosvm/devices/src/virtio/vhost/worker.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2017 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::collections::BTreeMap;
6 
7 use base::error;
8 use base::Error as SysError;
9 use base::Event;
10 use base::EventToken;
11 use base::Tube;
12 use base::WaitContext;
13 use libc::EIO;
14 use serde::Deserialize;
15 use serde::Serialize;
16 use vhost::Vhost;
17 use vm_memory::GuestMemory;
18 
19 use super::control_socket::VhostDevRequest;
20 use super::control_socket::VhostDevResponse;
21 use super::Error;
22 use super::Result;
23 use crate::virtio::Interrupt;
24 use crate::virtio::Queue;
25 use crate::virtio::VIRTIO_F_ACCESS_PLATFORM;
26 
27 #[derive(Clone, Serialize, Deserialize)]
28 pub struct VringBase {
29     pub index: usize,
30     pub base: u16,
31 }
32 
33 /// Worker that takes care of running the vhost device.
34 pub struct Worker<T: Vhost> {
35     interrupt: Interrupt,
36     pub queues: BTreeMap<usize, Queue>,
37     pub vhost_handle: T,
38     pub vhost_interrupt: Vec<Event>,
39     acked_features: u64,
40     pub response_tube: Option<Tube>,
41 }
42 
43 impl<T: Vhost> Worker<T> {
new( queues: BTreeMap<usize, Queue>, vhost_handle: T, vhost_interrupt: Vec<Event>, interrupt: Interrupt, acked_features: u64, response_tube: Option<Tube>, ) -> Worker<T>44     pub fn new(
45         queues: BTreeMap<usize, Queue>,
46         vhost_handle: T,
47         vhost_interrupt: Vec<Event>,
48         interrupt: Interrupt,
49         acked_features: u64,
50         response_tube: Option<Tube>,
51     ) -> Worker<T> {
52         Worker {
53             interrupt,
54             queues,
55             vhost_handle,
56             vhost_interrupt,
57             acked_features,
58             response_tube,
59         }
60     }
61 
init<F1>( &mut self, mem: GuestMemory, queue_sizes: &[u16], activate_vqs: F1, queue_vrings_base: Option<Vec<VringBase>>, ) -> Result<()> where F1: FnOnce(&T) -> Result<()>,62     pub fn init<F1>(
63         &mut self,
64         mem: GuestMemory,
65         queue_sizes: &[u16],
66         activate_vqs: F1,
67         queue_vrings_base: Option<Vec<VringBase>>,
68     ) -> Result<()>
69     where
70         F1: FnOnce(&T) -> Result<()>,
71     {
72         let avail_features = self
73             .vhost_handle
74             .get_features()
75             .map_err(Error::VhostGetFeatures)?;
76 
77         let mut features = self.acked_features & avail_features;
78         if self.acked_features & (1u64 << VIRTIO_F_ACCESS_PLATFORM) != 0 {
79             // The vhost API is a bit poorly named, this flag in the context of vhost
80             // means that it will do address translation via its IOTLB APIs. If the
81             // underlying virtio device doesn't use viommu, it doesn't need vhost
82             // translation.
83             features &= !(1u64 << VIRTIO_F_ACCESS_PLATFORM);
84         }
85 
86         self.vhost_handle
87             .set_features(features)
88             .map_err(Error::VhostSetFeatures)?;
89 
90         self.vhost_handle
91             .set_mem_table(&mem)
92             .map_err(Error::VhostSetMemTable)?;
93 
94         for (&queue_index, queue) in self.queues.iter() {
95             self.vhost_handle
96                 .set_vring_num(queue_index, queue.size())
97                 .map_err(Error::VhostSetVringNum)?;
98 
99             self.vhost_handle
100                 .set_vring_addr(
101                     &mem,
102                     queue_sizes[queue_index],
103                     queue.size(),
104                     queue_index,
105                     0,
106                     queue.desc_table(),
107                     queue.used_ring(),
108                     queue.avail_ring(),
109                     None,
110                 )
111                 .map_err(Error::VhostSetVringAddr)?;
112             if let Some(vrings_base) = &queue_vrings_base {
113                 let base = if let Some(vring_base) = vrings_base
114                     .iter()
115                     .find(|vring_base| vring_base.index == queue_index)
116                 {
117                     vring_base.base
118                 } else {
119                     return Err(Error::VringBaseMissing);
120                 };
121                 self.vhost_handle
122                     .set_vring_base(queue_index, base)
123                     .map_err(Error::VhostSetVringBase)?;
124             } else {
125                 self.vhost_handle
126                     .set_vring_base(queue_index, 0)
127                     .map_err(Error::VhostSetVringBase)?;
128             }
129             self.set_vring_call_for_entry(queue_index, queue.vector() as usize)?;
130             self.vhost_handle
131                 .set_vring_kick(queue_index, queue.event())
132                 .map_err(Error::VhostSetVringKick)?;
133         }
134 
135         activate_vqs(&self.vhost_handle)?;
136         Ok(())
137     }
138 
run<F1>(&mut self, cleanup_vqs: F1, kill_evt: Event) -> Result<()> where F1: FnOnce(&T) -> Result<()>,139     pub fn run<F1>(&mut self, cleanup_vqs: F1, kill_evt: Event) -> Result<()>
140     where
141         F1: FnOnce(&T) -> Result<()>,
142     {
143         #[derive(EventToken)]
144         enum Token {
145             VhostIrqi { index: usize },
146             InterruptResample,
147             Kill,
148             ControlNotify,
149         }
150 
151         let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[(&kill_evt, Token::Kill)])
152             .map_err(Error::CreateWaitContext)?;
153 
154         for (index, vhost_int) in self.vhost_interrupt.iter().enumerate() {
155             wait_ctx
156                 .add(vhost_int, Token::VhostIrqi { index })
157                 .map_err(Error::CreateWaitContext)?;
158         }
159         if let Some(socket) = &self.response_tube {
160             wait_ctx
161                 .add(socket, Token::ControlNotify)
162                 .map_err(Error::CreateWaitContext)?;
163         }
164         if let Some(resample_evt) = self.interrupt.get_resample_evt() {
165             wait_ctx
166                 .add(resample_evt, Token::InterruptResample)
167                 .map_err(Error::CreateWaitContext)?;
168         }
169 
170         'wait: loop {
171             let events = wait_ctx.wait().map_err(Error::WaitError)?;
172 
173             for event in events.iter().filter(|e| e.is_readable) {
174                 match event.token {
175                     Token::VhostIrqi { index } => {
176                         self.vhost_interrupt[index]
177                             .wait()
178                             .map_err(Error::VhostIrqRead)?;
179                         self.interrupt
180                             .signal_used_queue(self.queues[&index].vector());
181                     }
182                     Token::InterruptResample => {
183                         self.interrupt.interrupt_resample();
184                     }
185                     Token::Kill => {
186                         let _ = kill_evt.wait();
187                         break 'wait;
188                     }
189                     Token::ControlNotify => {
190                         if let Some(socket) = &self.response_tube {
191                             match socket.recv() {
192                                 Ok(VhostDevRequest::MsixEntryChanged(index)) => {
193                                     let mut qindex = 0;
194                                     for (&queue_index, queue) in self.queues.iter() {
195                                         if queue.vector() == index as u16 {
196                                             qindex = queue_index;
197                                             break;
198                                         }
199                                     }
200                                     let response =
201                                         match self.set_vring_call_for_entry(qindex, index) {
202                                             Ok(()) => VhostDevResponse::Ok,
203                                             Err(e) => {
204                                                 error!(
205                                                 "Set vring call failed for masked entry {}: {:?}",
206                                                 index, e
207                                             );
208                                                 VhostDevResponse::Err(SysError::new(EIO))
209                                             }
210                                         };
211                                     if let Err(e) = socket.send(&response) {
212                                         error!("Vhost failed to send VhostMsixEntryMasked Response for entry {}: {:?}", index, e);
213                                     }
214                                 }
215                                 Ok(VhostDevRequest::MsixChanged) => {
216                                     let response = match self.set_vring_calls() {
217                                         Ok(()) => VhostDevResponse::Ok,
218                                         Err(e) => {
219                                             error!("Set vring calls failed: {:?}", e);
220                                             VhostDevResponse::Err(SysError::new(EIO))
221                                         }
222                                     };
223                                     if let Err(e) = socket.send(&response) {
224                                         error!(
225                                             "Vhost failed to send VhostMsixMasked Response: {:?}",
226                                             e
227                                         );
228                                     }
229                                 }
230                                 Err(e) => {
231                                     error!("Vhost failed to receive Control request: {:?}", e);
232                                 }
233                             }
234                         }
235                     }
236                 }
237             }
238         }
239         cleanup_vqs(&self.vhost_handle)?;
240         Ok(())
241     }
242 
set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()>243     fn set_vring_call_for_entry(&self, queue_index: usize, vector: usize) -> Result<()> {
244         // No response_socket means it doesn't have any control related
245         // with the msix. Due to this, cannot use the direct irq fd but
246         // should fall back to indirect irq fd.
247         if self.response_tube.is_some() {
248             if let Some(msix_config) = self.interrupt.get_msix_config() {
249                 let msix_config = msix_config.lock();
250                 let msix_masked = msix_config.masked();
251                 if msix_masked {
252                     return Ok(());
253                 }
254                 if !msix_config.table_masked(vector) {
255                     if let Some(irqfd) = msix_config.get_irqfd(vector) {
256                         self.vhost_handle
257                             .set_vring_call(queue_index, irqfd)
258                             .map_err(Error::VhostSetVringCall)?;
259                     } else {
260                         self.vhost_handle
261                             .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
262                             .map_err(Error::VhostSetVringCall)?;
263                     }
264                     return Ok(());
265                 }
266             }
267         }
268 
269         self.vhost_handle
270             .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
271             .map_err(Error::VhostSetVringCall)?;
272         Ok(())
273     }
274 
set_vring_calls(&self) -> Result<()>275     fn set_vring_calls(&self) -> Result<()> {
276         if let Some(msix_config) = self.interrupt.get_msix_config() {
277             let msix_config = msix_config.lock();
278             if msix_config.masked() {
279                 for (&queue_index, _) in self.queues.iter() {
280                     self.vhost_handle
281                         .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
282                         .map_err(Error::VhostSetVringCall)?;
283                 }
284             } else {
285                 for (&queue_index, queue) in self.queues.iter() {
286                     let vector = queue.vector() as usize;
287                     if !msix_config.table_masked(vector) {
288                         if let Some(irqfd) = msix_config.get_irqfd(vector) {
289                             self.vhost_handle
290                                 .set_vring_call(queue_index, irqfd)
291                                 .map_err(Error::VhostSetVringCall)?;
292                         } else {
293                             self.vhost_handle
294                                 .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
295                                 .map_err(Error::VhostSetVringCall)?;
296                         }
297                     } else {
298                         self.vhost_handle
299                             .set_vring_call(queue_index, &self.vhost_interrupt[queue_index])
300                             .map_err(Error::VhostSetVringCall)?;
301                     }
302                 }
303             }
304         }
305         Ok(())
306     }
307 }
308