1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserDevice` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //! /* fields */
20 //! }
21 //!
22 //! impl VhostUserDevice for MyBackend {
23 //! /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //! let backend = MyBackend { /* initialize fields */ };
28 //! let handler = DeviceRequestHandler::new(backend);
29 //! let socket = std::path::Path("/path/to/socket");
30 //! let ex = cros_async::Executor::new()?;
31 //!
32 //! if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //! eprintln!("error happened: {}", e);
34 //! }
35 //! Ok(())
36 //! }
37 //! ```
38 // Implementation note:
39 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40 // protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41 // common code for setting up guest memory and managing partially configured vrings.
42 // DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43 // becomes readable. handle_request() reads and parses the message and then calls one of the
44 // Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45 // is what our devices implement).
46
47 pub(super) mod sys;
48
49 use std::collections::BTreeMap;
50 use std::convert::From;
51 use std::fs::File;
52 use std::num::Wrapping;
53 #[cfg(any(target_os = "android", target_os = "linux"))]
54 use std::os::unix::io::AsRawFd;
55 use std::sync::Arc;
56
57 use anyhow::bail;
58 use anyhow::Context;
59 #[cfg(any(target_os = "android", target_os = "linux"))]
60 use base::clear_fd_flags;
61 use base::error;
62 use base::trace;
63 use base::warn;
64 use base::Event;
65 use base::Protection;
66 use base::SafeDescriptor;
67 use base::SharedMemory;
68 use base::WorkerThread;
69 use cros_async::TaskHandle;
70 use hypervisor::MemCacheType;
71 use serde::Deserialize;
72 use serde::Serialize;
73 use sync::Mutex;
74 use thiserror::Error as ThisError;
75 use vm_control::VmMemorySource;
76 use vm_memory::GuestAddress;
77 use vm_memory::GuestMemory;
78 use vm_memory::MemoryRegion;
79 use vmm_vhost::message::VhostSharedMemoryRegion;
80 use vmm_vhost::message::VhostUserConfigFlags;
81 use vmm_vhost::message::VhostUserExternalMapMsg;
82 use vmm_vhost::message::VhostUserGpuMapMsg;
83 use vmm_vhost::message::VhostUserInflight;
84 use vmm_vhost::message::VhostUserMemoryRegion;
85 use vmm_vhost::message::VhostUserMigrationPhase;
86 use vmm_vhost::message::VhostUserProtocolFeatures;
87 use vmm_vhost::message::VhostUserShmemMapMsg;
88 use vmm_vhost::message::VhostUserShmemMapMsgFlags;
89 use vmm_vhost::message::VhostUserShmemUnmapMsg;
90 use vmm_vhost::message::VhostUserSingleMemoryRegion;
91 use vmm_vhost::message::VhostUserTransferDirection;
92 use vmm_vhost::message::VhostUserVringAddrFlags;
93 use vmm_vhost::message::VhostUserVringState;
94 use vmm_vhost::BackendReq;
95 use vmm_vhost::Connection;
96 use vmm_vhost::Error as VhostError;
97 use vmm_vhost::Frontend;
98 use vmm_vhost::FrontendClient;
99 use vmm_vhost::Result as VhostResult;
100 use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
101
102 use crate::virtio::Interrupt;
103 use crate::virtio::Queue;
104 use crate::virtio::QueueConfig;
105 use crate::virtio::SharedMemoryMapper;
106 use crate::virtio::SharedMemoryRegion;
107
108 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
109 /// used to translate messages from the vmm to guest offsets.
110 #[derive(Default)]
111 pub struct MappingInfo {
112 pub vmm_addr: u64,
113 pub guest_phys: u64,
114 pub size: u64,
115 }
116
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>117 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
118 for map in maps {
119 if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
120 return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
121 }
122 }
123 Err(VhostError::InvalidMessage)
124 }
125
126 /// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
127 ///
128 /// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
129 /// is designed to follow crosvm conventions for implementing devices.
130 pub trait VhostUserDevice {
131 /// The maximum number of queues that this backend can manage.
max_queue_num(&self) -> usize132 fn max_queue_num(&self) -> usize;
133
134 /// The set of feature bits that this backend supports.
features(&self) -> u64135 fn features(&self) -> u64;
136
137 /// Acknowledges that this set of features should be enabled.
138 ///
139 /// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
140 /// framework will manage generic vhost and vring features.
141 ///
142 /// `DeviceRequestHandler` checks for valid features before calling this function, so the
143 /// features in `value` will always be a subset of those advertised by `features()`.
ack_features(&mut self, _value: u64) -> anyhow::Result<()>144 fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
145 Ok(())
146 }
147
148 /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures149 fn protocol_features(&self) -> VhostUserProtocolFeatures;
150
151 /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])152 fn read_config(&self, offset: u64, dst: &mut [u8]);
153
154 /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])155 fn write_config(&self, _offset: u64, _data: &[u8]) {}
156
157 /// Indicates that the backend should start processing requests for virtio queue number `idx`.
158 /// This method must not block the current thread so device backends should either spawn an
159 /// async task or another thread to handle messages from the Queue.
start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>160 fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
161
162 /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
163 /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
164 /// This method will only be called for queues that were previously started by `start_queue`.
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>165 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
166
167 /// Resets the vhost-user backend.
reset(&mut self)168 fn reset(&mut self);
169
170 /// Returns the device's shared memory region if present.
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>171 fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
172 None
173 }
174
175 /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
176 /// handling.
177 ///
178 /// The backend is given an `Arc` instead of full ownership so that the framework can also use
179 /// the connection.
180 ///
181 /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
182 /// negotiated.
set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>)183 fn set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>) {}
184
185 /// Enter the "suspended device state" described in the vhost-user spec. See the spec for
186 /// requirements.
187 ///
188 /// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
189 ///
190 /// Called after a `stop_queue` call if there are no running queues left. Also called soon
191 /// after device creation to ensure the device is acting suspended immediately on construction.
192 ///
193 /// The next `start_queue` call implicitly exits the "suspend device state".
194 ///
195 /// * Ok(()) => device successfully suspended
196 /// * Err(_) => unrecoverable error
enter_suspended_state(&mut self) -> anyhow::Result<()>197 fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
198
199 /// Snapshot device and return serialized state.
snapshot(&mut self) -> anyhow::Result<serde_json::Value>200 fn snapshot(&mut self) -> anyhow::Result<serde_json::Value>;
201
202 /// Restore device state from a snapshot.
restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>203 fn restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>;
204 }
205
206 /// A virtio ring entry.
207 struct Vring {
208 // The queue config. This doesn't get mutated by the queue workers.
209 queue: QueueConfig,
210 doorbell: Option<Interrupt>,
211 enabled: bool,
212 }
213
214 impl Vring {
new(max_size: u16, features: u64) -> Self215 fn new(max_size: u16, features: u64) -> Self {
216 Self {
217 queue: QueueConfig::new(max_size, features),
218 doorbell: None,
219 enabled: false,
220 }
221 }
222
reset(&mut self)223 fn reset(&mut self) {
224 self.queue.reset();
225 self.doorbell = None;
226 self.enabled = false;
227 }
228 }
229
230 /// Ops for running vhost-user over a stream (i.e. regular protocol).
231 pub(super) struct VhostUserRegularOps;
232
233 impl VhostUserRegularOps {
set_mem_table( contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>234 pub fn set_mem_table(
235 contexts: &[VhostUserMemoryRegion],
236 files: Vec<File>,
237 ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
238 if files.len() != contexts.len() {
239 return Err(VhostError::InvalidParam);
240 }
241
242 let mut regions = Vec::with_capacity(files.len());
243 for (region, file) in contexts.iter().zip(files.into_iter()) {
244 let region = MemoryRegion::new_from_shm(
245 region.memory_size,
246 GuestAddress(region.guest_phys_addr),
247 region.mmap_offset,
248 Arc::new(
249 SharedMemory::from_safe_descriptor(
250 SafeDescriptor::from(file),
251 region.memory_size,
252 )
253 .unwrap(),
254 ),
255 )
256 .map_err(|e| {
257 error!("failed to create a memory region: {}", e);
258 VhostError::InvalidOperation
259 })?;
260 regions.push(region);
261 }
262 let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
263 error!("failed to create guest memory: {}", e);
264 VhostError::InvalidOperation
265 })?;
266
267 let vmm_maps = contexts
268 .iter()
269 .map(|region| MappingInfo {
270 vmm_addr: region.user_addr,
271 guest_phys: region.guest_phys_addr,
272 size: region.memory_size,
273 })
274 .collect();
275 Ok((guest_mem, vmm_maps))
276 }
277
set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event>278 pub fn set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event> {
279 let file = file.ok_or(VhostError::InvalidParam)?;
280 // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
281 // values via `next_val()` later.
282 // This is only required (and can only be done) on Unix platforms.
283 #[cfg(any(target_os = "android", target_os = "linux"))]
284 if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
285 error!("failed to remove O_NONBLOCK for kick fd: {}", e);
286 return Err(VhostError::InvalidParam);
287 }
288 Ok(Event::from(SafeDescriptor::from(file)))
289 }
290
set_vring_call( _index: u8, file: Option<File>, signal_config_change_fn: Box<dyn Fn() + Send + Sync>, ) -> VhostResult<Interrupt>291 pub fn set_vring_call(
292 _index: u8,
293 file: Option<File>,
294 signal_config_change_fn: Box<dyn Fn() + Send + Sync>,
295 ) -> VhostResult<Interrupt> {
296 let file = file.ok_or(VhostError::InvalidParam)?;
297 Ok(Interrupt::new_vhost_user(
298 Event::from(SafeDescriptor::from(file)),
299 signal_config_change_fn,
300 ))
301 }
302 }
303
304 /// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
305 pub struct DeviceRequestHandler<T: VhostUserDevice> {
306 vrings: Vec<Vring>,
307 owned: bool,
308 vmm_maps: Option<Vec<MappingInfo>>,
309 mem: Option<GuestMemory>,
310 acked_features: u64,
311 acked_protocol_features: VhostUserProtocolFeatures,
312 backend: T,
313 backend_req_connection: Arc<Mutex<VhostBackendReqConnectionState>>,
314 // Thread processing active device state FD.
315 device_state_thread: Option<DeviceStateThread>,
316 }
317
318 enum DeviceStateThread {
319 Save(WorkerThread<serde_json::Result<()>>),
320 Load(WorkerThread<serde_json::Result<DeviceRequestHandlerSnapshot>>),
321 }
322
323 #[derive(Serialize, Deserialize)]
324 pub struct DeviceRequestHandlerSnapshot {
325 acked_features: u64,
326 acked_protocol_features: u64,
327 backend: serde_json::Value,
328 }
329
330 impl<T: VhostUserDevice> DeviceRequestHandler<T> {
331 /// Creates a vhost-user handler instance for `backend`.
new(mut backend: T) -> Self332 pub(crate) fn new(mut backend: T) -> Self {
333 let mut vrings = Vec::with_capacity(backend.max_queue_num());
334 for _ in 0..backend.max_queue_num() {
335 vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
336 }
337
338 // VhostUserDevice implementations must support `enter_suspended_state()`.
339 // Call it on startup to ensure it works and to initialize the device in a suspended state.
340 backend
341 .enter_suspended_state()
342 .expect("enter_suspended_state failed on device init");
343
344 DeviceRequestHandler {
345 vrings,
346 owned: false,
347 vmm_maps: None,
348 mem: None,
349 acked_features: 0,
350 acked_protocol_features: VhostUserProtocolFeatures::empty(),
351 backend,
352 backend_req_connection: Arc::new(Mutex::new(
353 VhostBackendReqConnectionState::NoConnection,
354 )),
355 device_state_thread: None,
356 }
357 }
358
359 /// Check if all queues are stopped.
360 ///
361 /// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
all_queues_stopped(&self) -> bool362 fn all_queues_stopped(&self) -> bool {
363 self.vrings.iter().all(|vring| !vring.queue.ready())
364 }
365 }
366
367 impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
as_ref(&self) -> &T368 fn as_ref(&self) -> &T {
369 &self.backend
370 }
371 }
372
373 impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
as_mut(&mut self) -> &mut T374 fn as_mut(&mut self) -> &mut T {
375 &mut self.backend
376 }
377 }
378
379 impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
set_owner(&mut self) -> VhostResult<()>380 fn set_owner(&mut self) -> VhostResult<()> {
381 if self.owned {
382 return Err(VhostError::InvalidOperation);
383 }
384 self.owned = true;
385 Ok(())
386 }
387
reset_owner(&mut self) -> VhostResult<()>388 fn reset_owner(&mut self) -> VhostResult<()> {
389 self.owned = false;
390 self.acked_features = 0;
391 self.backend.reset();
392 Ok(())
393 }
394
get_features(&mut self) -> VhostResult<u64>395 fn get_features(&mut self) -> VhostResult<u64> {
396 let features = self.backend.features();
397 Ok(features)
398 }
399
set_features(&mut self, features: u64) -> VhostResult<()>400 fn set_features(&mut self, features: u64) -> VhostResult<()> {
401 if !self.owned {
402 return Err(VhostError::InvalidOperation);
403 }
404
405 let unexpected_features = features & !self.backend.features();
406 if unexpected_features != 0 {
407 error!("unexpected set_features {:#x}", unexpected_features);
408 return Err(VhostError::InvalidParam);
409 }
410
411 if let Err(e) = self.backend.ack_features(features) {
412 error!("failed to acknowledge features 0x{:x}: {}", features, e);
413 return Err(VhostError::InvalidOperation);
414 }
415
416 self.acked_features |= features;
417
418 // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
419 // enabled state.
420 // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
421 // disabled state.
422 // Client must not pass data to/from the backend until ring is enabled by
423 // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
424 // VHOST_USER_SET_VRING_ENABLE with parameter 0.
425 let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
426 for v in &mut self.vrings {
427 v.enabled = vring_enabled;
428 }
429
430 Ok(())
431 }
432
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>433 fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
434 Ok(self.backend.protocol_features())
435 }
436
set_protocol_features(&mut self, features: u64) -> VhostResult<()>437 fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
438 let features = match VhostUserProtocolFeatures::from_bits(features) {
439 Some(proto_features) => proto_features,
440 None => {
441 error!(
442 "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
443 features
444 );
445 return Err(VhostError::InvalidOperation);
446 }
447 };
448 let supported = self.backend.protocol_features();
449 self.acked_protocol_features = features & supported;
450 Ok(())
451 }
452
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>453 fn set_mem_table(
454 &mut self,
455 contexts: &[VhostUserMemoryRegion],
456 files: Vec<File>,
457 ) -> VhostResult<()> {
458 let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
459 self.mem = Some(guest_mem);
460 self.vmm_maps = Some(vmm_maps);
461 Ok(())
462 }
463
get_queue_num(&mut self) -> VhostResult<u64>464 fn get_queue_num(&mut self) -> VhostResult<u64> {
465 Ok(self.vrings.len() as u64)
466 }
467
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>468 fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
469 if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
470 return Err(VhostError::InvalidParam);
471 }
472 self.vrings[index as usize].queue.set_size(num as u16);
473
474 Ok(())
475 }
476
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>477 fn set_vring_addr(
478 &mut self,
479 index: u32,
480 _flags: VhostUserVringAddrFlags,
481 descriptor: u64,
482 used: u64,
483 available: u64,
484 _log: u64,
485 ) -> VhostResult<()> {
486 if index as usize >= self.vrings.len() {
487 return Err(VhostError::InvalidParam);
488 }
489
490 let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
491 let vring = &mut self.vrings[index as usize];
492 vring
493 .queue
494 .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
495 vring
496 .queue
497 .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
498 vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
499
500 Ok(())
501 }
502
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>503 fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
504 if index as usize >= self.vrings.len() || base >= Queue::MAX_SIZE.into() {
505 return Err(VhostError::InvalidParam);
506 }
507
508 let vring = &mut self.vrings[index as usize];
509 vring.queue.set_next_avail(Wrapping(base as u16));
510 vring.queue.set_next_used(Wrapping(base as u16));
511
512 Ok(())
513 }
514
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>515 fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
516 let vring = self
517 .vrings
518 .get_mut(index as usize)
519 .ok_or(VhostError::InvalidParam)?;
520
521 // Quotation from vhost-user spec:
522 // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
523 // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
524 // means it is started and should be stopped.
525 let vring_base = if vring.queue.ready() {
526 let queue = match self.backend.stop_queue(index as usize) {
527 Ok(q) => q,
528 Err(e) => {
529 error!("Failed to stop queue in get_vring_base: {:#}", e);
530 return Err(VhostError::BackendInternalError);
531 }
532 };
533
534 trace!("stopped queue {index}");
535
536 vring.reset();
537
538 if self.all_queues_stopped() {
539 trace!("all queues stopped; entering suspended state");
540 self.backend
541 .enter_suspended_state()
542 .map_err(VhostError::EnterSuspendedState)?;
543 }
544
545 queue.next_avail_to_process()
546 } else {
547 0
548 };
549
550 Ok(VhostUserVringState::new(index, vring_base.into()))
551 }
552
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>553 fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
554 if index as usize >= self.vrings.len() {
555 return Err(VhostError::InvalidParam);
556 }
557
558 let vring = &mut self.vrings[index as usize];
559 if vring.queue.ready() {
560 error!("kick fd cannot replaced after queue is started");
561 return Err(VhostError::InvalidOperation);
562 }
563
564 let kick_evt = VhostUserRegularOps::set_vring_kick(index, file)?;
565
566 // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
567 vring.queue.ack_features(self.acked_features);
568 vring.queue.set_ready(true);
569
570 let mem = self
571 .mem
572 .as_ref()
573 .cloned()
574 .ok_or(VhostError::InvalidOperation)?;
575
576 let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
577
578 let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
579 Ok(queue) => queue,
580 Err(e) => {
581 error!("failed to activate vring: {:#}", e);
582 return Err(VhostError::BackendInternalError);
583 }
584 };
585
586 if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
587 error!("Failed to start queue {}: {}", index, e);
588 return Err(VhostError::BackendInternalError);
589 }
590
591 trace!("started queue {index}");
592
593 Ok(())
594 }
595
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>596 fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
597 if index as usize >= self.vrings.len() {
598 return Err(VhostError::InvalidParam);
599 }
600
601 let backend_req_conn = self.backend_req_connection.clone();
602 let signal_config_change_fn = Box::new(move || match &*backend_req_conn.lock() {
603 VhostBackendReqConnectionState::Connected(frontend) => {
604 if let Err(e) = frontend.send_config_changed() {
605 error!("Failed to notify config change: {:#}", e);
606 }
607 }
608 VhostBackendReqConnectionState::NoConnection => {
609 error!("No Backend request connection found");
610 }
611 });
612
613 let doorbell = VhostUserRegularOps::set_vring_call(index, file, signal_config_change_fn)?;
614 self.vrings[index as usize].doorbell = Some(doorbell);
615 Ok(())
616 }
617
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>618 fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
619 // TODO
620 Ok(())
621 }
622
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>623 fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
624 if index as usize >= self.vrings.len() {
625 return Err(VhostError::InvalidParam);
626 }
627
628 // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
629 // has been negotiated.
630 if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
631 return Err(VhostError::InvalidOperation);
632 }
633
634 // Backend must not pass data to/from the ring until ring is enabled by
635 // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
636 // VHOST_USER_SET_VRING_ENABLE with parameter 0.
637 self.vrings[index as usize].enabled = enable;
638
639 Ok(())
640 }
641
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>642 fn get_config(
643 &mut self,
644 offset: u32,
645 size: u32,
646 _flags: VhostUserConfigFlags,
647 ) -> VhostResult<Vec<u8>> {
648 let mut data = vec![0; size as usize];
649 self.backend.read_config(u64::from(offset), &mut data);
650 Ok(data)
651 }
652
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>653 fn set_config(
654 &mut self,
655 offset: u32,
656 buf: &[u8],
657 _flags: VhostUserConfigFlags,
658 ) -> VhostResult<()> {
659 self.backend.write_config(u64::from(offset), buf);
660 Ok(())
661 }
662
set_backend_req_fd(&mut self, ep: Connection<BackendReq>)663 fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
664 let conn = Arc::new(VhostBackendReqConnection::new(
665 FrontendClient::new(ep),
666 self.backend.get_shared_memory_region().map(|r| r.id),
667 ));
668
669 {
670 let mut backend_req_conn = self.backend_req_connection.lock();
671 if let VhostBackendReqConnectionState::Connected(_) = &*backend_req_conn {
672 warn!("Backend Request Connection already established. Overwriting");
673 }
674 *backend_req_conn = VhostBackendReqConnectionState::Connected(conn.clone());
675 }
676
677 self.backend.set_backend_req_connection(conn);
678 }
679
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>680 fn get_inflight_fd(
681 &mut self,
682 _inflight: &VhostUserInflight,
683 ) -> VhostResult<(VhostUserInflight, File)> {
684 unimplemented!("get_inflight_fd");
685 }
686
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>687 fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
688 unimplemented!("set_inflight_fd");
689 }
690
get_max_mem_slots(&mut self) -> VhostResult<u64>691 fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
692 //TODO
693 Ok(0)
694 }
695
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>696 fn add_mem_region(
697 &mut self,
698 _region: &VhostUserSingleMemoryRegion,
699 _fd: File,
700 ) -> VhostResult<()> {
701 //TODO
702 Ok(())
703 }
704
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>705 fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
706 //TODO
707 Ok(())
708 }
709
set_device_state_fd( &mut self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, mut fd: File, ) -> VhostResult<Option<File>>710 fn set_device_state_fd(
711 &mut self,
712 transfer_direction: VhostUserTransferDirection,
713 migration_phase: VhostUserMigrationPhase,
714 mut fd: File,
715 ) -> VhostResult<Option<File>> {
716 if migration_phase != VhostUserMigrationPhase::Stopped {
717 return Err(VhostError::InvalidOperation);
718 }
719 if !self.all_queues_stopped() {
720 return Err(VhostError::InvalidOperation);
721 }
722 if self.device_state_thread.is_some() {
723 error!("must call check_device_state before starting new state transfer");
724 return Err(VhostError::InvalidOperation);
725 }
726 // `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
727 // methods, but, for simplicitly, we do those operations inline and only spawn a thread to
728 // handle the serialization and data transfer (the latter which seems necessary to
729 // implement the API correctly without, e.g., deadlocking because a pipe is full).
730 match transfer_direction {
731 VhostUserTransferDirection::Save => {
732 // Snapshot the state.
733 let snapshot = DeviceRequestHandlerSnapshot {
734 acked_features: self.acked_features,
735 acked_protocol_features: self.acked_protocol_features.bits(),
736 backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
737 };
738 // Spawn thread to write the serialized bytes.
739 self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
740 "device_state_save",
741 move |_kill_event| serde_json::to_writer(&mut fd, &snapshot),
742 )));
743 Ok(None)
744 }
745 VhostUserTransferDirection::Load => {
746 // Spawn a thread to read the bytes and deserialize. Restore will happen in
747 // `check_device_state`.
748 self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
749 "device_state_load",
750 move |_kill_event| serde_json::from_reader(&mut fd),
751 )));
752 Ok(None)
753 }
754 }
755 }
756
check_device_state(&mut self) -> VhostResult<()>757 fn check_device_state(&mut self) -> VhostResult<()> {
758 let Some(thread) = self.device_state_thread.take() else {
759 error!("check_device_state: no active state transfer");
760 return Err(VhostError::InvalidOperation);
761 };
762 match thread {
763 DeviceStateThread::Save(worker) => {
764 worker.stop().map_err(|e| {
765 error!("device state save thread failed: {:#}", e);
766 VhostError::BackendInternalError
767 })?;
768 Ok(())
769 }
770 DeviceStateThread::Load(worker) => {
771 let snapshot = worker.stop().map_err(|e| {
772 error!("device state load thread failed: {:#}", e);
773 VhostError::BackendInternalError
774 })?;
775 self.acked_features = snapshot.acked_features;
776 self.acked_protocol_features =
777 VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
778 .with_context(|| {
779 format!(
780 "unsupported bits in acked_protocol_features: {:#x}",
781 snapshot.acked_protocol_features
782 )
783 })
784 .map_err(VhostError::RestoreError)?;
785 self.backend
786 .restore(snapshot.backend)
787 .map_err(VhostError::RestoreError)?;
788 Ok(())
789 }
790 }
791 }
792
get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>>793 fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
794 Ok(if let Some(r) = self.backend.get_shared_memory_region() {
795 vec![VhostSharedMemoryRegion::new(r.id, r.length)]
796 } else {
797 Vec::new()
798 })
799 }
800 }
801
802 /// Indicates the state of backend request connection
803 pub enum VhostBackendReqConnectionState {
804 /// A backend request connection (`VhostBackendReqConnection`) is established
805 Connected(Arc<VhostBackendReqConnection>),
806 /// No backend request connection has been established yet
807 NoConnection,
808 }
809
810 /// Keeps track of Vhost user backend request connection.
811 pub struct VhostBackendReqConnection {
812 conn: Arc<Mutex<FrontendClient>>,
813 shmem_info: Mutex<Option<ShmemInfo>>,
814 }
815
816 #[derive(Clone)]
817 struct ShmemInfo {
818 shmid: u8,
819 mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
820 }
821
822 impl VhostBackendReqConnection {
new(conn: FrontendClient, shmid: Option<u8>) -> Self823 pub fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
824 let shmem_info = Mutex::new(shmid.map(|shmid| ShmemInfo {
825 shmid,
826 mapped_regions: BTreeMap::new(),
827 }));
828 Self {
829 conn: Arc::new(Mutex::new(conn)),
830 shmem_info,
831 }
832 }
833
834 /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
send_config_changed(&self) -> anyhow::Result<()>835 pub fn send_config_changed(&self) -> anyhow::Result<()> {
836 self.conn
837 .lock()
838 .handle_config_change()
839 .context("Could not send config change message")?;
840 Ok(())
841 }
842
843 /// Create a SharedMemoryMapper trait object from the ShmemInfo.
take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>>844 pub fn take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
845 let shmem_info = self
846 .shmem_info
847 .lock()
848 .take()
849 .context("could not take shared memory mapper information")?;
850
851 Ok(Box::new(VhostShmemMapper {
852 conn: self.conn.clone(),
853 shmem_info,
854 }))
855 }
856 }
857
858 struct VhostShmemMapper {
859 conn: Arc<Mutex<FrontendClient>>,
860 shmem_info: ShmemInfo,
861 }
862
863 impl SharedMemoryMapper for VhostShmemMapper {
add_mapping( &mut self, source: VmMemorySource, offset: u64, prot: Protection, _cache: MemCacheType, ) -> anyhow::Result<()>864 fn add_mapping(
865 &mut self,
866 source: VmMemorySource,
867 offset: u64,
868 prot: Protection,
869 _cache: MemCacheType,
870 ) -> anyhow::Result<()> {
871 let size = match source {
872 VmMemorySource::Vulkan {
873 descriptor,
874 handle_type,
875 memory_idx,
876 device_uuid,
877 driver_uuid,
878 size,
879 } => {
880 let msg = VhostUserGpuMapMsg::new(
881 self.shmem_info.shmid,
882 offset,
883 size,
884 memory_idx,
885 handle_type,
886 device_uuid,
887 driver_uuid,
888 );
889 self.conn
890 .lock()
891 .gpu_map(&msg, &descriptor)
892 .context("failed to map memory")?;
893 size
894 }
895 VmMemorySource::ExternalMapping { ptr, size } => {
896 let msg = VhostUserExternalMapMsg::new(self.shmem_info.shmid, offset, size, ptr);
897 self.conn
898 .lock()
899 .external_map(&msg)
900 .context("failed to map memory")?;
901 size
902 }
903 source => {
904 // The last two sources use the same VhostUserShmemMapMsg, continue matching here
905 // on the aliased `source` above.
906 let (descriptor, fd_offset, size) = match source {
907 VmMemorySource::Descriptor {
908 descriptor,
909 offset,
910 size,
911 } => (descriptor, offset, size),
912 VmMemorySource::SharedMemory(shmem) => {
913 let size = shmem.size();
914 let descriptor = SafeDescriptor::from(shmem);
915 (descriptor, 0, size)
916 }
917 _ => bail!("unsupported source"),
918 };
919 let flags = VhostUserShmemMapMsgFlags::from(prot);
920 let msg = VhostUserShmemMapMsg::new(
921 self.shmem_info.shmid,
922 offset,
923 fd_offset,
924 size,
925 flags,
926 );
927 self.conn
928 .lock()
929 .shmem_map(&msg, &descriptor)
930 .context("failed to map memory")?;
931 size
932 }
933 };
934
935 self.shmem_info.mapped_regions.insert(offset, size);
936 Ok(())
937 }
938
remove_mapping(&mut self, offset: u64) -> anyhow::Result<()>939 fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
940 let size = self
941 .shmem_info
942 .mapped_regions
943 .remove(&offset)
944 .context("unknown offset")?;
945 let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
946 self.conn
947 .lock()
948 .shmem_unmap(&msg)
949 .context("failed to map memory")
950 .map(|_| ())
951 }
952 }
953
954 pub(crate) struct WorkerState<T, U> {
955 pub(crate) queue_task: TaskHandle<U>,
956 pub(crate) queue: T,
957 }
958
959 /// Errors for device operations
960 #[derive(Debug, ThisError)]
961 pub enum Error {
962 #[error("worker not found when stopping queue")]
963 WorkerNotFound,
964 }
965
966 #[cfg(test)]
967 mod tests {
968 use std::sync::mpsc::channel;
969 use std::sync::Barrier;
970
971 use anyhow::bail;
972 use base::Event;
973 use vmm_vhost::BackendServer;
974 use vmm_vhost::FrontendReq;
975 use zerocopy::AsBytes;
976 use zerocopy::FromBytes;
977 use zerocopy::FromZeroes;
978
979 use super::*;
980 use crate::virtio::vhost_user_frontend::VhostUserFrontend;
981 use crate::virtio::DeviceType;
982 use crate::virtio::VirtioDevice;
983
984 #[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromZeroes, FromBytes)]
985 #[repr(C, packed(4))]
986 struct FakeConfig {
987 x: u32,
988 y: u64,
989 }
990
991 const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
992
993 pub(super) struct FakeBackend {
994 avail_features: u64,
995 acked_features: u64,
996 active_queues: Vec<Option<Queue>>,
997 allow_backend_req: bool,
998 backend_conn: Option<Arc<VhostBackendReqConnection>>,
999 }
1000
1001 #[derive(Deserialize, Serialize)]
1002 struct FakeBackendSnapshot {
1003 data: Vec<u8>,
1004 }
1005
1006 impl FakeBackend {
1007 const MAX_QUEUE_NUM: usize = 16;
1008
new() -> Self1009 pub(super) fn new() -> Self {
1010 let mut active_queues = Vec::new();
1011 active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1012 Self {
1013 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1014 acked_features: 0,
1015 active_queues,
1016 allow_backend_req: false,
1017 backend_conn: None,
1018 }
1019 }
1020 }
1021
1022 impl VhostUserDevice for FakeBackend {
max_queue_num(&self) -> usize1023 fn max_queue_num(&self) -> usize {
1024 Self::MAX_QUEUE_NUM
1025 }
1026
features(&self) -> u641027 fn features(&self) -> u64 {
1028 self.avail_features
1029 }
1030
ack_features(&mut self, value: u64) -> anyhow::Result<()>1031 fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1032 let unrequested_features = value & !self.avail_features;
1033 if unrequested_features != 0 {
1034 bail!(
1035 "invalid protocol features are given: 0x{:x}",
1036 unrequested_features
1037 );
1038 }
1039 self.acked_features |= value;
1040 Ok(())
1041 }
1042
protocol_features(&self) -> VhostUserProtocolFeatures1043 fn protocol_features(&self) -> VhostUserProtocolFeatures {
1044 let mut features =
1045 VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1046 if self.allow_backend_req {
1047 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1048 }
1049 features
1050 }
1051
read_config(&self, offset: u64, dst: &mut [u8])1052 fn read_config(&self, offset: u64, dst: &mut [u8]) {
1053 dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1054 }
1055
reset(&mut self)1056 fn reset(&mut self) {}
1057
start_queue( &mut self, idx: usize, queue: Queue, _mem: GuestMemory, ) -> anyhow::Result<()>1058 fn start_queue(
1059 &mut self,
1060 idx: usize,
1061 queue: Queue,
1062 _mem: GuestMemory,
1063 ) -> anyhow::Result<()> {
1064 self.active_queues[idx] = Some(queue);
1065 Ok(())
1066 }
1067
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>1068 fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1069 Ok(self.active_queues[idx]
1070 .take()
1071 .ok_or(Error::WorkerNotFound)?)
1072 }
1073
set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>)1074 fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
1075 self.backend_conn = Some(conn);
1076 }
1077
enter_suspended_state(&mut self) -> anyhow::Result<()>1078 fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1079 Ok(())
1080 }
1081
snapshot(&mut self) -> anyhow::Result<serde_json::Value>1082 fn snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
1083 serde_json::to_value(FakeBackendSnapshot {
1084 data: vec![1, 2, 3],
1085 })
1086 .context("failed to serialize snapshot")
1087 }
1088
restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>1089 fn restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
1090 let snapshot: FakeBackendSnapshot =
1091 serde_json::from_value(data).context("failed to deserialize snapshot")?;
1092 assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1093 Ok(())
1094 }
1095 }
1096
1097 #[test]
test_vhost_user_lifecycle()1098 fn test_vhost_user_lifecycle() {
1099 test_vhost_user_lifecycle_parameterized(false);
1100 }
1101
1102 #[test]
1103 #[cfg(not(windows))] // Windows requries more complex connection setup.
test_vhost_user_lifecycle_with_backend_req()1104 fn test_vhost_user_lifecycle_with_backend_req() {
1105 test_vhost_user_lifecycle_parameterized(true);
1106 }
1107
test_vhost_user_lifecycle_parameterized(allow_backend_req: bool)1108 fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1109 const QUEUES_NUM: usize = 2;
1110
1111 let (client_connection, server_connection) =
1112 vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
1113
1114 let vmm_bar = Arc::new(Barrier::new(2));
1115 let dev_bar = vmm_bar.clone();
1116
1117 let (ready_tx, ready_rx) = channel();
1118 let (shutdown_tx, shutdown_rx) = channel();
1119
1120 std::thread::spawn(move || {
1121 // VMM side
1122 ready_rx.recv().unwrap(); // Ensure the device is ready.
1123
1124 let mut vmm_device =
1125 VhostUserFrontend::new(DeviceType::Console, 0, client_connection, None, None)
1126 .unwrap();
1127
1128 println!("read_config");
1129 let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
1130 vmm_device.read_config(0, &mut buf);
1131 // Check if the obtained config data is correct.
1132 let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
1133 assert_eq!(config, FAKE_CONFIG_DATA);
1134
1135 let activate = |vmm_device: &mut VhostUserFrontend| {
1136 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1137 let interrupt = Interrupt::new_for_test_with_msix();
1138
1139 let mut queues = BTreeMap::new();
1140 for idx in 0..QUEUES_NUM {
1141 let mut queue = QueueConfig::new(0x10, 0);
1142 queue.set_ready(true);
1143 let queue = queue
1144 .activate(&mem, Event::new().unwrap(), interrupt.clone())
1145 .expect("QueueConfig::activate");
1146 queues.insert(idx, queue);
1147 }
1148
1149 println!("activate");
1150 vmm_device
1151 .activate(mem.clone(), interrupt.clone(), queues)
1152 .unwrap();
1153 };
1154
1155 activate(&mut vmm_device);
1156
1157 println!("reset");
1158 let reset_result = vmm_device.reset();
1159 assert!(
1160 reset_result.is_ok(),
1161 "reset failed: {:#}",
1162 reset_result.unwrap_err()
1163 );
1164
1165 activate(&mut vmm_device);
1166
1167 println!("virtio_sleep");
1168 let queues = vmm_device
1169 .virtio_sleep()
1170 .unwrap()
1171 .expect("virtio_sleep unexpectedly returned None");
1172
1173 println!("virtio_snapshot");
1174 let snapshot = vmm_device
1175 .virtio_snapshot()
1176 .expect("virtio_snapshot failed");
1177 println!("virtio_restore");
1178 vmm_device
1179 .virtio_restore(snapshot)
1180 .expect("virtio_restore failed");
1181
1182 println!("virtio_wake");
1183 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1184 let interrupt = Interrupt::new_for_test_with_msix();
1185 vmm_device
1186 .virtio_wake(Some((mem, interrupt, queues)))
1187 .unwrap();
1188
1189 println!("wait for shutdown signal");
1190 shutdown_rx.recv().unwrap();
1191
1192 // The VMM side is supposed to stop before the device side.
1193 println!("drop");
1194 drop(vmm_device);
1195
1196 vmm_bar.wait();
1197 });
1198
1199 // Device side
1200 let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1201 handler.as_mut().allow_backend_req = allow_backend_req;
1202
1203 // Notify listener is ready.
1204 ready_tx.send(()).unwrap();
1205
1206 let mut req_handler = BackendServer::new(server_connection, handler);
1207
1208 // VhostUserFrontend::new()
1209 handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1210 handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1211 handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1212 handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1213 handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1214 if allow_backend_req {
1215 handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1216 }
1217
1218 // VhostUserFrontend::read_config()
1219 handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1220
1221 // VhostUserFrontend::activate()
1222 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1223 for _ in 0..QUEUES_NUM {
1224 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1225 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1226 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1227 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1228 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1229 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1230 }
1231
1232 // VhostUserFrontend::reset()
1233 for _ in 0..QUEUES_NUM {
1234 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1235 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1236 }
1237
1238 // VhostUserFrontend::activate()
1239 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1240 for _ in 0..QUEUES_NUM {
1241 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1242 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1243 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1244 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1245 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1246 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1247 }
1248
1249 if allow_backend_req {
1250 // Make sure the connection still works even after reset/reactivate.
1251 req_handler
1252 .as_ref()
1253 .as_ref()
1254 .backend_conn
1255 .as_ref()
1256 .expect("backend_conn missing")
1257 .send_config_changed()
1258 .expect("send_config_changed failed");
1259 }
1260
1261 // VhostUserFrontend::virtio_sleep()
1262 for _ in 0..QUEUES_NUM {
1263 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1264 handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1265 }
1266
1267 // VhostUserFrontend::virtio_snapshot()
1268 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1269 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1270 // VhostUserFrontend::virtio_restore()
1271 handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1272 handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1273
1274 // VhostUserFrontend::virtio_wake()
1275 handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1276 for _ in 0..QUEUES_NUM {
1277 handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1278 handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1279 handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1280 handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1281 handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1282 handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1283 }
1284
1285 if allow_backend_req {
1286 // Make sure the connection still works even after sleep/wake.
1287 req_handler
1288 .as_ref()
1289 .as_ref()
1290 .backend_conn
1291 .as_ref()
1292 .expect("backend_conn missing")
1293 .send_config_changed()
1294 .expect("send_config_changed failed");
1295 }
1296
1297 // Ask the client to shutdown, then wait to it to finish.
1298 shutdown_tx.send(()).unwrap();
1299 dev_bar.wait();
1300
1301 // Verify recv_header fails with `ClientExit` after the client has disconnected.
1302 match req_handler.recv_header() {
1303 Err(VhostError::ClientExit) => (),
1304 r => panic!("expected Err(ClientExit) but got {:?}", r),
1305 }
1306 }
1307
handle_request<S: vmm_vhost::Backend>( handler: &mut BackendServer<S>, expected_message_type: FrontendReq, ) -> Result<(), VhostError>1308 fn handle_request<S: vmm_vhost::Backend>(
1309 handler: &mut BackendServer<S>,
1310 expected_message_type: FrontendReq,
1311 ) -> Result<(), VhostError> {
1312 let (hdr, files) = handler.recv_header()?;
1313 assert_eq!(hdr.get_code(), Ok(expected_message_type));
1314 handler.process_message(hdr, files)
1315 }
1316 }
1317