xref: /aosp_15_r20/external/crosvm/devices/src/virtio/vhost/user/device/handler.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2021 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Library for implementing vhost-user device executables.
6 //!
7 //! This crate provides
8 //! * `VhostUserDevice` trait, which is a collection of methods to handle vhost-user requests, and
9 //! * `DeviceRequestHandler` struct, which makes a connection to a VMM and starts an event loop.
10 //!
11 //! They are expected to be used as follows:
12 //!
13 //! 1. Define a struct and implement `VhostUserDevice` for it.
14 //! 2. Create a `DeviceRequestHandler` with the backend struct.
15 //! 3. Drive the `DeviceRequestHandler::run` async fn with an executor.
16 //!
17 //! ```ignore
18 //! struct MyBackend {
19 //!   /* fields */
20 //! }
21 //!
22 //! impl VhostUserDevice for MyBackend {
23 //!   /* implement methods */
24 //! }
25 //!
26 //! fn main() -> Result<(), Box<dyn Error>> {
27 //!   let backend = MyBackend { /* initialize fields */ };
28 //!   let handler = DeviceRequestHandler::new(backend);
29 //!   let socket = std::path::Path("/path/to/socket");
30 //!   let ex = cros_async::Executor::new()?;
31 //!
32 //!   if let Err(e) = ex.run_until(handler.run(socket, &ex)) {
33 //!     eprintln!("error happened: {}", e);
34 //!   }
35 //!   Ok(())
36 //! }
37 //! ```
38 // Implementation note:
39 // This code lets us take advantage of the vmm_vhost low level implementation of the vhost user
40 // protocol. DeviceRequestHandler implements the Backend trait from vmm_vhost, and includes some
41 // common code for setting up guest memory and managing partially configured vrings.
42 // DeviceRequestHandler::run watches the vhost-user socket and then calls handle_request() when it
43 // becomes readable. handle_request() reads and parses the message and then calls one of the
44 // Backend trait methods. These dispatch back to the supplied VhostUserDevice implementation (this
45 // is what our devices implement).
46 
47 pub(super) mod sys;
48 
49 use std::collections::BTreeMap;
50 use std::convert::From;
51 use std::fs::File;
52 use std::num::Wrapping;
53 #[cfg(any(target_os = "android", target_os = "linux"))]
54 use std::os::unix::io::AsRawFd;
55 use std::sync::Arc;
56 
57 use anyhow::bail;
58 use anyhow::Context;
59 #[cfg(any(target_os = "android", target_os = "linux"))]
60 use base::clear_fd_flags;
61 use base::error;
62 use base::trace;
63 use base::warn;
64 use base::Event;
65 use base::Protection;
66 use base::SafeDescriptor;
67 use base::SharedMemory;
68 use base::WorkerThread;
69 use cros_async::TaskHandle;
70 use hypervisor::MemCacheType;
71 use serde::Deserialize;
72 use serde::Serialize;
73 use sync::Mutex;
74 use thiserror::Error as ThisError;
75 use vm_control::VmMemorySource;
76 use vm_memory::GuestAddress;
77 use vm_memory::GuestMemory;
78 use vm_memory::MemoryRegion;
79 use vmm_vhost::message::VhostSharedMemoryRegion;
80 use vmm_vhost::message::VhostUserConfigFlags;
81 use vmm_vhost::message::VhostUserExternalMapMsg;
82 use vmm_vhost::message::VhostUserGpuMapMsg;
83 use vmm_vhost::message::VhostUserInflight;
84 use vmm_vhost::message::VhostUserMemoryRegion;
85 use vmm_vhost::message::VhostUserMigrationPhase;
86 use vmm_vhost::message::VhostUserProtocolFeatures;
87 use vmm_vhost::message::VhostUserShmemMapMsg;
88 use vmm_vhost::message::VhostUserShmemMapMsgFlags;
89 use vmm_vhost::message::VhostUserShmemUnmapMsg;
90 use vmm_vhost::message::VhostUserSingleMemoryRegion;
91 use vmm_vhost::message::VhostUserTransferDirection;
92 use vmm_vhost::message::VhostUserVringAddrFlags;
93 use vmm_vhost::message::VhostUserVringState;
94 use vmm_vhost::BackendReq;
95 use vmm_vhost::Connection;
96 use vmm_vhost::Error as VhostError;
97 use vmm_vhost::Frontend;
98 use vmm_vhost::FrontendClient;
99 use vmm_vhost::Result as VhostResult;
100 use vmm_vhost::VHOST_USER_F_PROTOCOL_FEATURES;
101 
102 use crate::virtio::Interrupt;
103 use crate::virtio::Queue;
104 use crate::virtio::QueueConfig;
105 use crate::virtio::SharedMemoryMapper;
106 use crate::virtio::SharedMemoryRegion;
107 
108 /// Keeps a mapping from the vmm's virtual addresses to guest addresses.
109 /// used to translate messages from the vmm to guest offsets.
110 #[derive(Default)]
111 pub struct MappingInfo {
112     pub vmm_addr: u64,
113     pub guest_phys: u64,
114     pub size: u64,
115 }
116 
vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress>117 pub fn vmm_va_to_gpa(maps: &[MappingInfo], vmm_va: u64) -> VhostResult<GuestAddress> {
118     for map in maps {
119         if vmm_va >= map.vmm_addr && vmm_va < map.vmm_addr + map.size {
120             return Ok(GuestAddress(vmm_va - map.vmm_addr + map.guest_phys));
121         }
122     }
123     Err(VhostError::InvalidMessage)
124 }
125 
126 /// Trait for vhost-user devices. Analogous to the `VirtioDevice` trait.
127 ///
128 /// In contrast with [[vmm_vhost::Backend]], which closely matches the vhost-user spec, this trait
129 /// is designed to follow crosvm conventions for implementing devices.
130 pub trait VhostUserDevice {
131     /// The maximum number of queues that this backend can manage.
max_queue_num(&self) -> usize132     fn max_queue_num(&self) -> usize;
133 
134     /// The set of feature bits that this backend supports.
features(&self) -> u64135     fn features(&self) -> u64;
136 
137     /// Acknowledges that this set of features should be enabled.
138     ///
139     /// Implementations only need to handle device-specific feature bits; the `DeviceRequestHandler`
140     /// framework will manage generic vhost and vring features.
141     ///
142     /// `DeviceRequestHandler` checks for valid features before calling this function, so the
143     /// features in `value` will always be a subset of those advertised by `features()`.
ack_features(&mut self, _value: u64) -> anyhow::Result<()>144     fn ack_features(&mut self, _value: u64) -> anyhow::Result<()> {
145         Ok(())
146     }
147 
148     /// The set of protocol feature bits that this backend supports.
protocol_features(&self) -> VhostUserProtocolFeatures149     fn protocol_features(&self) -> VhostUserProtocolFeatures;
150 
151     /// Reads this device configuration space at `offset`.
read_config(&self, offset: u64, dst: &mut [u8])152     fn read_config(&self, offset: u64, dst: &mut [u8]);
153 
154     /// writes `data` to this device's configuration space at `offset`.
write_config(&self, _offset: u64, _data: &[u8])155     fn write_config(&self, _offset: u64, _data: &[u8]) {}
156 
157     /// Indicates that the backend should start processing requests for virtio queue number `idx`.
158     /// This method must not block the current thread so device backends should either spawn an
159     /// async task or another thread to handle messages from the Queue.
start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>160     fn start_queue(&mut self, idx: usize, queue: Queue, mem: GuestMemory) -> anyhow::Result<()>;
161 
162     /// Indicates that the backend should stop processing requests for virtio queue number `idx`.
163     /// This method should return the queue passed to `start_queue` for the corresponding `idx`.
164     /// This method will only be called for queues that were previously started by `start_queue`.
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>165     fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>;
166 
167     /// Resets the vhost-user backend.
reset(&mut self)168     fn reset(&mut self);
169 
170     /// Returns the device's shared memory region if present.
get_shared_memory_region(&self) -> Option<SharedMemoryRegion>171     fn get_shared_memory_region(&self) -> Option<SharedMemoryRegion> {
172         None
173     }
174 
175     /// Accepts `VhostBackendReqConnection` to conduct Vhost backend to frontend message
176     /// handling.
177     ///
178     /// The backend is given an `Arc` instead of full ownership so that the framework can also use
179     /// the connection.
180     ///
181     /// This method will be called when `VhostUserProtocolFeatures::BACKEND_REQ` is
182     /// negotiated.
set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>)183     fn set_backend_req_connection(&mut self, _conn: Arc<VhostBackendReqConnection>) {}
184 
185     /// Enter the "suspended device state" described in the vhost-user spec. See the spec for
186     /// requirements.
187     ///
188     /// One reasonably foolproof way to satisfy the requirements is to stop all worker threads.
189     ///
190     /// Called after a `stop_queue` call if there are no running queues left. Also called soon
191     /// after device creation to ensure the device is acting suspended immediately on construction.
192     ///
193     /// The next `start_queue` call implicitly exits the "suspend device state".
194     ///
195     /// * Ok(())    => device successfully suspended
196     /// * Err(_)    => unrecoverable error
enter_suspended_state(&mut self) -> anyhow::Result<()>197     fn enter_suspended_state(&mut self) -> anyhow::Result<()>;
198 
199     /// Snapshot device and return serialized state.
snapshot(&mut self) -> anyhow::Result<serde_json::Value>200     fn snapshot(&mut self) -> anyhow::Result<serde_json::Value>;
201 
202     /// Restore device state from a snapshot.
restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>203     fn restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>;
204 }
205 
206 /// A virtio ring entry.
207 struct Vring {
208     // The queue config. This doesn't get mutated by the queue workers.
209     queue: QueueConfig,
210     doorbell: Option<Interrupt>,
211     enabled: bool,
212 }
213 
214 impl Vring {
new(max_size: u16, features: u64) -> Self215     fn new(max_size: u16, features: u64) -> Self {
216         Self {
217             queue: QueueConfig::new(max_size, features),
218             doorbell: None,
219             enabled: false,
220         }
221     }
222 
reset(&mut self)223     fn reset(&mut self) {
224         self.queue.reset();
225         self.doorbell = None;
226         self.enabled = false;
227     }
228 }
229 
230 /// Ops for running vhost-user over a stream (i.e. regular protocol).
231 pub(super) struct VhostUserRegularOps;
232 
233 impl VhostUserRegularOps {
set_mem_table( contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)>234     pub fn set_mem_table(
235         contexts: &[VhostUserMemoryRegion],
236         files: Vec<File>,
237     ) -> VhostResult<(GuestMemory, Vec<MappingInfo>)> {
238         if files.len() != contexts.len() {
239             return Err(VhostError::InvalidParam);
240         }
241 
242         let mut regions = Vec::with_capacity(files.len());
243         for (region, file) in contexts.iter().zip(files.into_iter()) {
244             let region = MemoryRegion::new_from_shm(
245                 region.memory_size,
246                 GuestAddress(region.guest_phys_addr),
247                 region.mmap_offset,
248                 Arc::new(
249                     SharedMemory::from_safe_descriptor(
250                         SafeDescriptor::from(file),
251                         region.memory_size,
252                     )
253                     .unwrap(),
254                 ),
255             )
256             .map_err(|e| {
257                 error!("failed to create a memory region: {}", e);
258                 VhostError::InvalidOperation
259             })?;
260             regions.push(region);
261         }
262         let guest_mem = GuestMemory::from_regions(regions).map_err(|e| {
263             error!("failed to create guest memory: {}", e);
264             VhostError::InvalidOperation
265         })?;
266 
267         let vmm_maps = contexts
268             .iter()
269             .map(|region| MappingInfo {
270                 vmm_addr: region.user_addr,
271                 guest_phys: region.guest_phys_addr,
272                 size: region.memory_size,
273             })
274             .collect();
275         Ok((guest_mem, vmm_maps))
276     }
277 
set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event>278     pub fn set_vring_kick(_index: u8, file: Option<File>) -> VhostResult<Event> {
279         let file = file.ok_or(VhostError::InvalidParam)?;
280         // Remove O_NONBLOCK from kick_fd. Otherwise, uring_executor will fails when we read
281         // values via `next_val()` later.
282         // This is only required (and can only be done) on Unix platforms.
283         #[cfg(any(target_os = "android", target_os = "linux"))]
284         if let Err(e) = clear_fd_flags(file.as_raw_fd(), libc::O_NONBLOCK) {
285             error!("failed to remove O_NONBLOCK for kick fd: {}", e);
286             return Err(VhostError::InvalidParam);
287         }
288         Ok(Event::from(SafeDescriptor::from(file)))
289     }
290 
set_vring_call( _index: u8, file: Option<File>, signal_config_change_fn: Box<dyn Fn() + Send + Sync>, ) -> VhostResult<Interrupt>291     pub fn set_vring_call(
292         _index: u8,
293         file: Option<File>,
294         signal_config_change_fn: Box<dyn Fn() + Send + Sync>,
295     ) -> VhostResult<Interrupt> {
296         let file = file.ok_or(VhostError::InvalidParam)?;
297         Ok(Interrupt::new_vhost_user(
298             Event::from(SafeDescriptor::from(file)),
299             signal_config_change_fn,
300         ))
301     }
302 }
303 
304 /// An adapter that implements `vmm_vhost::Backend` for any type implementing `VhostUserDevice`.
305 pub struct DeviceRequestHandler<T: VhostUserDevice> {
306     vrings: Vec<Vring>,
307     owned: bool,
308     vmm_maps: Option<Vec<MappingInfo>>,
309     mem: Option<GuestMemory>,
310     acked_features: u64,
311     acked_protocol_features: VhostUserProtocolFeatures,
312     backend: T,
313     backend_req_connection: Arc<Mutex<VhostBackendReqConnectionState>>,
314     // Thread processing active device state FD.
315     device_state_thread: Option<DeviceStateThread>,
316 }
317 
318 enum DeviceStateThread {
319     Save(WorkerThread<serde_json::Result<()>>),
320     Load(WorkerThread<serde_json::Result<DeviceRequestHandlerSnapshot>>),
321 }
322 
323 #[derive(Serialize, Deserialize)]
324 pub struct DeviceRequestHandlerSnapshot {
325     acked_features: u64,
326     acked_protocol_features: u64,
327     backend: serde_json::Value,
328 }
329 
330 impl<T: VhostUserDevice> DeviceRequestHandler<T> {
331     /// Creates a vhost-user handler instance for `backend`.
new(mut backend: T) -> Self332     pub(crate) fn new(mut backend: T) -> Self {
333         let mut vrings = Vec::with_capacity(backend.max_queue_num());
334         for _ in 0..backend.max_queue_num() {
335             vrings.push(Vring::new(Queue::MAX_SIZE, backend.features()));
336         }
337 
338         // VhostUserDevice implementations must support `enter_suspended_state()`.
339         // Call it on startup to ensure it works and to initialize the device in a suspended state.
340         backend
341             .enter_suspended_state()
342             .expect("enter_suspended_state failed on device init");
343 
344         DeviceRequestHandler {
345             vrings,
346             owned: false,
347             vmm_maps: None,
348             mem: None,
349             acked_features: 0,
350             acked_protocol_features: VhostUserProtocolFeatures::empty(),
351             backend,
352             backend_req_connection: Arc::new(Mutex::new(
353                 VhostBackendReqConnectionState::NoConnection,
354             )),
355             device_state_thread: None,
356         }
357     }
358 
359     /// Check if all queues are stopped.
360     ///
361     /// The device can be suspended with `enter_suspended_state()` only when all queues are stopped.
all_queues_stopped(&self) -> bool362     fn all_queues_stopped(&self) -> bool {
363         self.vrings.iter().all(|vring| !vring.queue.ready())
364     }
365 }
366 
367 impl<T: VhostUserDevice> AsRef<T> for DeviceRequestHandler<T> {
as_ref(&self) -> &T368     fn as_ref(&self) -> &T {
369         &self.backend
370     }
371 }
372 
373 impl<T: VhostUserDevice> AsMut<T> for DeviceRequestHandler<T> {
as_mut(&mut self) -> &mut T374     fn as_mut(&mut self) -> &mut T {
375         &mut self.backend
376     }
377 }
378 
379 impl<T: VhostUserDevice> vmm_vhost::Backend for DeviceRequestHandler<T> {
set_owner(&mut self) -> VhostResult<()>380     fn set_owner(&mut self) -> VhostResult<()> {
381         if self.owned {
382             return Err(VhostError::InvalidOperation);
383         }
384         self.owned = true;
385         Ok(())
386     }
387 
reset_owner(&mut self) -> VhostResult<()>388     fn reset_owner(&mut self) -> VhostResult<()> {
389         self.owned = false;
390         self.acked_features = 0;
391         self.backend.reset();
392         Ok(())
393     }
394 
get_features(&mut self) -> VhostResult<u64>395     fn get_features(&mut self) -> VhostResult<u64> {
396         let features = self.backend.features();
397         Ok(features)
398     }
399 
set_features(&mut self, features: u64) -> VhostResult<()>400     fn set_features(&mut self, features: u64) -> VhostResult<()> {
401         if !self.owned {
402             return Err(VhostError::InvalidOperation);
403         }
404 
405         let unexpected_features = features & !self.backend.features();
406         if unexpected_features != 0 {
407             error!("unexpected set_features {:#x}", unexpected_features);
408             return Err(VhostError::InvalidParam);
409         }
410 
411         if let Err(e) = self.backend.ack_features(features) {
412             error!("failed to acknowledge features 0x{:x}: {}", features, e);
413             return Err(VhostError::InvalidOperation);
414         }
415 
416         self.acked_features |= features;
417 
418         // If VHOST_USER_F_PROTOCOL_FEATURES has not been negotiated, the ring is initialized in an
419         // enabled state.
420         // If VHOST_USER_F_PROTOCOL_FEATURES has been negotiated, the ring is initialized in a
421         // disabled state.
422         // Client must not pass data to/from the backend until ring is enabled by
423         // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
424         // VHOST_USER_SET_VRING_ENABLE with parameter 0.
425         let vring_enabled = self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES != 0;
426         for v in &mut self.vrings {
427             v.enabled = vring_enabled;
428         }
429 
430         Ok(())
431     }
432 
get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures>433     fn get_protocol_features(&mut self) -> VhostResult<VhostUserProtocolFeatures> {
434         Ok(self.backend.protocol_features())
435     }
436 
set_protocol_features(&mut self, features: u64) -> VhostResult<()>437     fn set_protocol_features(&mut self, features: u64) -> VhostResult<()> {
438         let features = match VhostUserProtocolFeatures::from_bits(features) {
439             Some(proto_features) => proto_features,
440             None => {
441                 error!(
442                     "unsupported bits in VHOST_USER_SET_PROTOCOL_FEATURES: {:#x}",
443                     features
444                 );
445                 return Err(VhostError::InvalidOperation);
446             }
447         };
448         let supported = self.backend.protocol_features();
449         self.acked_protocol_features = features & supported;
450         Ok(())
451     }
452 
set_mem_table( &mut self, contexts: &[VhostUserMemoryRegion], files: Vec<File>, ) -> VhostResult<()>453     fn set_mem_table(
454         &mut self,
455         contexts: &[VhostUserMemoryRegion],
456         files: Vec<File>,
457     ) -> VhostResult<()> {
458         let (guest_mem, vmm_maps) = VhostUserRegularOps::set_mem_table(contexts, files)?;
459         self.mem = Some(guest_mem);
460         self.vmm_maps = Some(vmm_maps);
461         Ok(())
462     }
463 
get_queue_num(&mut self) -> VhostResult<u64>464     fn get_queue_num(&mut self) -> VhostResult<u64> {
465         Ok(self.vrings.len() as u64)
466     }
467 
set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()>468     fn set_vring_num(&mut self, index: u32, num: u32) -> VhostResult<()> {
469         if index as usize >= self.vrings.len() || num == 0 || num > Queue::MAX_SIZE.into() {
470             return Err(VhostError::InvalidParam);
471         }
472         self.vrings[index as usize].queue.set_size(num as u16);
473 
474         Ok(())
475     }
476 
set_vring_addr( &mut self, index: u32, _flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, _log: u64, ) -> VhostResult<()>477     fn set_vring_addr(
478         &mut self,
479         index: u32,
480         _flags: VhostUserVringAddrFlags,
481         descriptor: u64,
482         used: u64,
483         available: u64,
484         _log: u64,
485     ) -> VhostResult<()> {
486         if index as usize >= self.vrings.len() {
487             return Err(VhostError::InvalidParam);
488         }
489 
490         let vmm_maps = self.vmm_maps.as_ref().ok_or(VhostError::InvalidParam)?;
491         let vring = &mut self.vrings[index as usize];
492         vring
493             .queue
494             .set_desc_table(vmm_va_to_gpa(vmm_maps, descriptor)?);
495         vring
496             .queue
497             .set_avail_ring(vmm_va_to_gpa(vmm_maps, available)?);
498         vring.queue.set_used_ring(vmm_va_to_gpa(vmm_maps, used)?);
499 
500         Ok(())
501     }
502 
set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()>503     fn set_vring_base(&mut self, index: u32, base: u32) -> VhostResult<()> {
504         if index as usize >= self.vrings.len() || base >= Queue::MAX_SIZE.into() {
505             return Err(VhostError::InvalidParam);
506         }
507 
508         let vring = &mut self.vrings[index as usize];
509         vring.queue.set_next_avail(Wrapping(base as u16));
510         vring.queue.set_next_used(Wrapping(base as u16));
511 
512         Ok(())
513     }
514 
get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState>515     fn get_vring_base(&mut self, index: u32) -> VhostResult<VhostUserVringState> {
516         let vring = self
517             .vrings
518             .get_mut(index as usize)
519             .ok_or(VhostError::InvalidParam)?;
520 
521         // Quotation from vhost-user spec:
522         // "The back-end must [...] stop ring upon receiving VHOST_USER_GET_VRING_BASE."
523         // We only call `queue.set_ready()` when starting the queue, so if the queue is ready, that
524         // means it is started and should be stopped.
525         let vring_base = if vring.queue.ready() {
526             let queue = match self.backend.stop_queue(index as usize) {
527                 Ok(q) => q,
528                 Err(e) => {
529                     error!("Failed to stop queue in get_vring_base: {:#}", e);
530                     return Err(VhostError::BackendInternalError);
531                 }
532             };
533 
534             trace!("stopped queue {index}");
535 
536             vring.reset();
537 
538             if self.all_queues_stopped() {
539                 trace!("all queues stopped; entering suspended state");
540                 self.backend
541                     .enter_suspended_state()
542                     .map_err(VhostError::EnterSuspendedState)?;
543             }
544 
545             queue.next_avail_to_process()
546         } else {
547             0
548         };
549 
550         Ok(VhostUserVringState::new(index, vring_base.into()))
551     }
552 
set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()>553     fn set_vring_kick(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
554         if index as usize >= self.vrings.len() {
555             return Err(VhostError::InvalidParam);
556         }
557 
558         let vring = &mut self.vrings[index as usize];
559         if vring.queue.ready() {
560             error!("kick fd cannot replaced after queue is started");
561             return Err(VhostError::InvalidOperation);
562         }
563 
564         let kick_evt = VhostUserRegularOps::set_vring_kick(index, file)?;
565 
566         // Enable any virtqueue features that were negotiated (like VIRTIO_RING_F_EVENT_IDX).
567         vring.queue.ack_features(self.acked_features);
568         vring.queue.set_ready(true);
569 
570         let mem = self
571             .mem
572             .as_ref()
573             .cloned()
574             .ok_or(VhostError::InvalidOperation)?;
575 
576         let doorbell = vring.doorbell.clone().ok_or(VhostError::InvalidOperation)?;
577 
578         let queue = match vring.queue.activate(&mem, kick_evt, doorbell) {
579             Ok(queue) => queue,
580             Err(e) => {
581                 error!("failed to activate vring: {:#}", e);
582                 return Err(VhostError::BackendInternalError);
583             }
584         };
585 
586         if let Err(e) = self.backend.start_queue(index as usize, queue, mem) {
587             error!("Failed to start queue {}: {}", index, e);
588             return Err(VhostError::BackendInternalError);
589         }
590 
591         trace!("started queue {index}");
592 
593         Ok(())
594     }
595 
set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()>596     fn set_vring_call(&mut self, index: u8, file: Option<File>) -> VhostResult<()> {
597         if index as usize >= self.vrings.len() {
598             return Err(VhostError::InvalidParam);
599         }
600 
601         let backend_req_conn = self.backend_req_connection.clone();
602         let signal_config_change_fn = Box::new(move || match &*backend_req_conn.lock() {
603             VhostBackendReqConnectionState::Connected(frontend) => {
604                 if let Err(e) = frontend.send_config_changed() {
605                     error!("Failed to notify config change: {:#}", e);
606                 }
607             }
608             VhostBackendReqConnectionState::NoConnection => {
609                 error!("No Backend request connection found");
610             }
611         });
612 
613         let doorbell = VhostUserRegularOps::set_vring_call(index, file, signal_config_change_fn)?;
614         self.vrings[index as usize].doorbell = Some(doorbell);
615         Ok(())
616     }
617 
set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()>618     fn set_vring_err(&mut self, _index: u8, _fd: Option<File>) -> VhostResult<()> {
619         // TODO
620         Ok(())
621     }
622 
set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()>623     fn set_vring_enable(&mut self, index: u32, enable: bool) -> VhostResult<()> {
624         if index as usize >= self.vrings.len() {
625             return Err(VhostError::InvalidParam);
626         }
627 
628         // This request should be handled only when VHOST_USER_F_PROTOCOL_FEATURES
629         // has been negotiated.
630         if self.acked_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
631             return Err(VhostError::InvalidOperation);
632         }
633 
634         // Backend must not pass data to/from the ring until ring is enabled by
635         // VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been disabled by
636         // VHOST_USER_SET_VRING_ENABLE with parameter 0.
637         self.vrings[index as usize].enabled = enable;
638 
639         Ok(())
640     }
641 
get_config( &mut self, offset: u32, size: u32, _flags: VhostUserConfigFlags, ) -> VhostResult<Vec<u8>>642     fn get_config(
643         &mut self,
644         offset: u32,
645         size: u32,
646         _flags: VhostUserConfigFlags,
647     ) -> VhostResult<Vec<u8>> {
648         let mut data = vec![0; size as usize];
649         self.backend.read_config(u64::from(offset), &mut data);
650         Ok(data)
651     }
652 
set_config( &mut self, offset: u32, buf: &[u8], _flags: VhostUserConfigFlags, ) -> VhostResult<()>653     fn set_config(
654         &mut self,
655         offset: u32,
656         buf: &[u8],
657         _flags: VhostUserConfigFlags,
658     ) -> VhostResult<()> {
659         self.backend.write_config(u64::from(offset), buf);
660         Ok(())
661     }
662 
set_backend_req_fd(&mut self, ep: Connection<BackendReq>)663     fn set_backend_req_fd(&mut self, ep: Connection<BackendReq>) {
664         let conn = Arc::new(VhostBackendReqConnection::new(
665             FrontendClient::new(ep),
666             self.backend.get_shared_memory_region().map(|r| r.id),
667         ));
668 
669         {
670             let mut backend_req_conn = self.backend_req_connection.lock();
671             if let VhostBackendReqConnectionState::Connected(_) = &*backend_req_conn {
672                 warn!("Backend Request Connection already established. Overwriting");
673             }
674             *backend_req_conn = VhostBackendReqConnectionState::Connected(conn.clone());
675         }
676 
677         self.backend.set_backend_req_connection(conn);
678     }
679 
get_inflight_fd( &mut self, _inflight: &VhostUserInflight, ) -> VhostResult<(VhostUserInflight, File)>680     fn get_inflight_fd(
681         &mut self,
682         _inflight: &VhostUserInflight,
683     ) -> VhostResult<(VhostUserInflight, File)> {
684         unimplemented!("get_inflight_fd");
685     }
686 
set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()>687     fn set_inflight_fd(&mut self, _inflight: &VhostUserInflight, _file: File) -> VhostResult<()> {
688         unimplemented!("set_inflight_fd");
689     }
690 
get_max_mem_slots(&mut self) -> VhostResult<u64>691     fn get_max_mem_slots(&mut self) -> VhostResult<u64> {
692         //TODO
693         Ok(0)
694     }
695 
add_mem_region( &mut self, _region: &VhostUserSingleMemoryRegion, _fd: File, ) -> VhostResult<()>696     fn add_mem_region(
697         &mut self,
698         _region: &VhostUserSingleMemoryRegion,
699         _fd: File,
700     ) -> VhostResult<()> {
701         //TODO
702         Ok(())
703     }
704 
remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()>705     fn remove_mem_region(&mut self, _region: &VhostUserSingleMemoryRegion) -> VhostResult<()> {
706         //TODO
707         Ok(())
708     }
709 
set_device_state_fd( &mut self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, mut fd: File, ) -> VhostResult<Option<File>>710     fn set_device_state_fd(
711         &mut self,
712         transfer_direction: VhostUserTransferDirection,
713         migration_phase: VhostUserMigrationPhase,
714         mut fd: File,
715     ) -> VhostResult<Option<File>> {
716         if migration_phase != VhostUserMigrationPhase::Stopped {
717             return Err(VhostError::InvalidOperation);
718         }
719         if !self.all_queues_stopped() {
720             return Err(VhostError::InvalidOperation);
721         }
722         if self.device_state_thread.is_some() {
723             error!("must call check_device_state before starting new state transfer");
724             return Err(VhostError::InvalidOperation);
725         }
726         // `set_device_state_fd` is designed to allow snapshot/restore concurrently with other
727         // methods, but, for simplicitly, we do those operations inline and only spawn a thread to
728         // handle the serialization and data transfer (the latter which seems necessary to
729         // implement the API correctly without, e.g., deadlocking because a pipe is full).
730         match transfer_direction {
731             VhostUserTransferDirection::Save => {
732                 // Snapshot the state.
733                 let snapshot = DeviceRequestHandlerSnapshot {
734                     acked_features: self.acked_features,
735                     acked_protocol_features: self.acked_protocol_features.bits(),
736                     backend: self.backend.snapshot().map_err(VhostError::SnapshotError)?,
737                 };
738                 // Spawn thread to write the serialized bytes.
739                 self.device_state_thread = Some(DeviceStateThread::Save(WorkerThread::start(
740                     "device_state_save",
741                     move |_kill_event| serde_json::to_writer(&mut fd, &snapshot),
742                 )));
743                 Ok(None)
744             }
745             VhostUserTransferDirection::Load => {
746                 // Spawn a thread to read the bytes and deserialize. Restore will happen in
747                 // `check_device_state`.
748                 self.device_state_thread = Some(DeviceStateThread::Load(WorkerThread::start(
749                     "device_state_load",
750                     move |_kill_event| serde_json::from_reader(&mut fd),
751                 )));
752                 Ok(None)
753             }
754         }
755     }
756 
check_device_state(&mut self) -> VhostResult<()>757     fn check_device_state(&mut self) -> VhostResult<()> {
758         let Some(thread) = self.device_state_thread.take() else {
759             error!("check_device_state: no active state transfer");
760             return Err(VhostError::InvalidOperation);
761         };
762         match thread {
763             DeviceStateThread::Save(worker) => {
764                 worker.stop().map_err(|e| {
765                     error!("device state save thread failed: {:#}", e);
766                     VhostError::BackendInternalError
767                 })?;
768                 Ok(())
769             }
770             DeviceStateThread::Load(worker) => {
771                 let snapshot = worker.stop().map_err(|e| {
772                     error!("device state load thread failed: {:#}", e);
773                     VhostError::BackendInternalError
774                 })?;
775                 self.acked_features = snapshot.acked_features;
776                 self.acked_protocol_features =
777                     VhostUserProtocolFeatures::from_bits(snapshot.acked_protocol_features)
778                         .with_context(|| {
779                             format!(
780                                 "unsupported bits in acked_protocol_features: {:#x}",
781                                 snapshot.acked_protocol_features
782                             )
783                         })
784                         .map_err(VhostError::RestoreError)?;
785                 self.backend
786                     .restore(snapshot.backend)
787                     .map_err(VhostError::RestoreError)?;
788                 Ok(())
789             }
790         }
791     }
792 
get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>>793     fn get_shared_memory_regions(&mut self) -> VhostResult<Vec<VhostSharedMemoryRegion>> {
794         Ok(if let Some(r) = self.backend.get_shared_memory_region() {
795             vec![VhostSharedMemoryRegion::new(r.id, r.length)]
796         } else {
797             Vec::new()
798         })
799     }
800 }
801 
802 /// Indicates the state of backend request connection
803 pub enum VhostBackendReqConnectionState {
804     /// A backend request connection (`VhostBackendReqConnection`) is established
805     Connected(Arc<VhostBackendReqConnection>),
806     /// No backend request connection has been established yet
807     NoConnection,
808 }
809 
810 /// Keeps track of Vhost user backend request connection.
811 pub struct VhostBackendReqConnection {
812     conn: Arc<Mutex<FrontendClient>>,
813     shmem_info: Mutex<Option<ShmemInfo>>,
814 }
815 
816 #[derive(Clone)]
817 struct ShmemInfo {
818     shmid: u8,
819     mapped_regions: BTreeMap<u64 /* offset */, u64 /* size */>,
820 }
821 
822 impl VhostBackendReqConnection {
new(conn: FrontendClient, shmid: Option<u8>) -> Self823     pub fn new(conn: FrontendClient, shmid: Option<u8>) -> Self {
824         let shmem_info = Mutex::new(shmid.map(|shmid| ShmemInfo {
825             shmid,
826             mapped_regions: BTreeMap::new(),
827         }));
828         Self {
829             conn: Arc::new(Mutex::new(conn)),
830             shmem_info,
831         }
832     }
833 
834     /// Send `VHOST_USER_CONFIG_CHANGE_MSG` to the frontend
send_config_changed(&self) -> anyhow::Result<()>835     pub fn send_config_changed(&self) -> anyhow::Result<()> {
836         self.conn
837             .lock()
838             .handle_config_change()
839             .context("Could not send config change message")?;
840         Ok(())
841     }
842 
843     /// Create a SharedMemoryMapper trait object from the ShmemInfo.
take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>>844     pub fn take_shmem_mapper(&self) -> anyhow::Result<Box<dyn SharedMemoryMapper>> {
845         let shmem_info = self
846             .shmem_info
847             .lock()
848             .take()
849             .context("could not take shared memory mapper information")?;
850 
851         Ok(Box::new(VhostShmemMapper {
852             conn: self.conn.clone(),
853             shmem_info,
854         }))
855     }
856 }
857 
858 struct VhostShmemMapper {
859     conn: Arc<Mutex<FrontendClient>>,
860     shmem_info: ShmemInfo,
861 }
862 
863 impl SharedMemoryMapper for VhostShmemMapper {
add_mapping( &mut self, source: VmMemorySource, offset: u64, prot: Protection, _cache: MemCacheType, ) -> anyhow::Result<()>864     fn add_mapping(
865         &mut self,
866         source: VmMemorySource,
867         offset: u64,
868         prot: Protection,
869         _cache: MemCacheType,
870     ) -> anyhow::Result<()> {
871         let size = match source {
872             VmMemorySource::Vulkan {
873                 descriptor,
874                 handle_type,
875                 memory_idx,
876                 device_uuid,
877                 driver_uuid,
878                 size,
879             } => {
880                 let msg = VhostUserGpuMapMsg::new(
881                     self.shmem_info.shmid,
882                     offset,
883                     size,
884                     memory_idx,
885                     handle_type,
886                     device_uuid,
887                     driver_uuid,
888                 );
889                 self.conn
890                     .lock()
891                     .gpu_map(&msg, &descriptor)
892                     .context("failed to map memory")?;
893                 size
894             }
895             VmMemorySource::ExternalMapping { ptr, size } => {
896                 let msg = VhostUserExternalMapMsg::new(self.shmem_info.shmid, offset, size, ptr);
897                 self.conn
898                     .lock()
899                     .external_map(&msg)
900                     .context("failed to map memory")?;
901                 size
902             }
903             source => {
904                 // The last two sources use the same VhostUserShmemMapMsg, continue matching here
905                 // on the aliased `source` above.
906                 let (descriptor, fd_offset, size) = match source {
907                     VmMemorySource::Descriptor {
908                         descriptor,
909                         offset,
910                         size,
911                     } => (descriptor, offset, size),
912                     VmMemorySource::SharedMemory(shmem) => {
913                         let size = shmem.size();
914                         let descriptor = SafeDescriptor::from(shmem);
915                         (descriptor, 0, size)
916                     }
917                     _ => bail!("unsupported source"),
918                 };
919                 let flags = VhostUserShmemMapMsgFlags::from(prot);
920                 let msg = VhostUserShmemMapMsg::new(
921                     self.shmem_info.shmid,
922                     offset,
923                     fd_offset,
924                     size,
925                     flags,
926                 );
927                 self.conn
928                     .lock()
929                     .shmem_map(&msg, &descriptor)
930                     .context("failed to map memory")?;
931                 size
932             }
933         };
934 
935         self.shmem_info.mapped_regions.insert(offset, size);
936         Ok(())
937     }
938 
remove_mapping(&mut self, offset: u64) -> anyhow::Result<()>939     fn remove_mapping(&mut self, offset: u64) -> anyhow::Result<()> {
940         let size = self
941             .shmem_info
942             .mapped_regions
943             .remove(&offset)
944             .context("unknown offset")?;
945         let msg = VhostUserShmemUnmapMsg::new(self.shmem_info.shmid, offset, size);
946         self.conn
947             .lock()
948             .shmem_unmap(&msg)
949             .context("failed to map memory")
950             .map(|_| ())
951     }
952 }
953 
954 pub(crate) struct WorkerState<T, U> {
955     pub(crate) queue_task: TaskHandle<U>,
956     pub(crate) queue: T,
957 }
958 
959 /// Errors for device operations
960 #[derive(Debug, ThisError)]
961 pub enum Error {
962     #[error("worker not found when stopping queue")]
963     WorkerNotFound,
964 }
965 
966 #[cfg(test)]
967 mod tests {
968     use std::sync::mpsc::channel;
969     use std::sync::Barrier;
970 
971     use anyhow::bail;
972     use base::Event;
973     use vmm_vhost::BackendServer;
974     use vmm_vhost::FrontendReq;
975     use zerocopy::AsBytes;
976     use zerocopy::FromBytes;
977     use zerocopy::FromZeroes;
978 
979     use super::*;
980     use crate::virtio::vhost_user_frontend::VhostUserFrontend;
981     use crate::virtio::DeviceType;
982     use crate::virtio::VirtioDevice;
983 
984     #[derive(Clone, Copy, Debug, PartialEq, Eq, AsBytes, FromZeroes, FromBytes)]
985     #[repr(C, packed(4))]
986     struct FakeConfig {
987         x: u32,
988         y: u64,
989     }
990 
991     const FAKE_CONFIG_DATA: FakeConfig = FakeConfig { x: 1, y: 2 };
992 
993     pub(super) struct FakeBackend {
994         avail_features: u64,
995         acked_features: u64,
996         active_queues: Vec<Option<Queue>>,
997         allow_backend_req: bool,
998         backend_conn: Option<Arc<VhostBackendReqConnection>>,
999     }
1000 
1001     #[derive(Deserialize, Serialize)]
1002     struct FakeBackendSnapshot {
1003         data: Vec<u8>,
1004     }
1005 
1006     impl FakeBackend {
1007         const MAX_QUEUE_NUM: usize = 16;
1008 
new() -> Self1009         pub(super) fn new() -> Self {
1010             let mut active_queues = Vec::new();
1011             active_queues.resize_with(Self::MAX_QUEUE_NUM, Default::default);
1012             Self {
1013                 avail_features: 1 << VHOST_USER_F_PROTOCOL_FEATURES,
1014                 acked_features: 0,
1015                 active_queues,
1016                 allow_backend_req: false,
1017                 backend_conn: None,
1018             }
1019         }
1020     }
1021 
1022     impl VhostUserDevice for FakeBackend {
max_queue_num(&self) -> usize1023         fn max_queue_num(&self) -> usize {
1024             Self::MAX_QUEUE_NUM
1025         }
1026 
features(&self) -> u641027         fn features(&self) -> u64 {
1028             self.avail_features
1029         }
1030 
ack_features(&mut self, value: u64) -> anyhow::Result<()>1031         fn ack_features(&mut self, value: u64) -> anyhow::Result<()> {
1032             let unrequested_features = value & !self.avail_features;
1033             if unrequested_features != 0 {
1034                 bail!(
1035                     "invalid protocol features are given: 0x{:x}",
1036                     unrequested_features
1037                 );
1038             }
1039             self.acked_features |= value;
1040             Ok(())
1041         }
1042 
protocol_features(&self) -> VhostUserProtocolFeatures1043         fn protocol_features(&self) -> VhostUserProtocolFeatures {
1044             let mut features =
1045                 VhostUserProtocolFeatures::CONFIG | VhostUserProtocolFeatures::DEVICE_STATE;
1046             if self.allow_backend_req {
1047                 features |= VhostUserProtocolFeatures::BACKEND_REQ;
1048             }
1049             features
1050         }
1051 
read_config(&self, offset: u64, dst: &mut [u8])1052         fn read_config(&self, offset: u64, dst: &mut [u8]) {
1053             dst.copy_from_slice(&FAKE_CONFIG_DATA.as_bytes()[offset as usize..]);
1054         }
1055 
reset(&mut self)1056         fn reset(&mut self) {}
1057 
start_queue( &mut self, idx: usize, queue: Queue, _mem: GuestMemory, ) -> anyhow::Result<()>1058         fn start_queue(
1059             &mut self,
1060             idx: usize,
1061             queue: Queue,
1062             _mem: GuestMemory,
1063         ) -> anyhow::Result<()> {
1064             self.active_queues[idx] = Some(queue);
1065             Ok(())
1066         }
1067 
stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue>1068         fn stop_queue(&mut self, idx: usize) -> anyhow::Result<Queue> {
1069             Ok(self.active_queues[idx]
1070                 .take()
1071                 .ok_or(Error::WorkerNotFound)?)
1072         }
1073 
set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>)1074         fn set_backend_req_connection(&mut self, conn: Arc<VhostBackendReqConnection>) {
1075             self.backend_conn = Some(conn);
1076         }
1077 
enter_suspended_state(&mut self) -> anyhow::Result<()>1078         fn enter_suspended_state(&mut self) -> anyhow::Result<()> {
1079             Ok(())
1080         }
1081 
snapshot(&mut self) -> anyhow::Result<serde_json::Value>1082         fn snapshot(&mut self) -> anyhow::Result<serde_json::Value> {
1083             serde_json::to_value(FakeBackendSnapshot {
1084                 data: vec![1, 2, 3],
1085             })
1086             .context("failed to serialize snapshot")
1087         }
1088 
restore(&mut self, data: serde_json::Value) -> anyhow::Result<()>1089         fn restore(&mut self, data: serde_json::Value) -> anyhow::Result<()> {
1090             let snapshot: FakeBackendSnapshot =
1091                 serde_json::from_value(data).context("failed to deserialize snapshot")?;
1092             assert_eq!(snapshot.data, vec![1, 2, 3], "bad snapshot data");
1093             Ok(())
1094         }
1095     }
1096 
1097     #[test]
test_vhost_user_lifecycle()1098     fn test_vhost_user_lifecycle() {
1099         test_vhost_user_lifecycle_parameterized(false);
1100     }
1101 
1102     #[test]
1103     #[cfg(not(windows))] // Windows requries more complex connection setup.
test_vhost_user_lifecycle_with_backend_req()1104     fn test_vhost_user_lifecycle_with_backend_req() {
1105         test_vhost_user_lifecycle_parameterized(true);
1106     }
1107 
test_vhost_user_lifecycle_parameterized(allow_backend_req: bool)1108     fn test_vhost_user_lifecycle_parameterized(allow_backend_req: bool) {
1109         const QUEUES_NUM: usize = 2;
1110 
1111         let (client_connection, server_connection) =
1112             vmm_vhost::Connection::<FrontendReq>::pair().unwrap();
1113 
1114         let vmm_bar = Arc::new(Barrier::new(2));
1115         let dev_bar = vmm_bar.clone();
1116 
1117         let (ready_tx, ready_rx) = channel();
1118         let (shutdown_tx, shutdown_rx) = channel();
1119 
1120         std::thread::spawn(move || {
1121             // VMM side
1122             ready_rx.recv().unwrap(); // Ensure the device is ready.
1123 
1124             let mut vmm_device =
1125                 VhostUserFrontend::new(DeviceType::Console, 0, client_connection, None, None)
1126                     .unwrap();
1127 
1128             println!("read_config");
1129             let mut buf = vec![0; std::mem::size_of::<FakeConfig>()];
1130             vmm_device.read_config(0, &mut buf);
1131             // Check if the obtained config data is correct.
1132             let config = FakeConfig::read_from(buf.as_bytes()).unwrap();
1133             assert_eq!(config, FAKE_CONFIG_DATA);
1134 
1135             let activate = |vmm_device: &mut VhostUserFrontend| {
1136                 let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1137                 let interrupt = Interrupt::new_for_test_with_msix();
1138 
1139                 let mut queues = BTreeMap::new();
1140                 for idx in 0..QUEUES_NUM {
1141                     let mut queue = QueueConfig::new(0x10, 0);
1142                     queue.set_ready(true);
1143                     let queue = queue
1144                         .activate(&mem, Event::new().unwrap(), interrupt.clone())
1145                         .expect("QueueConfig::activate");
1146                     queues.insert(idx, queue);
1147                 }
1148 
1149                 println!("activate");
1150                 vmm_device
1151                     .activate(mem.clone(), interrupt.clone(), queues)
1152                     .unwrap();
1153             };
1154 
1155             activate(&mut vmm_device);
1156 
1157             println!("reset");
1158             let reset_result = vmm_device.reset();
1159             assert!(
1160                 reset_result.is_ok(),
1161                 "reset failed: {:#}",
1162                 reset_result.unwrap_err()
1163             );
1164 
1165             activate(&mut vmm_device);
1166 
1167             println!("virtio_sleep");
1168             let queues = vmm_device
1169                 .virtio_sleep()
1170                 .unwrap()
1171                 .expect("virtio_sleep unexpectedly returned None");
1172 
1173             println!("virtio_snapshot");
1174             let snapshot = vmm_device
1175                 .virtio_snapshot()
1176                 .expect("virtio_snapshot failed");
1177             println!("virtio_restore");
1178             vmm_device
1179                 .virtio_restore(snapshot)
1180                 .expect("virtio_restore failed");
1181 
1182             println!("virtio_wake");
1183             let mem = GuestMemory::new(&[(GuestAddress(0x0), 0x10000)]).unwrap();
1184             let interrupt = Interrupt::new_for_test_with_msix();
1185             vmm_device
1186                 .virtio_wake(Some((mem, interrupt, queues)))
1187                 .unwrap();
1188 
1189             println!("wait for shutdown signal");
1190             shutdown_rx.recv().unwrap();
1191 
1192             // The VMM side is supposed to stop before the device side.
1193             println!("drop");
1194             drop(vmm_device);
1195 
1196             vmm_bar.wait();
1197         });
1198 
1199         // Device side
1200         let mut handler = DeviceRequestHandler::new(FakeBackend::new());
1201         handler.as_mut().allow_backend_req = allow_backend_req;
1202 
1203         // Notify listener is ready.
1204         ready_tx.send(()).unwrap();
1205 
1206         let mut req_handler = BackendServer::new(server_connection, handler);
1207 
1208         // VhostUserFrontend::new()
1209         handle_request(&mut req_handler, FrontendReq::SET_OWNER).unwrap();
1210         handle_request(&mut req_handler, FrontendReq::GET_FEATURES).unwrap();
1211         handle_request(&mut req_handler, FrontendReq::SET_FEATURES).unwrap();
1212         handle_request(&mut req_handler, FrontendReq::GET_PROTOCOL_FEATURES).unwrap();
1213         handle_request(&mut req_handler, FrontendReq::SET_PROTOCOL_FEATURES).unwrap();
1214         if allow_backend_req {
1215             handle_request(&mut req_handler, FrontendReq::SET_BACKEND_REQ_FD).unwrap();
1216         }
1217 
1218         // VhostUserFrontend::read_config()
1219         handle_request(&mut req_handler, FrontendReq::GET_CONFIG).unwrap();
1220 
1221         // VhostUserFrontend::activate()
1222         handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1223         for _ in 0..QUEUES_NUM {
1224             handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1225             handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1226             handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1227             handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1228             handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1229             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1230         }
1231 
1232         // VhostUserFrontend::reset()
1233         for _ in 0..QUEUES_NUM {
1234             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1235             handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1236         }
1237 
1238         // VhostUserFrontend::activate()
1239         handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1240         for _ in 0..QUEUES_NUM {
1241             handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1242             handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1243             handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1244             handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1245             handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1246             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1247         }
1248 
1249         if allow_backend_req {
1250             // Make sure the connection still works even after reset/reactivate.
1251             req_handler
1252                 .as_ref()
1253                 .as_ref()
1254                 .backend_conn
1255                 .as_ref()
1256                 .expect("backend_conn missing")
1257                 .send_config_changed()
1258                 .expect("send_config_changed failed");
1259         }
1260 
1261         // VhostUserFrontend::virtio_sleep()
1262         for _ in 0..QUEUES_NUM {
1263             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1264             handle_request(&mut req_handler, FrontendReq::GET_VRING_BASE).unwrap();
1265         }
1266 
1267         // VhostUserFrontend::virtio_snapshot()
1268         handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1269         handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1270         // VhostUserFrontend::virtio_restore()
1271         handle_request(&mut req_handler, FrontendReq::SET_DEVICE_STATE_FD).unwrap();
1272         handle_request(&mut req_handler, FrontendReq::CHECK_DEVICE_STATE).unwrap();
1273 
1274         // VhostUserFrontend::virtio_wake()
1275         handle_request(&mut req_handler, FrontendReq::SET_MEM_TABLE).unwrap();
1276         for _ in 0..QUEUES_NUM {
1277             handle_request(&mut req_handler, FrontendReq::SET_VRING_NUM).unwrap();
1278             handle_request(&mut req_handler, FrontendReq::SET_VRING_ADDR).unwrap();
1279             handle_request(&mut req_handler, FrontendReq::SET_VRING_BASE).unwrap();
1280             handle_request(&mut req_handler, FrontendReq::SET_VRING_CALL).unwrap();
1281             handle_request(&mut req_handler, FrontendReq::SET_VRING_KICK).unwrap();
1282             handle_request(&mut req_handler, FrontendReq::SET_VRING_ENABLE).unwrap();
1283         }
1284 
1285         if allow_backend_req {
1286             // Make sure the connection still works even after sleep/wake.
1287             req_handler
1288                 .as_ref()
1289                 .as_ref()
1290                 .backend_conn
1291                 .as_ref()
1292                 .expect("backend_conn missing")
1293                 .send_config_changed()
1294                 .expect("send_config_changed failed");
1295         }
1296 
1297         // Ask the client to shutdown, then wait to it to finish.
1298         shutdown_tx.send(()).unwrap();
1299         dev_bar.wait();
1300 
1301         // Verify recv_header fails with `ClientExit` after the client has disconnected.
1302         match req_handler.recv_header() {
1303             Err(VhostError::ClientExit) => (),
1304             r => panic!("expected Err(ClientExit) but got {:?}", r),
1305         }
1306     }
1307 
handle_request<S: vmm_vhost::Backend>( handler: &mut BackendServer<S>, expected_message_type: FrontendReq, ) -> Result<(), VhostError>1308     fn handle_request<S: vmm_vhost::Backend>(
1309         handler: &mut BackendServer<S>,
1310         expected_message_type: FrontendReq,
1311     ) -> Result<(), VhostError> {
1312         let (hdr, files) = handler.recv_header()?;
1313         assert_eq!(hdr.get_code(), Ok(expected_message_type));
1314         handler.process_message(hdr, files)
1315     }
1316 }
1317