1 // SPDX-License-Identifier: Apache-2.0 or BSD-3-Clause
2 
3 mod rxops;
4 mod rxqueue;
5 mod thread_backend;
6 mod txbuf;
7 mod vhu_vsock;
8 mod vhu_vsock_thread;
9 mod vsock_conn;
10 
11 use std::{
12     collections::HashMap,
13     convert::TryFrom,
14     process::exit,
15     sync::{Arc, RwLock},
16     thread,
17 };
18 
19 use crate::vhu_vsock::{CidMap, VhostUserVsockBackend, VsockConfig};
20 use clap::{Args, Parser};
21 use log::{error, info, warn};
22 use serde::Deserialize;
23 use thiserror::Error as ThisError;
24 use vhost::{vhost_user, vhost_user::Listener};
25 use vhost_user_backend::VhostUserDaemon;
26 use vm_memory::{GuestMemoryAtomic, GuestMemoryMmap};
27 
28 const DEFAULT_GUEST_CID: u64 = 3;
29 const DEFAULT_TX_BUFFER_SIZE: u32 = 64 * 1024;
30 const DEFAULT_GROUP_NAME: &str = "default";
31 
32 #[derive(Debug, ThisError)]
33 enum CliError {
34     #[error("No arguments provided")]
35     NoArgsProvided,
36     #[error("Failed to parse configuration file")]
37     ConfigParse,
38 }
39 
40 #[derive(Debug, ThisError)]
41 enum VmArgsParseError {
42     #[error("Bad argument")]
43     BadArgument,
44     #[error("Invalid key `{0}`")]
45     InvalidKey(String),
46     #[error("Unable to convert string to integer: {0}")]
47     ParseInteger(std::num::ParseIntError),
48     #[error("Required key `{0}` not found")]
49     RequiredKeyNotFound(String),
50 }
51 
52 #[derive(Debug, ThisError)]
53 enum BackendError {
54     #[error("Could not create backend: {0}")]
55     CouldNotCreateBackend(vhu_vsock::Error),
56     #[error("Could not create daemon: {0}")]
57     CouldNotCreateDaemon(vhost_user_backend::Error),
58 }
59 
60 #[derive(Args, Clone, Debug)]
61 struct VsockParam {
62     /// Context identifier of the guest which uniquely identifies the device for its lifetime.
63     #[arg(
64         long,
65         default_value_t = DEFAULT_GUEST_CID,
66         conflicts_with = "config",
67         conflicts_with = "vm"
68     )]
69     guest_cid: u64,
70 
71     /// Unix socket to which a hypervisor connects to and sets up the control path with the device.
72     #[arg(long, conflicts_with = "config", conflicts_with = "vm")]
73     socket: String,
74 
75     /// Unix socket to which a host-side application connects to.
76     #[arg(long, conflicts_with = "config", conflicts_with = "vm")]
77     uds_path: String,
78 
79     /// The size of the buffer used for the TX virtqueue
80     #[clap(long, default_value_t = DEFAULT_TX_BUFFER_SIZE, conflicts_with = "config", conflicts_with = "vm")]
81     tx_buffer_size: u32,
82 
83     /// The list of group names to which the device belongs.
84     /// A group is a set of devices that allow sibling communication between their guests.
85     #[arg(
86         long,
87         default_value_t = String::from(DEFAULT_GROUP_NAME),
88         conflicts_with = "config",
89         conflicts_with = "vm",
90         verbatim_doc_comment
91     )]
92     groups: String,
93 }
94 
95 #[derive(Clone, Debug, Deserialize)]
96 struct ConfigFileVsockParam {
97     guest_cid: Option<u64>,
98     socket: String,
99     uds_path: String,
100     tx_buffer_size: Option<u32>,
101     groups: Option<String>,
102 }
103 
104 #[derive(Parser, Debug)]
105 #[command(version, about = None, long_about = None)]
106 struct VsockArgs {
107     #[command(flatten)]
108     param: Option<VsockParam>,
109 
110     /// Device parameters corresponding to a VM in the form of comma separated key=value pairs.
111     /// The allowed keys are: guest_cid, socket, uds_path, tx_buffer_size and group.
112     /// Example:
113     ///   --vm guest-cid=3,socket=/tmp/vhost3.socket,uds-path=/tmp/vm3.vsock,tx-buffer-size=65536,groups=group1+group2
114     /// Multiple instances of this argument can be provided to configure devices for multiple guests.
115     #[arg(long, conflicts_with = "config", verbatim_doc_comment, value_parser = parse_vm_params)]
116     vm: Option<Vec<VsockConfig>>,
117 
118     /// Load from a given configuration file
119     #[arg(long)]
120     config: Option<String>,
121 }
122 
parse_vm_params(s: &str) -> Result<VsockConfig, VmArgsParseError>123 fn parse_vm_params(s: &str) -> Result<VsockConfig, VmArgsParseError> {
124     let mut guest_cid = None;
125     let mut socket = None;
126     let mut uds_path = None;
127     let mut tx_buffer_size = None;
128     let mut groups = None;
129 
130     for arg in s.trim().split(',') {
131         let mut parts = arg.split('=');
132         let key = parts.next().ok_or(VmArgsParseError::BadArgument)?;
133         let val = parts.next().ok_or(VmArgsParseError::BadArgument)?;
134 
135         match key {
136             "guest_cid" | "guest-cid" => {
137                 guest_cid = Some(val.parse().map_err(VmArgsParseError::ParseInteger)?)
138             }
139             "socket" => socket = Some(val.to_string()),
140             "uds_path" | "uds-path" => uds_path = Some(val.to_string()),
141             "tx_buffer_size" | "tx-buffer-size" => {
142                 tx_buffer_size = Some(val.parse().map_err(VmArgsParseError::ParseInteger)?)
143             }
144             "groups" => groups = Some(val.split('+').map(String::from).collect()),
145             _ => return Err(VmArgsParseError::InvalidKey(key.to_string())),
146         }
147     }
148 
149     Ok(VsockConfig::new(
150         guest_cid.unwrap_or(DEFAULT_GUEST_CID),
151         socket.ok_or_else(|| VmArgsParseError::RequiredKeyNotFound("socket".to_string()))?,
152         uds_path.ok_or_else(|| VmArgsParseError::RequiredKeyNotFound("uds-path".to_string()))?,
153         tx_buffer_size.unwrap_or(DEFAULT_TX_BUFFER_SIZE),
154         groups.unwrap_or(vec![DEFAULT_GROUP_NAME.to_string()]),
155     ))
156 }
157 
158 impl VsockArgs {
parse_config(&self) -> Option<Result<Vec<VsockConfig>, CliError>>159     pub fn parse_config(&self) -> Option<Result<Vec<VsockConfig>, CliError>> {
160         if let Some(c) = &self.config {
161             let b = config::Config::builder()
162                 .add_source(config::File::new(c.as_str(), config::FileFormat::Yaml))
163                 .build();
164             if let Ok(s) = b {
165                 let mut v = s.get::<Vec<ConfigFileVsockParam>>("vms").unwrap();
166                 if !v.is_empty() {
167                     let parsed: Vec<VsockConfig> = v
168                         .drain(..)
169                         .map(|p| {
170                             VsockConfig::new(
171                                 p.guest_cid.unwrap_or(DEFAULT_GUEST_CID),
172                                 p.socket.trim().to_string(),
173                                 p.uds_path.trim().to_string(),
174                                 p.tx_buffer_size.unwrap_or(DEFAULT_TX_BUFFER_SIZE),
175                                 p.groups.map_or(vec![DEFAULT_GROUP_NAME.to_string()], |g| {
176                                     g.trim().split('+').map(String::from).collect()
177                                 }),
178                             )
179                         })
180                         .collect();
181                     return Some(Ok(parsed));
182                 } else {
183                     return Some(Err(CliError::ConfigParse));
184                 }
185             } else {
186                 return Some(Err(CliError::ConfigParse));
187             }
188         }
189         None
190     }
191 }
192 
193 impl TryFrom<VsockArgs> for Vec<VsockConfig> {
194     type Error = CliError;
195 
try_from(cmd_args: VsockArgs) -> Result<Self, CliError>196     fn try_from(cmd_args: VsockArgs) -> Result<Self, CliError> {
197         // we try to use the configuration first, if failed,  then fall back to the manual settings.
198         match cmd_args.parse_config() {
199             Some(c) => c,
200             _ => match cmd_args.vm {
201                 Some(v) => Ok(v),
202                 _ => cmd_args.param.map_or(Err(CliError::NoArgsProvided), |p| {
203                     Ok(vec![VsockConfig::new(
204                         p.guest_cid,
205                         p.socket.trim().to_string(),
206                         p.uds_path.trim().to_string(),
207                         p.tx_buffer_size,
208                         p.groups.trim().split('+').map(String::from).collect(),
209                     )])
210                 }),
211             },
212         }
213     }
214 }
215 
216 /// This is the public API through which an external program starts the
217 /// vhost-device-vsock backend server.
start_backend_server( config: VsockConfig, cid_map: Arc<RwLock<CidMap>>, ) -> Result<(), BackendError>218 pub(crate) fn start_backend_server(
219     config: VsockConfig,
220     cid_map: Arc<RwLock<CidMap>>,
221 ) -> Result<(), BackendError> {
222     loop {
223         let backend = Arc::new(
224             VhostUserVsockBackend::new(config.clone(), cid_map.clone())
225                 .map_err(BackendError::CouldNotCreateBackend)?,
226         );
227 
228         let listener = Listener::new(config.get_socket_path(), true).unwrap();
229 
230         let mut daemon = VhostUserDaemon::new(
231             String::from("vhost-device-vsock"),
232             backend.clone(),
233             GuestMemoryAtomic::new(GuestMemoryMmap::new()),
234         )
235         .map_err(BackendError::CouldNotCreateDaemon)?;
236 
237         let mut epoll_handlers = daemon.get_epoll_handlers();
238 
239         for thread in backend.threads.iter() {
240             thread
241                 .lock()
242                 .unwrap()
243                 .register_listeners(epoll_handlers.remove(0));
244         }
245 
246         daemon.start(listener).unwrap();
247 
248         match daemon.wait() {
249             Ok(()) => {
250                 info!("Stopping cleanly");
251             }
252             Err(vhost_user_backend::Error::HandleRequest(
253                 vhost_user::Error::PartialMessage | vhost_user::Error::Disconnected,
254             )) => {
255                 info!("vhost-user connection closed with partial message. If the VM is shutting down, this is expected behavior; otherwise, it might be a bug.");
256             }
257             Err(e) => {
258                 warn!("Error running daemon: {:?}", e);
259             }
260         }
261 
262         // No matter the result, we need to shut down the worker thread.
263         backend.exit_event.write(1).unwrap();
264     }
265 }
266 
start_backend_servers(configs: &[VsockConfig]) -> Result<(), BackendError>267 pub(crate) fn start_backend_servers(configs: &[VsockConfig]) -> Result<(), BackendError> {
268     let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
269     let mut handles = Vec::new();
270 
271     for c in configs.iter() {
272         let config = c.clone();
273         let cid_map = cid_map.clone();
274         let handle = thread::Builder::new()
275             .name(format!("vhu-vsock-cid-{}", c.get_guest_cid()))
276             .spawn(move || start_backend_server(config, cid_map))
277             .unwrap();
278         handles.push(handle);
279     }
280 
281     for handle in handles {
282         handle.join().unwrap()?;
283     }
284 
285     Ok(())
286 }
287 
main()288 fn main() {
289     env_logger::init();
290 
291     let configs = match Vec::<VsockConfig>::try_from(VsockArgs::parse()) {
292         Ok(c) => c,
293         Err(e) => {
294             println!("Error parsing arguments: {}", e);
295             return;
296         }
297     };
298 
299     if let Err(e) = start_backend_servers(&configs) {
300         error!("{e}");
301         exit(1);
302     }
303 }
304 
305 #[cfg(test)]
306 mod tests {
307     use super::*;
308     use std::fs::File;
309     use std::io::Write;
310     use tempfile::tempdir;
311 
312     impl VsockArgs {
from_args( guest_cid: u64, socket: &str, uds_path: &str, tx_buffer_size: u32, groups: &str, ) -> Self313         fn from_args(
314             guest_cid: u64,
315             socket: &str,
316             uds_path: &str,
317             tx_buffer_size: u32,
318             groups: &str,
319         ) -> Self {
320             VsockArgs {
321                 param: Some(VsockParam {
322                     guest_cid,
323                     socket: socket.to_string(),
324                     uds_path: uds_path.to_string(),
325                     tx_buffer_size,
326                     groups: groups.to_string(),
327                 }),
328                 vm: None,
329                 config: None,
330             }
331         }
from_file(config: &str) -> Self332         fn from_file(config: &str) -> Self {
333             VsockArgs {
334                 param: None,
335                 vm: None,
336                 config: Some(config.to_string()),
337             }
338         }
339     }
340 
341     #[test]
test_vsock_config_setup()342     fn test_vsock_config_setup() {
343         let test_dir = tempdir().expect("Could not create a temp test directory.");
344 
345         let socket_path = test_dir.path().join("vhost4.socket").display().to_string();
346         let uds_path = test_dir.path().join("vm4.vsock").display().to_string();
347         let args = VsockArgs::from_args(3, &socket_path, &uds_path, 64 * 1024, "group1");
348 
349         let configs = Vec::<VsockConfig>::try_from(args);
350         assert!(configs.is_ok());
351 
352         let configs = configs.unwrap();
353         assert_eq!(configs.len(), 1);
354 
355         let config = &configs[0];
356         assert_eq!(config.get_guest_cid(), 3);
357         assert_eq!(config.get_socket_path(), socket_path);
358         assert_eq!(config.get_uds_path(), uds_path);
359         assert_eq!(config.get_tx_buffer_size(), 64 * 1024);
360         assert_eq!(config.get_groups(), vec!["group1".to_string()]);
361 
362         test_dir.close().unwrap();
363     }
364 
365     #[test]
test_vsock_config_setup_from_vm_args()366     fn test_vsock_config_setup_from_vm_args() {
367         let test_dir = tempdir().expect("Could not create a temp test directory.");
368 
369         let socket_paths = [
370             test_dir.path().join("vhost3.socket"),
371             test_dir.path().join("vhost4.socket"),
372             test_dir.path().join("vhost5.socket"),
373         ];
374         let uds_paths = [
375             test_dir.path().join("vm3.vsock"),
376             test_dir.path().join("vm4.vsock"),
377             test_dir.path().join("vm5.vsock"),
378         ];
379         let params = format!(
380             "--vm socket={vhost3_socket},uds_path={vm3_vsock} \
381              --vm socket={vhost4_socket},uds-path={vm4_vsock},guest-cid=4,tx_buffer_size=65536,groups=group1 \
382              --vm groups=group2+group3,guest-cid=5,socket={vhost5_socket},uds_path={vm5_vsock},tx-buffer-size=32768",
383             vhost3_socket = socket_paths[0].display(),
384             vhost4_socket = socket_paths[1].display(),
385             vhost5_socket = socket_paths[2].display(),
386             vm3_vsock = uds_paths[0].display(),
387             vm4_vsock = uds_paths[1].display(),
388             vm5_vsock = uds_paths[2].display(),
389         );
390 
391         let mut params = params.split_whitespace().collect::<Vec<&str>>();
392         params.insert(0, ""); // to make the test binary name agnostic
393 
394         let args = VsockArgs::parse_from(params);
395 
396         let configs = Vec::<VsockConfig>::try_from(args);
397         assert!(configs.is_ok());
398 
399         let configs = configs.unwrap();
400         assert_eq!(configs.len(), 3);
401 
402         let config = configs.get(0).unwrap();
403         assert_eq!(config.get_guest_cid(), 3);
404         assert_eq!(
405             config.get_socket_path(),
406             socket_paths[0].display().to_string()
407         );
408         assert_eq!(config.get_uds_path(), uds_paths[0].display().to_string());
409         assert_eq!(config.get_tx_buffer_size(), 65536);
410         assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]);
411 
412         let config = configs.get(1).unwrap();
413         assert_eq!(config.get_guest_cid(), 4);
414         assert_eq!(
415             config.get_socket_path(),
416             socket_paths[1].display().to_string()
417         );
418         assert_eq!(config.get_uds_path(), uds_paths[1].display().to_string());
419         assert_eq!(config.get_tx_buffer_size(), 65536);
420         assert_eq!(config.get_groups(), vec!["group1".to_string()]);
421 
422         let config = configs.get(2).unwrap();
423         assert_eq!(config.get_guest_cid(), 5);
424         assert_eq!(
425             config.get_socket_path(),
426             socket_paths[2].display().to_string()
427         );
428         assert_eq!(config.get_uds_path(), uds_paths[2].display().to_string());
429         assert_eq!(config.get_tx_buffer_size(), 32768);
430         assert_eq!(
431             config.get_groups(),
432             vec!["group2".to_string(), "group3".to_string()]
433         );
434 
435         test_dir.close().unwrap();
436     }
437 
438     #[test]
test_vsock_config_setup_from_file()439     fn test_vsock_config_setup_from_file() {
440         let test_dir = tempdir().expect("Could not create a temp test directory.");
441 
442         let config_path = test_dir.path().join("config.yaml");
443         let socket_path = test_dir.path().join("vhost4.socket");
444         let uds_path = test_dir.path().join("vm4.vsock");
445 
446         let mut yaml = File::create(&config_path).unwrap();
447         yaml.write_all(
448             format!(
449                 "vms:
450     - guest_cid: 4
451       socket: {}
452       uds_path: {}
453       tx_buffer_size: 32768
454       groups: group1+group2",
455                 socket_path.display(),
456                 uds_path.display(),
457             )
458             .as_bytes(),
459         )
460         .unwrap();
461         let args = VsockArgs::from_file(&config_path.display().to_string());
462 
463         let configs = Vec::<VsockConfig>::try_from(args).unwrap();
464         assert_eq!(configs.len(), 1);
465 
466         let config = &configs[0];
467         assert_eq!(config.get_guest_cid(), 4);
468         assert_eq!(config.get_socket_path(), socket_path.display().to_string());
469         assert_eq!(config.get_uds_path(), uds_path.display().to_string());
470         assert_eq!(config.get_tx_buffer_size(), 32768);
471         assert_eq!(
472             config.get_groups(),
473             vec!["group1".to_string(), "group2".to_string()]
474         );
475 
476         // Now test that optional parameters are correctly set to their default values.
477         let mut yaml = File::create(&config_path).unwrap();
478         yaml.write_all(
479             format!(
480                 "vms:
481     - socket: {}
482       uds_path: {}",
483                 socket_path.display(),
484                 uds_path.display(),
485             )
486             .as_bytes(),
487         )
488         .unwrap();
489         let args = VsockArgs::from_file(&config_path.display().to_string());
490 
491         let configs = Vec::<VsockConfig>::try_from(args).unwrap();
492         assert_eq!(configs.len(), 1);
493 
494         let config = &configs[0];
495         assert_eq!(config.get_guest_cid(), DEFAULT_GUEST_CID);
496         assert_eq!(config.get_socket_path(), socket_path.display().to_string());
497         assert_eq!(config.get_uds_path(), uds_path.display().to_string());
498         assert_eq!(config.get_tx_buffer_size(), DEFAULT_TX_BUFFER_SIZE);
499         assert_eq!(config.get_groups(), vec![DEFAULT_GROUP_NAME.to_string()]);
500 
501         std::fs::remove_file(&config_path).unwrap();
502         test_dir.close().unwrap();
503     }
504 
505     #[test]
test_vsock_server()506     fn test_vsock_server() {
507         const CID: u64 = 3;
508         const CONN_TX_BUF_SIZE: u32 = 64 * 1024;
509 
510         let test_dir = tempdir().expect("Could not create a temp test directory.");
511 
512         let vhost_socket_path = test_dir
513             .path()
514             .join("test_vsock_server.socket")
515             .display()
516             .to_string();
517         let vsock_socket_path = test_dir
518             .path()
519             .join("test_vsock_server.vsock")
520             .display()
521             .to_string();
522 
523         let config = VsockConfig::new(
524             CID,
525             vhost_socket_path,
526             vsock_socket_path,
527             CONN_TX_BUF_SIZE,
528             vec![DEFAULT_GROUP_NAME.to_string()],
529         );
530 
531         let cid_map: Arc<RwLock<CidMap>> = Arc::new(RwLock::new(HashMap::new()));
532 
533         let backend = Arc::new(VhostUserVsockBackend::new(config, cid_map).unwrap());
534 
535         let daemon = VhostUserDaemon::new(
536             String::from("vhost-device-vsock"),
537             backend.clone(),
538             GuestMemoryAtomic::new(GuestMemoryMmap::new()),
539         )
540         .unwrap();
541 
542         let vring_workers = daemon.get_epoll_handlers();
543 
544         // VhostUserVsockBackend support a single thread that handles the TX and RX queues
545         assert_eq!(backend.threads.len(), 1);
546 
547         assert_eq!(vring_workers.len(), backend.threads.len());
548 
549         test_dir.close().unwrap();
550     }
551 }
552