xref: /aosp_15_r20/external/crosvm/third_party/vmm_vhost/src/backend_client.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::fs::File;
5 use std::mem;
6 
7 use base::AsRawDescriptor;
8 #[cfg(windows)]
9 use base::CloseNotifier;
10 use base::Event;
11 use base::RawDescriptor;
12 use base::ReadNotifier;
13 use base::INVALID_DESCRIPTOR;
14 use zerocopy::AsBytes;
15 use zerocopy::FromBytes;
16 
17 use crate::backend::VhostUserMemoryRegionInfo;
18 use crate::backend::VringConfigData;
19 use crate::into_single_file;
20 use crate::message::*;
21 use crate::Connection;
22 use crate::Error as VhostUserError;
23 use crate::FrontendReq;
24 use crate::Result as VhostUserResult;
25 use crate::Result;
26 
27 /// Client for a vhost-user device. The API is a thin abstraction over the vhost-user protocol.
28 pub struct BackendClient {
29     connection: Connection<FrontendReq>,
30     // Cached virtio features from the backend.
31     virtio_features: u64,
32     // Cached acked virtio features from the driver.
33     acked_virtio_features: u64,
34     // Cached vhost-user protocol features.
35     acked_protocol_features: u64,
36 }
37 
38 impl BackendClient {
39     /// Create a new instance.
new(connection: Connection<FrontendReq>) -> Self40     pub fn new(connection: Connection<FrontendReq>) -> Self {
41         BackendClient {
42             connection,
43             virtio_features: 0,
44             acked_virtio_features: 0,
45             acked_protocol_features: 0,
46         }
47     }
48 
49     /// Get a bitmask of supported virtio/vhost features.
get_features(&mut self) -> Result<u64>50     pub fn get_features(&mut self) -> Result<u64> {
51         let hdr = self.send_request_header(FrontendReq::GET_FEATURES, None)?;
52         let val = self.recv_reply::<VhostUserU64>(&hdr)?;
53         self.virtio_features = val.value;
54         Ok(self.virtio_features)
55     }
56 
57     /// Inform the vhost subsystem which features to enable.
58     /// This should be a subset of supported features from get_features().
set_features(&mut self, features: u64) -> Result<()>59     pub fn set_features(&mut self, features: u64) -> Result<()> {
60         let val = VhostUserU64::new(features);
61         let hdr = self.send_request_with_body(FrontendReq::SET_FEATURES, &val, None)?;
62         self.acked_virtio_features = features & self.virtio_features;
63         self.wait_for_ack(&hdr)
64     }
65 
66     /// Set the current process as the owner of the vhost backend.
67     /// This must be run before any other vhost commands.
set_owner(&self) -> Result<()>68     pub fn set_owner(&self) -> Result<()> {
69         let hdr = self.send_request_header(FrontendReq::SET_OWNER, None)?;
70         self.wait_for_ack(&hdr)
71     }
72 
73     /// Used to be sent to request disabling all rings
74     /// This is no longer used.
reset_owner(&self) -> Result<()>75     pub fn reset_owner(&self) -> Result<()> {
76         let hdr = self.send_request_header(FrontendReq::RESET_OWNER, None)?;
77         self.wait_for_ack(&hdr)
78     }
79 
80     /// Set the memory map regions on the backend so it can translate the vring
81     /// addresses. In the ancillary data there is an array of file descriptors
set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()>82     pub fn set_mem_table(&self, regions: &[VhostUserMemoryRegionInfo]) -> Result<()> {
83         if regions.is_empty() || regions.len() > MAX_ATTACHED_FD_ENTRIES {
84             return Err(VhostUserError::InvalidParam);
85         }
86 
87         let mut ctx = VhostUserMemoryContext::new();
88         for region in regions.iter() {
89             if region.memory_size == 0 || region.mmap_handle == INVALID_DESCRIPTOR {
90                 return Err(VhostUserError::InvalidParam);
91             }
92 
93             let reg = VhostUserMemoryRegion {
94                 guest_phys_addr: region.guest_phys_addr,
95                 memory_size: region.memory_size,
96                 user_addr: region.userspace_addr,
97                 mmap_offset: region.mmap_offset,
98             };
99             ctx.append(&reg, region.mmap_handle);
100         }
101 
102         let body = VhostUserMemory::new(ctx.regions.len() as u32);
103         let hdr = self.send_request_with_payload(
104             FrontendReq::SET_MEM_TABLE,
105             &body,
106             ctx.regions.as_bytes(),
107             Some(ctx.fds.as_slice()),
108         )?;
109         self.wait_for_ack(&hdr)
110     }
111 
112     /// Set base address for page modification logging.
set_log_base(&self, base: u64, fd: Option<RawDescriptor>) -> Result<()>113     pub fn set_log_base(&self, base: u64, fd: Option<RawDescriptor>) -> Result<()> {
114         let val = VhostUserU64::new(base);
115 
116         let should_have_fd =
117             self.acked_protocol_features & VhostUserProtocolFeatures::LOG_SHMFD.bits() != 0;
118         if should_have_fd != fd.is_some() {
119             return Err(VhostUserError::InvalidParam);
120         }
121 
122         let _ = self.send_request_with_body(
123             FrontendReq::SET_LOG_BASE,
124             &val,
125             fd.as_ref().map(std::slice::from_ref),
126         )?;
127 
128         Ok(())
129     }
130 
131     /// Specify an event file descriptor to signal on log write.
set_log_fd(&self, fd: RawDescriptor) -> Result<()>132     pub fn set_log_fd(&self, fd: RawDescriptor) -> Result<()> {
133         let fds = [fd];
134         let hdr = self.send_request_header(FrontendReq::SET_LOG_FD, Some(&fds))?;
135         self.wait_for_ack(&hdr)
136     }
137 
138     /// Set the number of descriptors in the vring.
set_vring_num(&self, queue_index: usize, num: u16) -> Result<()>139     pub fn set_vring_num(&self, queue_index: usize, num: u16) -> Result<()> {
140         let val = VhostUserVringState::new(queue_index as u32, num.into());
141         let hdr = self.send_request_with_body(FrontendReq::SET_VRING_NUM, &val, None)?;
142         self.wait_for_ack(&hdr)
143     }
144 
145     /// Set the addresses for a given vring.
set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()>146     pub fn set_vring_addr(&self, queue_index: usize, config_data: &VringConfigData) -> Result<()> {
147         if config_data.flags & !(VhostUserVringAddrFlags::all().bits()) != 0 {
148             return Err(VhostUserError::InvalidParam);
149         }
150 
151         let val = VhostUserVringAddr::from_config_data(queue_index as u32, config_data);
152         let hdr = self.send_request_with_body(FrontendReq::SET_VRING_ADDR, &val, None)?;
153         self.wait_for_ack(&hdr)
154     }
155 
156     /// Set the first index to look for available descriptors.
157     // TODO: b/331466964 - Arguments and message format are wrong for packed queues.
set_vring_base(&self, queue_index: usize, base: u16) -> Result<()>158     pub fn set_vring_base(&self, queue_index: usize, base: u16) -> Result<()> {
159         let val = VhostUserVringState::new(queue_index as u32, base.into());
160         let hdr = self.send_request_with_body(FrontendReq::SET_VRING_BASE, &val, None)?;
161         self.wait_for_ack(&hdr)
162     }
163 
164     /// Get the available vring base offset.
165     // TODO: b/331466964 - Return type is wrong for packed queues.
get_vring_base(&self, queue_index: usize) -> Result<u32>166     pub fn get_vring_base(&self, queue_index: usize) -> Result<u32> {
167         let req = VhostUserVringState::new(queue_index as u32, 0);
168         let hdr = self.send_request_with_body(FrontendReq::GET_VRING_BASE, &req, None)?;
169         let reply = self.recv_reply::<VhostUserVringState>(&hdr)?;
170         Ok(reply.num)
171     }
172 
173     /// Set the event to trigger when buffers have been used by the host.
174     ///
175     /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
176     /// is set when there is no file descriptor in the ancillary data. This signals that polling
177     /// will be used instead of waiting for the call.
set_vring_call(&self, queue_index: usize, event: &Event) -> Result<()>178     pub fn set_vring_call(&self, queue_index: usize, event: &Event) -> Result<()> {
179         let hdr = self.send_fd_for_vring(
180             FrontendReq::SET_VRING_CALL,
181             queue_index,
182             event.as_raw_descriptor(),
183         )?;
184         self.wait_for_ack(&hdr)
185     }
186 
187     /// Set the event that will be signaled by the guest when buffers are available for the host to
188     /// process.
189     ///
190     /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
191     /// is set when there is no file descriptor in the ancillary data. This signals that polling
192     /// should be used instead of waiting for a kick.
set_vring_kick(&self, queue_index: usize, event: &Event) -> Result<()>193     pub fn set_vring_kick(&self, queue_index: usize, event: &Event) -> Result<()> {
194         let hdr = self.send_fd_for_vring(
195             FrontendReq::SET_VRING_KICK,
196             queue_index,
197             event.as_raw_descriptor(),
198         )?;
199         self.wait_for_ack(&hdr)
200     }
201 
202     /// Set the event that will be signaled by the guest when error happens.
203     ///
204     /// Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag. This flag
205     /// is set when there is no file descriptor in the ancillary data.
set_vring_err(&self, queue_index: usize, event: &Event) -> Result<()>206     pub fn set_vring_err(&self, queue_index: usize, event: &Event) -> Result<()> {
207         let hdr = self.send_fd_for_vring(
208             FrontendReq::SET_VRING_ERR,
209             queue_index,
210             event.as_raw_descriptor(),
211         )?;
212         self.wait_for_ack(&hdr)
213     }
214 
215     /// Front-end and back-end negotiate a channel over which to transfer the back-end’s internal
216     /// state during migration.
217     ///
218     /// Requires VHOST_USER_PROTOCOL_F_DEVICE_STATE to be negotiated.
set_device_state_fd( &self, transfer_direction: VhostUserTransferDirection, migration_phase: VhostUserMigrationPhase, fd: &impl AsRawDescriptor, ) -> Result<Option<File>>219     pub fn set_device_state_fd(
220         &self,
221         transfer_direction: VhostUserTransferDirection,
222         migration_phase: VhostUserMigrationPhase,
223         fd: &impl AsRawDescriptor,
224     ) -> Result<Option<File>> {
225         if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits() == 0 {
226             return Err(VhostUserError::InvalidOperation);
227         }
228         // Send request.
229         let req = DeviceStateTransferParameters {
230             transfer_direction: match transfer_direction {
231                 VhostUserTransferDirection::Save => 0,
232                 VhostUserTransferDirection::Load => 1,
233             },
234             migration_phase: match migration_phase {
235                 VhostUserMigrationPhase::Stopped => 0,
236             },
237         };
238         let hdr = self.send_request_with_body(
239             FrontendReq::SET_DEVICE_STATE_FD,
240             &req,
241             Some(&[fd.as_raw_descriptor()]),
242         )?;
243         // Receive reply.
244         let (reply, files) = self.recv_reply_with_files::<VhostUserU64>(&hdr)?;
245         let has_err = reply.value & 0xff != 0;
246         let invalid_fd = reply.value & 0x100 != 0;
247         if has_err {
248             return Err(VhostUserError::BackendInternalError);
249         }
250         match (invalid_fd, files.len()) {
251             (true, 0) => Ok(None),
252             (false, 1) => Ok(files.into_iter().next()),
253             _ => Err(VhostUserError::IncorrectFds),
254         }
255     }
256 
257     /// After transferring the back-end’s internal state during migration, check whether the
258     /// back-end was able to successfully fully process the state.
check_device_state(&self) -> Result<()>259     pub fn check_device_state(&self) -> Result<()> {
260         if self.acked_protocol_features & VhostUserProtocolFeatures::DEVICE_STATE.bits() == 0 {
261             return Err(VhostUserError::InvalidOperation);
262         }
263         let hdr = self.send_request_header(FrontendReq::CHECK_DEVICE_STATE, None)?;
264         let reply = self.recv_reply::<VhostUserU64>(&hdr)?;
265         if reply.value != 0 {
266             return Err(VhostUserError::BackendInternalError);
267         }
268         Ok(())
269     }
270 
271     /// Get the protocol feature bitmask from the underlying vhost implementation.
get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>272     pub fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
273         if self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
274             return Err(VhostUserError::InvalidOperation);
275         }
276         let hdr = self.send_request_header(FrontendReq::GET_PROTOCOL_FEATURES, None)?;
277         let val = self.recv_reply::<VhostUserU64>(&hdr)?;
278         Ok(VhostUserProtocolFeatures::from_bits_truncate(val.value))
279     }
280 
281     /// Enable protocol features in the underlying vhost implementation.
set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()>282     pub fn set_protocol_features(&mut self, features: VhostUserProtocolFeatures) -> Result<()> {
283         if self.virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
284             return Err(VhostUserError::InvalidOperation);
285         }
286         if features.contains(VhostUserProtocolFeatures::SHARED_MEMORY_REGIONS)
287             && !features.contains(VhostUserProtocolFeatures::BACKEND_REQ)
288         {
289             return Err(VhostUserError::FeatureMismatch);
290         }
291         let val = VhostUserU64::new(features.bits());
292         let hdr = self.send_request_with_body(FrontendReq::SET_PROTOCOL_FEATURES, &val, None)?;
293         // Don't wait for ACK here because the protocol feature negotiation process hasn't been
294         // completed yet.
295         self.acked_protocol_features = features.bits();
296         self.wait_for_ack(&hdr)
297     }
298 
299     /// Query how many queues the backend supports.
get_queue_num(&self) -> Result<u64>300     pub fn get_queue_num(&self) -> Result<u64> {
301         if !self.is_feature_mq_available() {
302             return Err(VhostUserError::InvalidOperation);
303         }
304 
305         let hdr = self.send_request_header(FrontendReq::GET_QUEUE_NUM, None)?;
306         let val = self.recv_reply::<VhostUserU64>(&hdr)?;
307         if val.value > VHOST_USER_MAX_VRINGS {
308             return Err(VhostUserError::InvalidMessage);
309         }
310         Ok(val.value)
311     }
312 
313     /// Signal backend to enable or disable corresponding vring.
314     ///
315     /// Backend must not pass data to/from the ring until ring is enabled by
316     /// VHOST_USER_SET_VRING_ENABLE with parameter 1, or after it has been
317     /// disabled by VHOST_USER_SET_VRING_ENABLE with parameter 0.
set_vring_enable(&self, queue_index: usize, enable: bool) -> Result<()>318     pub fn set_vring_enable(&self, queue_index: usize, enable: bool) -> Result<()> {
319         // set_vring_enable() is supported only when PROTOCOL_FEATURES has been enabled.
320         if self.acked_virtio_features & 1 << VHOST_USER_F_PROTOCOL_FEATURES == 0 {
321             return Err(VhostUserError::InvalidOperation);
322         }
323 
324         let val = VhostUserVringState::new(queue_index as u32, enable.into());
325         let hdr = self.send_request_with_body(FrontendReq::SET_VRING_ENABLE, &val, None)?;
326         self.wait_for_ack(&hdr)
327     }
328 
329     /// Fetch the contents of the virtio device configuration space.
get_config( &self, offset: u32, size: u32, flags: VhostUserConfigFlags, buf: &[u8], ) -> Result<(VhostUserConfig, VhostUserConfigPayload)>330     pub fn get_config(
331         &self,
332         offset: u32,
333         size: u32,
334         flags: VhostUserConfigFlags,
335         buf: &[u8],
336     ) -> Result<(VhostUserConfig, VhostUserConfigPayload)> {
337         let body = VhostUserConfig::new(offset, size, flags);
338         if !body.is_valid() {
339             return Err(VhostUserError::InvalidParam);
340         }
341 
342         // depends on VhostUserProtocolFeatures::CONFIG
343         if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
344             return Err(VhostUserError::InvalidOperation);
345         }
346 
347         // vhost-user spec states that:
348         // "Request payload: virtio device config space"
349         // "Reply payload: virtio device config space"
350         let hdr = self.send_request_with_payload(FrontendReq::GET_CONFIG, &body, buf, None)?;
351         let (body_reply, buf_reply, rfds) =
352             self.recv_reply_with_payload::<VhostUserConfig>(&hdr)?;
353         if !rfds.is_empty() {
354             return Err(VhostUserError::InvalidMessage);
355         } else if body_reply.size == 0 {
356             return Err(VhostUserError::BackendInternalError);
357         } else if body_reply.size != body.size
358             || body_reply.size as usize != buf.len()
359             || body_reply.offset != body.offset
360         {
361             return Err(VhostUserError::InvalidMessage);
362         }
363 
364         Ok((body_reply, buf_reply))
365     }
366 
367     /// Change the virtio device configuration space. It also can be used for live migration on the
368     /// destination host to set readonly configuration space fields.
set_config(&self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()>369     pub fn set_config(&self, offset: u32, flags: VhostUserConfigFlags, buf: &[u8]) -> Result<()> {
370         let body = VhostUserConfig::new(
371             offset,
372             buf.len()
373                 .try_into()
374                 .map_err(VhostUserError::InvalidCastToInt)?,
375             flags,
376         );
377         if !body.is_valid() {
378             return Err(VhostUserError::InvalidParam);
379         }
380 
381         // depends on VhostUserProtocolFeatures::CONFIG
382         if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIG.bits() == 0 {
383             return Err(VhostUserError::InvalidOperation);
384         }
385 
386         let hdr = self.send_request_with_payload(FrontendReq::SET_CONFIG, &body, buf, None)?;
387         self.wait_for_ack(&hdr)
388     }
389 
390     /// Setup backend communication channel.
set_backend_req_fd(&self, fd: &dyn AsRawDescriptor) -> Result<()>391     pub fn set_backend_req_fd(&self, fd: &dyn AsRawDescriptor) -> Result<()> {
392         if self.acked_protocol_features & VhostUserProtocolFeatures::BACKEND_REQ.bits() == 0 {
393             return Err(VhostUserError::InvalidOperation);
394         }
395         let fds = [fd.as_raw_descriptor()];
396         let hdr = self.send_request_header(FrontendReq::SET_BACKEND_REQ_FD, Some(&fds))?;
397         self.wait_for_ack(&hdr)
398     }
399 
400     /// Retrieve shared buffer for inflight I/O tracking.
get_inflight_fd( &self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>401     pub fn get_inflight_fd(
402         &self,
403         inflight: &VhostUserInflight,
404     ) -> Result<(VhostUserInflight, File)> {
405         if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 {
406             return Err(VhostUserError::InvalidOperation);
407         }
408 
409         let hdr = self.send_request_with_body(FrontendReq::GET_INFLIGHT_FD, inflight, None)?;
410         let (inflight, files) = self.recv_reply_with_files::<VhostUserInflight>(&hdr)?;
411 
412         match into_single_file(files) {
413             Some(file) => Ok((inflight, file)),
414             None => Err(VhostUserError::IncorrectFds),
415         }
416     }
417 
418     /// Set shared buffer for inflight I/O tracking.
set_inflight_fd(&self, inflight: &VhostUserInflight, fd: RawDescriptor) -> Result<()>419     pub fn set_inflight_fd(&self, inflight: &VhostUserInflight, fd: RawDescriptor) -> Result<()> {
420         if self.acked_protocol_features & VhostUserProtocolFeatures::INFLIGHT_SHMFD.bits() == 0 {
421             return Err(VhostUserError::InvalidOperation);
422         }
423 
424         if inflight.mmap_size == 0
425             || inflight.num_queues == 0
426             || inflight.queue_size == 0
427             || fd == INVALID_DESCRIPTOR
428         {
429             return Err(VhostUserError::InvalidParam);
430         }
431 
432         let hdr =
433             self.send_request_with_body(FrontendReq::SET_INFLIGHT_FD, inflight, Some(&[fd]))?;
434         self.wait_for_ack(&hdr)
435     }
436 
437     /// Query the maximum amount of memory slots supported by the backend.
get_max_mem_slots(&self) -> Result<u64>438     pub fn get_max_mem_slots(&self) -> Result<u64> {
439         if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
440         {
441             return Err(VhostUserError::InvalidOperation);
442         }
443 
444         let hdr = self.send_request_header(FrontendReq::GET_MAX_MEM_SLOTS, None)?;
445         let val = self.recv_reply::<VhostUserU64>(&hdr)?;
446 
447         Ok(val.value)
448     }
449 
450     /// Add a new guest memory mapping for vhost to use.
add_mem_region(&self, region: &VhostUserMemoryRegionInfo) -> Result<()>451     pub fn add_mem_region(&self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
452         if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
453         {
454             return Err(VhostUserError::InvalidOperation);
455         }
456 
457         if region.memory_size == 0 || region.mmap_handle == INVALID_DESCRIPTOR {
458             return Err(VhostUserError::InvalidParam);
459         }
460 
461         let body = VhostUserSingleMemoryRegion::new(
462             region.guest_phys_addr,
463             region.memory_size,
464             region.userspace_addr,
465             region.mmap_offset,
466         );
467         let fds = [region.mmap_handle];
468         let hdr = self.send_request_with_body(FrontendReq::ADD_MEM_REG, &body, Some(&fds))?;
469         self.wait_for_ack(&hdr)
470     }
471 
472     /// Remove a guest memory mapping from vhost.
remove_mem_region(&self, region: &VhostUserMemoryRegionInfo) -> Result<()>473     pub fn remove_mem_region(&self, region: &VhostUserMemoryRegionInfo) -> Result<()> {
474         if self.acked_protocol_features & VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS.bits() == 0
475         {
476             return Err(VhostUserError::InvalidOperation);
477         }
478         if region.memory_size == 0 {
479             return Err(VhostUserError::InvalidParam);
480         }
481 
482         let body = VhostUserSingleMemoryRegion::new(
483             region.guest_phys_addr,
484             region.memory_size,
485             region.userspace_addr,
486             region.mmap_offset,
487         );
488         let hdr = self.send_request_with_body(FrontendReq::REM_MEM_REG, &body, None)?;
489         self.wait_for_ack(&hdr)
490     }
491 
492     /// Gets the shared memory regions used by the device.
get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>>493     pub fn get_shared_memory_regions(&self) -> Result<Vec<VhostSharedMemoryRegion>> {
494         let hdr = self.send_request_header(FrontendReq::GET_SHARED_MEMORY_REGIONS, None)?;
495         let (body_reply, buf_reply, rfds) = self.recv_reply_with_payload::<VhostUserU64>(&hdr)?;
496         let struct_size = mem::size_of::<VhostSharedMemoryRegion>();
497         if !rfds.is_empty() || buf_reply.len() != body_reply.value as usize * struct_size {
498             return Err(VhostUserError::InvalidMessage);
499         }
500         let mut regions = Vec::new();
501         let mut offset = 0;
502         for _ in 0..body_reply.value {
503             regions.push(
504                 // Can't fail because the input is the correct size.
505                 VhostSharedMemoryRegion::read_from(&buf_reply[offset..(offset + struct_size)])
506                     .unwrap(),
507             );
508             offset += struct_size;
509         }
510         Ok(regions)
511     }
512 
send_request_header( &self, code: FrontendReq, fds: Option<&[RawDescriptor]>, ) -> VhostUserResult<VhostUserMsgHeader<FrontendReq>>513     fn send_request_header(
514         &self,
515         code: FrontendReq,
516         fds: Option<&[RawDescriptor]>,
517     ) -> VhostUserResult<VhostUserMsgHeader<FrontendReq>> {
518         let hdr = self.new_request_header(code, 0);
519         self.connection.send_header_only_message(&hdr, fds)?;
520         Ok(hdr)
521     }
522 
send_request_with_body<T: Sized + AsBytes>( &self, code: FrontendReq, msg: &T, fds: Option<&[RawDescriptor]>, ) -> VhostUserResult<VhostUserMsgHeader<FrontendReq>>523     fn send_request_with_body<T: Sized + AsBytes>(
524         &self,
525         code: FrontendReq,
526         msg: &T,
527         fds: Option<&[RawDescriptor]>,
528     ) -> VhostUserResult<VhostUserMsgHeader<FrontendReq>> {
529         let hdr = self.new_request_header(code, mem::size_of::<T>() as u32);
530         self.connection.send_message(&hdr, msg, fds)?;
531         Ok(hdr)
532     }
533 
send_request_with_payload<T: Sized + AsBytes>( &self, code: FrontendReq, msg: &T, payload: &[u8], fds: Option<&[RawDescriptor]>, ) -> VhostUserResult<VhostUserMsgHeader<FrontendReq>>534     fn send_request_with_payload<T: Sized + AsBytes>(
535         &self,
536         code: FrontendReq,
537         msg: &T,
538         payload: &[u8],
539         fds: Option<&[RawDescriptor]>,
540     ) -> VhostUserResult<VhostUserMsgHeader<FrontendReq>> {
541         if let Some(fd_arr) = fds {
542             if fd_arr.len() > MAX_ATTACHED_FD_ENTRIES {
543                 return Err(VhostUserError::InvalidParam);
544             }
545         }
546         let len = mem::size_of::<T>()
547             .checked_add(payload.len())
548             .ok_or(VhostUserError::OversizedMsg)?;
549         let hdr = self.new_request_header(
550             code,
551             len.try_into().map_err(VhostUserError::InvalidCastToInt)?,
552         );
553         self.connection
554             .send_message_with_payload(&hdr, msg, payload, fds)?;
555         Ok(hdr)
556     }
557 
send_fd_for_vring( &self, code: FrontendReq, queue_index: usize, fd: RawDescriptor, ) -> VhostUserResult<VhostUserMsgHeader<FrontendReq>>558     fn send_fd_for_vring(
559         &self,
560         code: FrontendReq,
561         queue_index: usize,
562         fd: RawDescriptor,
563     ) -> VhostUserResult<VhostUserMsgHeader<FrontendReq>> {
564         // Bits (0-7) of the payload contain the vring index. Bit 8 is the invalid FD flag.
565         // This flag is set when there is no file descriptor in the ancillary data. This signals
566         // that polling will be used instead of waiting for the call.
567         let msg = VhostUserU64::new(queue_index as u64);
568         let hdr = self.new_request_header(code, mem::size_of::<VhostUserU64>() as u32);
569         self.connection.send_message(&hdr, &msg, Some(&[fd]))?;
570         Ok(hdr)
571     }
572 
recv_reply<T: Sized + FromBytes + AsBytes + Default + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<FrontendReq>, ) -> VhostUserResult<T>573     fn recv_reply<T: Sized + FromBytes + AsBytes + Default + VhostUserMsgValidator>(
574         &self,
575         hdr: &VhostUserMsgHeader<FrontendReq>,
576     ) -> VhostUserResult<T> {
577         if hdr.is_reply() {
578             return Err(VhostUserError::InvalidParam);
579         }
580         let (reply, body, rfds) = self.connection.recv_message::<T>()?;
581         if !reply.is_reply_for(hdr) || !rfds.is_empty() || !body.is_valid() {
582             return Err(VhostUserError::InvalidMessage);
583         }
584         Ok(body)
585     }
586 
recv_reply_with_files<T: Sized + AsBytes + FromBytes + Default + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<FrontendReq>, ) -> VhostUserResult<(T, Vec<File>)>587     fn recv_reply_with_files<T: Sized + AsBytes + FromBytes + Default + VhostUserMsgValidator>(
588         &self,
589         hdr: &VhostUserMsgHeader<FrontendReq>,
590     ) -> VhostUserResult<(T, Vec<File>)> {
591         if hdr.is_reply() {
592             return Err(VhostUserError::InvalidParam);
593         }
594 
595         let (reply, body, files) = self.connection.recv_message::<T>()?;
596         if !reply.is_reply_for(hdr) || !body.is_valid() {
597             return Err(VhostUserError::InvalidMessage);
598         }
599         Ok((body, files))
600     }
601 
recv_reply_with_payload<T: Sized + AsBytes + FromBytes + Default + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<FrontendReq>, ) -> VhostUserResult<(T, Vec<u8>, Vec<File>)>602     fn recv_reply_with_payload<T: Sized + AsBytes + FromBytes + Default + VhostUserMsgValidator>(
603         &self,
604         hdr: &VhostUserMsgHeader<FrontendReq>,
605     ) -> VhostUserResult<(T, Vec<u8>, Vec<File>)> {
606         if hdr.is_reply() {
607             return Err(VhostUserError::InvalidParam);
608         }
609 
610         let (reply, body, buf, files) = self.connection.recv_message_with_payload::<T>()?;
611         if !reply.is_reply_for(hdr) || !files.is_empty() || !body.is_valid() {
612             return Err(VhostUserError::InvalidMessage);
613         }
614 
615         Ok((body, buf, files))
616     }
617 
wait_for_ack(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> VhostUserResult<()>618     fn wait_for_ack(&self, hdr: &VhostUserMsgHeader<FrontendReq>) -> VhostUserResult<()> {
619         if self.acked_protocol_features & VhostUserProtocolFeatures::REPLY_ACK.bits() == 0
620             || !hdr.is_need_reply()
621         {
622             return Ok(());
623         }
624 
625         let (reply, body, rfds) = self.connection.recv_message::<VhostUserU64>()?;
626         if !reply.is_reply_for(hdr) || !rfds.is_empty() || !body.is_valid() {
627             return Err(VhostUserError::InvalidMessage);
628         }
629         if body.value != 0 {
630             return Err(VhostUserError::BackendInternalError);
631         }
632         Ok(())
633     }
634 
is_feature_mq_available(&self) -> bool635     fn is_feature_mq_available(&self) -> bool {
636         self.acked_protocol_features & VhostUserProtocolFeatures::MQ.bits() != 0
637     }
638 
639     #[inline]
new_request_header( &self, request: FrontendReq, size: u32, ) -> VhostUserMsgHeader<FrontendReq>640     fn new_request_header(
641         &self,
642         request: FrontendReq,
643         size: u32,
644     ) -> VhostUserMsgHeader<FrontendReq> {
645         VhostUserMsgHeader::new(request, 0x1, size)
646     }
647 }
648 
649 #[cfg(windows)]
650 impl CloseNotifier for BackendClient {
get_close_notifier(&self) -> &dyn AsRawDescriptor651     fn get_close_notifier(&self) -> &dyn AsRawDescriptor {
652         self.connection.0.get_close_notifier()
653     }
654 }
655 
656 impl ReadNotifier for BackendClient {
get_read_notifier(&self) -> &dyn AsRawDescriptor657     fn get_read_notifier(&self) -> &dyn AsRawDescriptor {
658         self.connection.0.get_read_notifier()
659     }
660 }
661 
662 // TODO(b/221882601): likely need pairs of RDs and/or SharedMemory to represent mmaps on Windows.
663 /// Context object to pass guest memory configuration to BackendClient::set_mem_table().
664 struct VhostUserMemoryContext {
665     regions: VhostUserMemoryPayload,
666     fds: Vec<RawDescriptor>,
667 }
668 
669 impl VhostUserMemoryContext {
670     /// Create a context object.
new() -> Self671     pub fn new() -> Self {
672         VhostUserMemoryContext {
673             regions: VhostUserMemoryPayload::new(),
674             fds: Vec::new(),
675         }
676     }
677 
678     /// Append a user memory region and corresponding RawDescriptor into the context object.
append(&mut self, region: &VhostUserMemoryRegion, fd: RawDescriptor)679     pub fn append(&mut self, region: &VhostUserMemoryRegion, fd: RawDescriptor) {
680         self.regions.push(*region);
681         self.fds.push(fd);
682     }
683 }
684 
685 #[cfg(test)]
686 mod tests {
687     use base::INVALID_DESCRIPTOR;
688     use tempfile::tempfile;
689 
690     use super::*;
691 
692     const BUFFER_SIZE: usize = 0x1001;
693     const INVALID_PROTOCOL_FEATURE: u64 = 1 << 63;
694 
create_pair() -> (BackendClient, Connection<FrontendReq>)695     fn create_pair() -> (BackendClient, Connection<FrontendReq>) {
696         let (client_connection, server_connection) = Connection::pair().unwrap();
697         let backend_client = BackendClient::new(client_connection);
698         (backend_client, server_connection)
699     }
700 
701     #[test]
create_backend_client()702     fn create_backend_client() {
703         let (backend_client, peer) = create_pair();
704 
705         assert!(backend_client.connection.as_raw_descriptor() != INVALID_DESCRIPTOR);
706         // Send two messages continuously
707         backend_client.set_owner().unwrap();
708         backend_client.reset_owner().unwrap();
709 
710         let (hdr, rfds) = peer.recv_header().unwrap();
711         assert_eq!(hdr.get_code(), Ok(FrontendReq::SET_OWNER));
712         assert_eq!(hdr.get_size(), 0);
713         assert_eq!(hdr.get_version(), 0x1);
714         assert!(rfds.is_empty());
715 
716         let (hdr, rfds) = peer.recv_header().unwrap();
717         assert_eq!(hdr.get_code(), Ok(FrontendReq::RESET_OWNER));
718         assert_eq!(hdr.get_size(), 0);
719         assert_eq!(hdr.get_version(), 0x1);
720         assert!(rfds.is_empty());
721     }
722 
723     #[test]
test_features()724     fn test_features() {
725         let (mut backend_client, peer) = create_pair();
726 
727         backend_client.set_owner().unwrap();
728         let (hdr, rfds) = peer.recv_header().unwrap();
729         assert_eq!(hdr.get_code(), Ok(FrontendReq::SET_OWNER));
730         assert_eq!(hdr.get_size(), 0);
731         assert_eq!(hdr.get_version(), 0x1);
732         assert!(rfds.is_empty());
733 
734         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0x4, 8);
735         let msg = VhostUserU64::new(0x15);
736         peer.send_message(&hdr, &msg, None).unwrap();
737         let features = backend_client.get_features().unwrap();
738         assert_eq!(features, 0x15u64);
739         let (_hdr, rfds) = peer.recv_header().unwrap();
740         assert!(rfds.is_empty());
741 
742         let hdr = VhostUserMsgHeader::new(FrontendReq::SET_FEATURES, 0x4, 8);
743         let msg = VhostUserU64::new(0x15);
744         peer.send_message(&hdr, &msg, None).unwrap();
745         backend_client.set_features(0x15).unwrap();
746         let (_hdr, msg, rfds) = peer.recv_message::<VhostUserU64>().unwrap();
747         assert!(rfds.is_empty());
748         let val = msg.value;
749         assert_eq!(val, 0x15);
750 
751         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0x4, 8);
752         let msg = 0x15u32;
753         peer.send_message(&hdr, &msg, None).unwrap();
754         assert!(backend_client.get_features().is_err());
755     }
756 
757     #[test]
test_protocol_features()758     fn test_protocol_features() {
759         let (mut backend_client, peer) = create_pair();
760 
761         backend_client.set_owner().unwrap();
762         let (hdr, rfds) = peer.recv_header().unwrap();
763         assert_eq!(hdr.get_code(), Ok(FrontendReq::SET_OWNER));
764         assert!(rfds.is_empty());
765 
766         assert!(backend_client.get_protocol_features().is_err());
767         assert!(backend_client
768             .set_protocol_features(VhostUserProtocolFeatures::all())
769             .is_err());
770 
771         let vfeatures = 0x15 | 1 << VHOST_USER_F_PROTOCOL_FEATURES;
772         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_FEATURES, 0x4, 8);
773         let msg = VhostUserU64::new(vfeatures);
774         peer.send_message(&hdr, &msg, None).unwrap();
775         let features = backend_client.get_features().unwrap();
776         assert_eq!(features, vfeatures);
777         let (_hdr, rfds) = peer.recv_header().unwrap();
778         assert!(rfds.is_empty());
779 
780         backend_client.set_features(vfeatures).unwrap();
781         let (_hdr, msg, rfds) = peer.recv_message::<VhostUserU64>().unwrap();
782         assert!(rfds.is_empty());
783         let val = msg.value;
784         assert_eq!(val, vfeatures);
785 
786         let pfeatures = VhostUserProtocolFeatures::all();
787         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_PROTOCOL_FEATURES, 0x4, 8);
788         // Unknown feature bits should be ignored.
789         let msg = VhostUserU64::new(pfeatures.bits() | INVALID_PROTOCOL_FEATURE);
790         peer.send_message(&hdr, &msg, None).unwrap();
791         let features = backend_client.get_protocol_features().unwrap();
792         assert_eq!(features, pfeatures);
793         let (_hdr, rfds) = peer.recv_header().unwrap();
794         assert!(rfds.is_empty());
795 
796         backend_client.set_protocol_features(pfeatures).unwrap();
797         let (_hdr, msg, rfds) = peer.recv_message::<VhostUserU64>().unwrap();
798         assert!(rfds.is_empty());
799         let val = msg.value;
800         assert_eq!(val, pfeatures.bits());
801 
802         let hdr = VhostUserMsgHeader::new(FrontendReq::SET_PROTOCOL_FEATURES, 0x4, 8);
803         let msg = VhostUserU64::new(pfeatures.bits());
804         peer.send_message(&hdr, &msg, None).unwrap();
805         assert!(backend_client.get_protocol_features().is_err());
806     }
807 
808     #[test]
test_backend_client_set_config_negative()809     fn test_backend_client_set_config_negative() {
810         let (mut backend_client, _peer) = create_pair();
811         let buf = vec![0x0; BUFFER_SIZE];
812 
813         backend_client
814             .set_config(0x100, VhostUserConfigFlags::WRITABLE, &buf[0..4])
815             .unwrap_err();
816 
817         backend_client.virtio_features = 0xffff_ffff;
818         backend_client.acked_virtio_features = 0xffff_ffff;
819         backend_client.acked_protocol_features = 0xffff_ffff;
820 
821         backend_client
822             .set_config(0, VhostUserConfigFlags::WRITABLE, &buf[0..4])
823             .unwrap();
824         backend_client
825             .set_config(
826                 VHOST_USER_CONFIG_SIZE,
827                 VhostUserConfigFlags::WRITABLE,
828                 &buf[0..4],
829             )
830             .unwrap_err();
831         backend_client
832             .set_config(0x1000, VhostUserConfigFlags::WRITABLE, &buf[0..4])
833             .unwrap_err();
834         backend_client
835             .set_config(
836                 0x100,
837                 VhostUserConfigFlags::from_bits_retain(0xffff_ffff),
838                 &buf[0..4],
839             )
840             .unwrap_err();
841         backend_client
842             .set_config(VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE, &buf)
843             .unwrap_err();
844         backend_client
845             .set_config(VHOST_USER_CONFIG_SIZE, VhostUserConfigFlags::WRITABLE, &[])
846             .unwrap_err();
847     }
848 
create_pair2() -> (BackendClient, Connection<FrontendReq>)849     fn create_pair2() -> (BackendClient, Connection<FrontendReq>) {
850         let (mut backend_client, peer) = create_pair();
851 
852         backend_client.virtio_features = 0xffff_ffff;
853         backend_client.acked_virtio_features = 0xffff_ffff;
854         backend_client.acked_protocol_features = 0xffff_ffff;
855 
856         (backend_client, peer)
857     }
858 
859     #[test]
test_backend_client_get_config_negative0()860     fn test_backend_client_get_config_negative0() {
861         let (backend_client, peer) = create_pair2();
862         let buf = vec![0x0; BUFFER_SIZE];
863 
864         let mut hdr = VhostUserMsgHeader::new(FrontendReq::GET_CONFIG, 0x4, 16);
865         let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
866         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
867             .unwrap();
868         assert!(backend_client
869             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
870             .is_ok());
871 
872         hdr.set_code(FrontendReq::GET_FEATURES);
873         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
874             .unwrap();
875         assert!(backend_client
876             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
877             .is_err());
878         hdr.set_code(FrontendReq::GET_CONFIG);
879     }
880 
881     #[test]
test_backend_client_get_config_negative1()882     fn test_backend_client_get_config_negative1() {
883         let (backend_client, peer) = create_pair2();
884         let buf = vec![0x0; BUFFER_SIZE];
885 
886         let mut hdr = VhostUserMsgHeader::new(FrontendReq::GET_CONFIG, 0x4, 16);
887         let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
888         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
889             .unwrap();
890         assert!(backend_client
891             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
892             .is_ok());
893 
894         hdr.set_reply(false);
895         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
896             .unwrap();
897         assert!(backend_client
898             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
899             .is_err());
900     }
901 
902     #[test]
test_backend_client_get_config_negative2()903     fn test_backend_client_get_config_negative2() {
904         let (backend_client, peer) = create_pair2();
905         let buf = vec![0x0; BUFFER_SIZE];
906 
907         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_CONFIG, 0x4, 16);
908         let msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
909         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
910             .unwrap();
911         assert!(backend_client
912             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
913             .is_ok());
914     }
915 
916     #[test]
test_backend_client_get_config_negative3()917     fn test_backend_client_get_config_negative3() {
918         let (backend_client, peer) = create_pair2();
919         let buf = vec![0x0; BUFFER_SIZE];
920 
921         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_CONFIG, 0x4, 16);
922         let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
923         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
924             .unwrap();
925         assert!(backend_client
926             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
927             .is_ok());
928 
929         msg.offset = 0;
930         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
931             .unwrap();
932         assert!(backend_client
933             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
934             .is_err());
935     }
936 
937     #[test]
test_backend_client_get_config_negative4()938     fn test_backend_client_get_config_negative4() {
939         let (backend_client, peer) = create_pair2();
940         let buf = vec![0x0; BUFFER_SIZE];
941 
942         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_CONFIG, 0x4, 16);
943         let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
944         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
945             .unwrap();
946         assert!(backend_client
947             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
948             .is_ok());
949 
950         msg.offset = 0x101;
951         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
952             .unwrap();
953         assert!(backend_client
954             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
955             .is_err());
956     }
957 
958     #[test]
test_backend_client_get_config_negative5()959     fn test_backend_client_get_config_negative5() {
960         let (backend_client, peer) = create_pair2();
961         let buf = vec![0x0; BUFFER_SIZE];
962 
963         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_CONFIG, 0x4, 16);
964         let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
965         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
966             .unwrap();
967         assert!(backend_client
968             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
969             .is_ok());
970 
971         msg.offset = (BUFFER_SIZE) as u32;
972         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
973             .unwrap();
974         assert!(backend_client
975             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
976             .is_err());
977     }
978 
979     #[test]
test_backend_client_get_config_negative6()980     fn test_backend_client_get_config_negative6() {
981         let (backend_client, peer) = create_pair2();
982         let buf = vec![0x0; BUFFER_SIZE];
983 
984         let hdr = VhostUserMsgHeader::new(FrontendReq::GET_CONFIG, 0x4, 16);
985         let mut msg = VhostUserConfig::new(0x100, 4, VhostUserConfigFlags::empty());
986         peer.send_message_with_payload(&hdr, &msg, &buf[0..4], None)
987             .unwrap();
988         assert!(backend_client
989             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
990             .is_ok());
991 
992         msg.size = 6;
993         peer.send_message_with_payload(&hdr, &msg, &buf[0..6], None)
994             .unwrap();
995         assert!(backend_client
996             .get_config(0x100, 4, VhostUserConfigFlags::WRITABLE, &buf[0..4])
997             .is_err());
998     }
999 
1000     #[test]
test_maset_set_mem_table_failure()1001     fn test_maset_set_mem_table_failure() {
1002         let (backend_client, _peer) = create_pair2();
1003 
1004         // set_mem_table() with 0 regions is invalid
1005         backend_client.set_mem_table(&[]).unwrap_err();
1006 
1007         // set_mem_table() with more than MAX_ATTACHED_FD_ENTRIES is invalid
1008         let files: Vec<File> = (0..MAX_ATTACHED_FD_ENTRIES + 1)
1009             .map(|_| tempfile().unwrap())
1010             .collect();
1011         let tables: Vec<VhostUserMemoryRegionInfo> = files
1012             .iter()
1013             .map(|f| VhostUserMemoryRegionInfo {
1014                 guest_phys_addr: 0,
1015                 memory_size: 0x100000,
1016                 userspace_addr: 0x800000,
1017                 mmap_offset: 0,
1018                 mmap_handle: f.as_raw_descriptor(),
1019             })
1020             .collect();
1021         backend_client.set_mem_table(&tables).unwrap_err();
1022     }
1023 }
1024