1 // Copyright (C) 2019 Alibaba Cloud Computing. All rights reserved.
2 // SPDX-License-Identifier: Apache-2.0
3 
4 use std::fs::File;
5 use std::mem;
6 use std::os::unix::io::{AsRawFd, FromRawFd, IntoRawFd, RawFd};
7 use std::os::unix::net::UnixStream;
8 use std::slice;
9 use std::sync::{Arc, Mutex};
10 
11 use vm_memory::ByteValued;
12 
13 use super::connection::Endpoint;
14 use super::message::*;
15 use super::slave_req::Slave;
16 use super::{take_single_file, Error, Result};
17 
18 /// Services provided to the master by the slave with interior mutability.
19 ///
20 /// The [VhostUserSlaveReqHandler] trait defines the services provided to the master by the slave.
21 /// And the [VhostUserSlaveReqHandlerMut] trait is a helper mirroring [VhostUserSlaveReqHandler],
22 /// but without interior mutability.
23 /// The vhost-user specification defines a master communication channel, by which masters could
24 /// request services from slaves. The [VhostUserSlaveReqHandler] trait defines services provided by
25 /// slaves, and it's used both on the master side and slave side.
26 ///
27 /// - on the master side, a stub forwarder implementing [VhostUserSlaveReqHandler] will proxy
28 ///   service requests to slaves.
29 /// - on the slave side, the [SlaveReqHandler] will forward service requests to a handler
30 ///   implementing [VhostUserSlaveReqHandler].
31 ///
32 /// The [VhostUserSlaveReqHandler] trait is design with interior mutability to improve performance
33 /// for multi-threading.
34 ///
35 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
36 /// [VhostUserSlaveReqHandlerMut]: trait.VhostUserSlaveReqHandlerMut.html
37 /// [SlaveReqHandler]: struct.SlaveReqHandler.html
38 #[allow(missing_docs)]
39 pub trait VhostUserSlaveReqHandler {
set_owner(&self) -> Result<()>40     fn set_owner(&self) -> Result<()>;
reset_owner(&self) -> Result<()>41     fn reset_owner(&self) -> Result<()>;
get_features(&self) -> Result<u64>42     fn get_features(&self) -> Result<u64>;
set_features(&self, features: u64) -> Result<()>43     fn set_features(&self, features: u64) -> Result<()>;
set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>44     fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
set_vring_num(&self, index: u32, num: u32) -> Result<()>45     fn set_vring_num(&self, index: u32, num: u32) -> Result<()>;
set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>46     fn set_vring_addr(
47         &self,
48         index: u32,
49         flags: VhostUserVringAddrFlags,
50         descriptor: u64,
51         used: u64,
52         available: u64,
53         log: u64,
54     ) -> Result<()>;
set_vring_base(&self, index: u32, base: u32) -> Result<()>55     fn set_vring_base(&self, index: u32, base: u32) -> Result<()>;
get_vring_base(&self, index: u32) -> Result<VhostUserVringState>56     fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState>;
set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>57     fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>58     fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>59     fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>;
60 
get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>61     fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>;
set_protocol_features(&self, features: u64) -> Result<()>62     fn set_protocol_features(&self, features: u64) -> Result<()>;
get_queue_num(&self) -> Result<u64>63     fn get_queue_num(&self) -> Result<u64>;
set_vring_enable(&self, index: u32, enable: bool) -> Result<()>64     fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()>;
get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>65     fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>;
set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>66     fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
set_slave_req_fd(&self, _slave: Slave)67     fn set_slave_req_fd(&self, _slave: Slave) {}
get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>68     fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>;
set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>69     fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>;
get_max_mem_slots(&self) -> Result<u64>70     fn get_max_mem_slots(&self) -> Result<u64>;
add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>71     fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>72     fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
73 }
74 
75 /// Services provided to the master by the slave without interior mutability.
76 ///
77 /// This is a helper trait mirroring the [VhostUserSlaveReqHandler] trait.
78 #[allow(missing_docs)]
79 pub trait VhostUserSlaveReqHandlerMut {
set_owner(&mut self) -> Result<()>80     fn set_owner(&mut self) -> Result<()>;
reset_owner(&mut self) -> Result<()>81     fn reset_owner(&mut self) -> Result<()>;
get_features(&mut self) -> Result<u64>82     fn get_features(&mut self) -> Result<u64>;
set_features(&mut self, features: u64) -> Result<()>83     fn set_features(&mut self, features: u64) -> Result<()>;
set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>84     fn set_mem_table(&mut self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>;
set_vring_num(&mut self, index: u32, num: u32) -> Result<()>85     fn set_vring_num(&mut self, index: u32, num: u32) -> Result<()>;
set_vring_addr( &mut self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>86     fn set_vring_addr(
87         &mut self,
88         index: u32,
89         flags: VhostUserVringAddrFlags,
90         descriptor: u64,
91         used: u64,
92         available: u64,
93         log: u64,
94     ) -> Result<()>;
set_vring_base(&mut self, index: u32, base: u32) -> Result<()>95     fn set_vring_base(&mut self, index: u32, base: u32) -> Result<()>;
get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>96     fn get_vring_base(&mut self, index: u32) -> Result<VhostUserVringState>;
set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>97     fn set_vring_kick(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>98     fn set_vring_call(&mut self, index: u8, fd: Option<File>) -> Result<()>;
set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>99     fn set_vring_err(&mut self, index: u8, fd: Option<File>) -> Result<()>;
100 
get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>101     fn get_protocol_features(&mut self) -> Result<VhostUserProtocolFeatures>;
set_protocol_features(&mut self, features: u64) -> Result<()>102     fn set_protocol_features(&mut self, features: u64) -> Result<()>;
get_queue_num(&mut self) -> Result<u64>103     fn get_queue_num(&mut self) -> Result<u64>;
set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>104     fn set_vring_enable(&mut self, index: u32, enable: bool) -> Result<()>;
get_config( &mut self, offset: u32, size: u32, flags: VhostUserConfigFlags, ) -> Result<Vec<u8>>105     fn get_config(
106         &mut self,
107         offset: u32,
108         size: u32,
109         flags: VhostUserConfigFlags,
110     ) -> Result<Vec<u8>>;
set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>111     fn set_config(&mut self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>;
set_slave_req_fd(&mut self, _slave: Slave)112     fn set_slave_req_fd(&mut self, _slave: Slave) {}
get_inflight_fd( &mut self, inflight: &VhostUserInflight, ) -> Result<(VhostUserInflight, File)>113     fn get_inflight_fd(
114         &mut self,
115         inflight: &VhostUserInflight,
116     ) -> Result<(VhostUserInflight, File)>;
set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>117     fn set_inflight_fd(&mut self, inflight: &VhostUserInflight, file: File) -> Result<()>;
get_max_mem_slots(&mut self) -> Result<u64>118     fn get_max_mem_slots(&mut self) -> Result<u64>;
add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>119     fn add_mem_region(&mut self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>;
remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>120     fn remove_mem_region(&mut self, region: &VhostUserSingleMemoryRegion) -> Result<()>;
121 }
122 
123 impl<T: VhostUserSlaveReqHandlerMut> VhostUserSlaveReqHandler for Mutex<T> {
set_owner(&self) -> Result<()>124     fn set_owner(&self) -> Result<()> {
125         self.lock().unwrap().set_owner()
126     }
127 
reset_owner(&self) -> Result<()>128     fn reset_owner(&self) -> Result<()> {
129         self.lock().unwrap().reset_owner()
130     }
131 
get_features(&self) -> Result<u64>132     fn get_features(&self) -> Result<u64> {
133         self.lock().unwrap().get_features()
134     }
135 
set_features(&self, features: u64) -> Result<()>136     fn set_features(&self, features: u64) -> Result<()> {
137         self.lock().unwrap().set_features(features)
138     }
139 
set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()>140     fn set_mem_table(&self, ctx: &[VhostUserMemoryRegion], files: Vec<File>) -> Result<()> {
141         self.lock().unwrap().set_mem_table(ctx, files)
142     }
143 
set_vring_num(&self, index: u32, num: u32) -> Result<()>144     fn set_vring_num(&self, index: u32, num: u32) -> Result<()> {
145         self.lock().unwrap().set_vring_num(index, num)
146     }
147 
set_vring_addr( &self, index: u32, flags: VhostUserVringAddrFlags, descriptor: u64, used: u64, available: u64, log: u64, ) -> Result<()>148     fn set_vring_addr(
149         &self,
150         index: u32,
151         flags: VhostUserVringAddrFlags,
152         descriptor: u64,
153         used: u64,
154         available: u64,
155         log: u64,
156     ) -> Result<()> {
157         self.lock()
158             .unwrap()
159             .set_vring_addr(index, flags, descriptor, used, available, log)
160     }
161 
set_vring_base(&self, index: u32, base: u32) -> Result<()>162     fn set_vring_base(&self, index: u32, base: u32) -> Result<()> {
163         self.lock().unwrap().set_vring_base(index, base)
164     }
165 
get_vring_base(&self, index: u32) -> Result<VhostUserVringState>166     fn get_vring_base(&self, index: u32) -> Result<VhostUserVringState> {
167         self.lock().unwrap().get_vring_base(index)
168     }
169 
set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()>170     fn set_vring_kick(&self, index: u8, fd: Option<File>) -> Result<()> {
171         self.lock().unwrap().set_vring_kick(index, fd)
172     }
173 
set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()>174     fn set_vring_call(&self, index: u8, fd: Option<File>) -> Result<()> {
175         self.lock().unwrap().set_vring_call(index, fd)
176     }
177 
set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()>178     fn set_vring_err(&self, index: u8, fd: Option<File>) -> Result<()> {
179         self.lock().unwrap().set_vring_err(index, fd)
180     }
181 
get_protocol_features(&self) -> Result<VhostUserProtocolFeatures>182     fn get_protocol_features(&self) -> Result<VhostUserProtocolFeatures> {
183         self.lock().unwrap().get_protocol_features()
184     }
185 
set_protocol_features(&self, features: u64) -> Result<()>186     fn set_protocol_features(&self, features: u64) -> Result<()> {
187         self.lock().unwrap().set_protocol_features(features)
188     }
189 
get_queue_num(&self) -> Result<u64>190     fn get_queue_num(&self) -> Result<u64> {
191         self.lock().unwrap().get_queue_num()
192     }
193 
set_vring_enable(&self, index: u32, enable: bool) -> Result<()>194     fn set_vring_enable(&self, index: u32, enable: bool) -> Result<()> {
195         self.lock().unwrap().set_vring_enable(index, enable)
196     }
197 
get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>>198     fn get_config(&self, offset: u32, size: u32, flags: VhostUserConfigFlags) -> Result<Vec<u8>> {
199         self.lock().unwrap().get_config(offset, size, flags)
200     }
201 
set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()>202     fn set_config(&self, offset: u32, buf: &[u8], flags: VhostUserConfigFlags) -> Result<()> {
203         self.lock().unwrap().set_config(offset, buf, flags)
204     }
205 
set_slave_req_fd(&self, slave: Slave)206     fn set_slave_req_fd(&self, slave: Slave) {
207         self.lock().unwrap().set_slave_req_fd(slave)
208     }
209 
get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)>210     fn get_inflight_fd(&self, inflight: &VhostUserInflight) -> Result<(VhostUserInflight, File)> {
211         self.lock().unwrap().get_inflight_fd(inflight)
212     }
213 
set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()>214     fn set_inflight_fd(&self, inflight: &VhostUserInflight, file: File) -> Result<()> {
215         self.lock().unwrap().set_inflight_fd(inflight, file)
216     }
217 
get_max_mem_slots(&self) -> Result<u64>218     fn get_max_mem_slots(&self) -> Result<u64> {
219         self.lock().unwrap().get_max_mem_slots()
220     }
221 
add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()>222     fn add_mem_region(&self, region: &VhostUserSingleMemoryRegion, fd: File) -> Result<()> {
223         self.lock().unwrap().add_mem_region(region, fd)
224     }
225 
remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()>226     fn remove_mem_region(&self, region: &VhostUserSingleMemoryRegion) -> Result<()> {
227         self.lock().unwrap().remove_mem_region(region)
228     }
229 }
230 
231 /// Server to handle service requests from masters from the master communication channel.
232 ///
233 /// The [SlaveReqHandler] acts as a server on the slave side, to handle service requests from
234 /// masters on the master communication channel. It's actually a proxy invoking the registered
235 /// handler implementing [VhostUserSlaveReqHandler] to do the real work.
236 ///
237 /// The lifetime of the SlaveReqHandler object should be the same as the underline Unix Domain
238 /// Socket, so it gets simpler to recover from disconnect.
239 ///
240 /// [VhostUserSlaveReqHandler]: trait.VhostUserSlaveReqHandler.html
241 /// [SlaveReqHandler]: struct.SlaveReqHandler.html
242 pub struct SlaveReqHandler<S: VhostUserSlaveReqHandler> {
243     // underlying Unix domain socket for communication
244     main_sock: Endpoint<MasterReq>,
245     // the vhost-user backend device object
246     backend: Arc<S>,
247 
248     virtio_features: u64,
249     acked_virtio_features: u64,
250     protocol_features: VhostUserProtocolFeatures,
251     acked_protocol_features: u64,
252 
253     // sending ack for messages without payload
254     reply_ack_enabled: bool,
255     // whether the endpoint has encountered any failure
256     error: Option<i32>,
257 }
258 
259 impl<S: VhostUserSlaveReqHandler> SlaveReqHandler<S> {
260     /// Create a vhost-user slave endpoint.
new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self261     pub(super) fn new(main_sock: Endpoint<MasterReq>, backend: Arc<S>) -> Self {
262         SlaveReqHandler {
263             main_sock,
264             backend,
265             virtio_features: 0,
266             acked_virtio_features: 0,
267             protocol_features: VhostUserProtocolFeatures::empty(),
268             acked_protocol_features: 0,
269             reply_ack_enabled: false,
270             error: None,
271         }
272     }
273 
check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()>274     fn check_feature(&self, feat: VhostUserVirtioFeatures) -> Result<()> {
275         if self.acked_virtio_features & feat.bits() != 0 {
276             Ok(())
277         } else {
278             Err(Error::InactiveFeature(feat))
279         }
280     }
281 
check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()>282     fn check_proto_feature(&self, feat: VhostUserProtocolFeatures) -> Result<()> {
283         if self.acked_protocol_features & feat.bits() != 0 {
284             Ok(())
285         } else {
286             Err(Error::InactiveOperation(feat))
287         }
288     }
289 
290     /// Create a vhost-user slave endpoint from a connected socket.
from_stream(socket: UnixStream, backend: Arc<S>) -> Self291     pub fn from_stream(socket: UnixStream, backend: Arc<S>) -> Self {
292         Self::new(Endpoint::from_stream(socket), backend)
293     }
294 
295     /// Create a new vhost-user slave endpoint.
296     ///
297     /// # Arguments
298     /// * - `path` - path of Unix domain socket listener to connect to
299     /// * - `backend` - handler for requests from the master to the slave
connect(path: &str, backend: Arc<S>) -> Result<Self>300     pub fn connect(path: &str, backend: Arc<S>) -> Result<Self> {
301         Ok(Self::new(Endpoint::<MasterReq>::connect(path)?, backend))
302     }
303 
304     /// Mark endpoint as failed with specified error code.
set_failed(&mut self, error: i32)305     pub fn set_failed(&mut self, error: i32) {
306         self.error = Some(error);
307     }
308 
309     /// Main entrance to server slave request from the slave communication channel.
310     ///
311     /// Receive and handle one incoming request message from the master. The caller needs to:
312     /// - serialize calls to this function
313     /// - decide what to do when error happens
314     /// - optional recover from failure
handle_request(&mut self) -> Result<()>315     pub fn handle_request(&mut self) -> Result<()> {
316         // Return error if the endpoint is already in failed state.
317         self.check_state()?;
318 
319         // The underlying communication channel is a Unix domain socket in
320         // stream mode, and recvmsg() is a little tricky here. To successfully
321         // receive attached file descriptors, we need to receive messages and
322         // corresponding attached file descriptors in this way:
323         // . recv messsage header and optional attached file
324         // . validate message header
325         // . recv optional message body and payload according size field in
326         //   message header
327         // . validate message body and optional payload
328         let (hdr, files) = self.main_sock.recv_header()?;
329         self.check_attached_files(&hdr, &files)?;
330 
331         let (size, buf) = match hdr.get_size() {
332             0 => (0, vec![0u8; 0]),
333             len => {
334                 let (size2, rbuf) = self.main_sock.recv_data(len as usize)?;
335                 if size2 != len as usize {
336                     return Err(Error::InvalidMessage);
337                 }
338                 (size2, rbuf)
339             }
340         };
341 
342         match hdr.get_code() {
343             Ok(MasterReq::SET_OWNER) => {
344                 self.check_request_size(&hdr, size, 0)?;
345                 let res = self.backend.set_owner();
346                 self.send_ack_message(&hdr, res)?;
347             }
348             Ok(MasterReq::RESET_OWNER) => {
349                 self.check_request_size(&hdr, size, 0)?;
350                 let res = self.backend.reset_owner();
351                 self.send_ack_message(&hdr, res)?;
352             }
353             Ok(MasterReq::GET_FEATURES) => {
354                 self.check_request_size(&hdr, size, 0)?;
355                 let features = self.backend.get_features()?;
356                 let msg = VhostUserU64::new(features);
357                 self.send_reply_message(&hdr, &msg)?;
358                 self.virtio_features = features;
359                 self.update_reply_ack_flag();
360             }
361             Ok(MasterReq::SET_FEATURES) => {
362                 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
363                 let res = self.backend.set_features(msg.value);
364                 self.acked_virtio_features = msg.value;
365                 self.update_reply_ack_flag();
366                 self.send_ack_message(&hdr, res)?;
367             }
368             Ok(MasterReq::SET_MEM_TABLE) => {
369                 let res = self.set_mem_table(&hdr, size, &buf, files);
370                 self.send_ack_message(&hdr, res)?;
371             }
372             Ok(MasterReq::SET_VRING_NUM) => {
373                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
374                 let res = self.backend.set_vring_num(msg.index, msg.num);
375                 self.send_ack_message(&hdr, res)?;
376             }
377             Ok(MasterReq::SET_VRING_ADDR) => {
378                 let msg = self.extract_request_body::<VhostUserVringAddr>(&hdr, size, &buf)?;
379                 let flags = match VhostUserVringAddrFlags::from_bits(msg.flags) {
380                     Some(val) => val,
381                     None => return Err(Error::InvalidMessage),
382                 };
383                 let res = self.backend.set_vring_addr(
384                     msg.index,
385                     flags,
386                     msg.descriptor,
387                     msg.used,
388                     msg.available,
389                     msg.log,
390                 );
391                 self.send_ack_message(&hdr, res)?;
392             }
393             Ok(MasterReq::SET_VRING_BASE) => {
394                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
395                 let res = self.backend.set_vring_base(msg.index, msg.num);
396                 self.send_ack_message(&hdr, res)?;
397             }
398             Ok(MasterReq::GET_VRING_BASE) => {
399                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
400                 let reply = self.backend.get_vring_base(msg.index)?;
401                 self.send_reply_message(&hdr, &reply)?;
402             }
403             Ok(MasterReq::SET_VRING_CALL) => {
404                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
405                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
406                 let res = self.backend.set_vring_call(index, file);
407                 self.send_ack_message(&hdr, res)?;
408             }
409             Ok(MasterReq::SET_VRING_KICK) => {
410                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
411                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
412                 let res = self.backend.set_vring_kick(index, file);
413                 self.send_ack_message(&hdr, res)?;
414             }
415             Ok(MasterReq::SET_VRING_ERR) => {
416                 self.check_request_size(&hdr, size, mem::size_of::<VhostUserU64>())?;
417                 let (index, file) = self.handle_vring_fd_request(&buf, files)?;
418                 let res = self.backend.set_vring_err(index, file);
419                 self.send_ack_message(&hdr, res)?;
420             }
421             Ok(MasterReq::GET_PROTOCOL_FEATURES) => {
422                 self.check_request_size(&hdr, size, 0)?;
423                 let features = self.backend.get_protocol_features()?;
424 
425                 // Enable the `XEN_MMAP` protocol feature for backends if xen feature is enabled.
426                 #[cfg(feature = "xen")]
427                 let features = features | VhostUserProtocolFeatures::XEN_MMAP;
428 
429                 let msg = VhostUserU64::new(features.bits());
430                 self.send_reply_message(&hdr, &msg)?;
431                 self.protocol_features = features;
432                 self.update_reply_ack_flag();
433             }
434             Ok(MasterReq::SET_PROTOCOL_FEATURES) => {
435                 let msg = self.extract_request_body::<VhostUserU64>(&hdr, size, &buf)?;
436                 let res = self.backend.set_protocol_features(msg.value);
437                 self.acked_protocol_features = msg.value;
438                 self.update_reply_ack_flag();
439                 self.send_ack_message(&hdr, res)?;
440 
441                 #[cfg(feature = "xen")]
442                 self.check_proto_feature(VhostUserProtocolFeatures::XEN_MMAP)?;
443             }
444             Ok(MasterReq::GET_QUEUE_NUM) => {
445                 self.check_proto_feature(VhostUserProtocolFeatures::MQ)?;
446                 self.check_request_size(&hdr, size, 0)?;
447                 let num = self.backend.get_queue_num()?;
448                 let msg = VhostUserU64::new(num);
449                 self.send_reply_message(&hdr, &msg)?;
450             }
451             Ok(MasterReq::SET_VRING_ENABLE) => {
452                 let msg = self.extract_request_body::<VhostUserVringState>(&hdr, size, &buf)?;
453                 self.check_feature(VhostUserVirtioFeatures::PROTOCOL_FEATURES)?;
454                 let enable = match msg.num {
455                     1 => true,
456                     0 => false,
457                     _ => return Err(Error::InvalidParam),
458                 };
459 
460                 let res = self.backend.set_vring_enable(msg.index, enable);
461                 self.send_ack_message(&hdr, res)?;
462             }
463             Ok(MasterReq::GET_CONFIG) => {
464                 self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
465                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
466                 self.get_config(&hdr, &buf)?;
467             }
468             Ok(MasterReq::SET_CONFIG) => {
469                 self.check_proto_feature(VhostUserProtocolFeatures::CONFIG)?;
470                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
471                 let res = self.set_config(size, &buf);
472                 self.send_ack_message(&hdr, res)?;
473             }
474             Ok(MasterReq::SET_SLAVE_REQ_FD) => {
475                 self.check_proto_feature(VhostUserProtocolFeatures::SLAVE_REQ)?;
476                 self.check_request_size(&hdr, size, hdr.get_size() as usize)?;
477                 let res = self.set_slave_req_fd(files);
478                 self.send_ack_message(&hdr, res)?;
479             }
480             Ok(MasterReq::GET_INFLIGHT_FD) => {
481                 self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
482 
483                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
484                 let (inflight, file) = self.backend.get_inflight_fd(&msg)?;
485                 let reply_hdr = self.new_reply_header::<VhostUserInflight>(&hdr, 0)?;
486                 self.main_sock
487                     .send_message(&reply_hdr, &inflight, Some(&[file.as_raw_fd()]))?;
488             }
489             Ok(MasterReq::SET_INFLIGHT_FD) => {
490                 self.check_proto_feature(VhostUserProtocolFeatures::INFLIGHT_SHMFD)?;
491                 let file = take_single_file(files).ok_or(Error::IncorrectFds)?;
492                 let msg = self.extract_request_body::<VhostUserInflight>(&hdr, size, &buf)?;
493                 let res = self.backend.set_inflight_fd(&msg, file);
494                 self.send_ack_message(&hdr, res)?;
495             }
496             Ok(MasterReq::GET_MAX_MEM_SLOTS) => {
497                 self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
498                 self.check_request_size(&hdr, size, 0)?;
499                 let num = self.backend.get_max_mem_slots()?;
500                 let msg = VhostUserU64::new(num);
501                 self.send_reply_message(&hdr, &msg)?;
502             }
503             Ok(MasterReq::ADD_MEM_REG) => {
504                 self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
505                 let mut files = files.ok_or(Error::InvalidParam)?;
506                 if files.len() != 1 {
507                     return Err(Error::InvalidParam);
508                 }
509                 let msg =
510                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
511                 let res = self.backend.add_mem_region(&msg, files.swap_remove(0));
512                 self.send_ack_message(&hdr, res)?;
513             }
514             Ok(MasterReq::REM_MEM_REG) => {
515                 self.check_proto_feature(VhostUserProtocolFeatures::CONFIGURE_MEM_SLOTS)?;
516 
517                 let msg =
518                     self.extract_request_body::<VhostUserSingleMemoryRegion>(&hdr, size, &buf)?;
519                 let res = self.backend.remove_mem_region(&msg);
520                 self.send_ack_message(&hdr, res)?;
521             }
522             _ => {
523                 return Err(Error::InvalidMessage);
524             }
525         }
526         Ok(())
527     }
528 
set_mem_table( &mut self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], files: Option<Vec<File>>, ) -> Result<()>529     fn set_mem_table(
530         &mut self,
531         hdr: &VhostUserMsgHeader<MasterReq>,
532         size: usize,
533         buf: &[u8],
534         files: Option<Vec<File>>,
535     ) -> Result<()> {
536         self.check_request_size(hdr, size, hdr.get_size() as usize)?;
537 
538         // check message size is consistent
539         let hdrsize = mem::size_of::<VhostUserMemory>();
540         if size < hdrsize {
541             return Err(Error::InvalidMessage);
542         }
543         // SAFETY: Safe because we checked that `buf` size is at least that of
544         // VhostUserMemory.
545         let msg = unsafe { &*(buf.as_ptr() as *const VhostUserMemory) };
546         if !msg.is_valid() {
547             return Err(Error::InvalidMessage);
548         }
549         if size != hdrsize + msg.num_regions as usize * mem::size_of::<VhostUserMemoryRegion>() {
550             return Err(Error::InvalidMessage);
551         }
552 
553         // validate number of fds matching number of memory regions
554         let files = files.ok_or(Error::InvalidMessage)?;
555         if files.len() != msg.num_regions as usize {
556             return Err(Error::InvalidMessage);
557         }
558 
559         // Validate memory regions
560         //
561         // SAFETY: Safe because we checked that `buf` size is equal to that of
562         // VhostUserMemory, plus `msg.num_regions` elements of VhostUserMemoryRegion.
563         let regions = unsafe {
564             slice::from_raw_parts(
565                 buf.as_ptr().add(hdrsize) as *const VhostUserMemoryRegion,
566                 msg.num_regions as usize,
567             )
568         };
569         for region in regions.iter() {
570             if !region.is_valid() {
571                 return Err(Error::InvalidMessage);
572             }
573         }
574 
575         self.backend.set_mem_table(regions, files)
576     }
577 
get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()>578     fn get_config(&mut self, hdr: &VhostUserMsgHeader<MasterReq>, buf: &[u8]) -> Result<()> {
579         let payload_offset = mem::size_of::<VhostUserConfig>();
580         if buf.len() > MAX_MSG_SIZE || buf.len() < payload_offset {
581             return Err(Error::InvalidMessage);
582         }
583         // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserConfig.
584         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
585         if !msg.is_valid() {
586             return Err(Error::InvalidMessage);
587         }
588         if buf.len() - payload_offset != msg.size as usize {
589             return Err(Error::InvalidMessage);
590         }
591         let flags = match VhostUserConfigFlags::from_bits(msg.flags) {
592             Some(val) => val,
593             None => return Err(Error::InvalidMessage),
594         };
595         let res = self.backend.get_config(msg.offset, msg.size, flags);
596 
597         // vhost-user slave's payload size MUST match master's request
598         // on success, uses zero length of payload to indicate an error
599         // to vhost-user master.
600         match res {
601             Ok(ref buf) if buf.len() == msg.size as usize => {
602                 let reply = VhostUserConfig::new(msg.offset, buf.len() as u32, flags);
603                 self.send_reply_with_payload(hdr, &reply, buf.as_slice())?;
604             }
605             Ok(_) => {
606                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
607                 self.send_reply_message(hdr, &reply)?;
608             }
609             Err(_) => {
610                 let reply = VhostUserConfig::new(msg.offset, 0, flags);
611                 self.send_reply_message(hdr, &reply)?;
612             }
613         }
614         Ok(())
615     }
616 
set_config(&mut self, size: usize, buf: &[u8]) -> Result<()>617     fn set_config(&mut self, size: usize, buf: &[u8]) -> Result<()> {
618         if size > MAX_MSG_SIZE || size < mem::size_of::<VhostUserConfig>() {
619             return Err(Error::InvalidMessage);
620         }
621         // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserConfig.
622         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserConfig) };
623         if !msg.is_valid() {
624             return Err(Error::InvalidMessage);
625         }
626         if size - mem::size_of::<VhostUserConfig>() != msg.size as usize {
627             return Err(Error::InvalidMessage);
628         }
629         let flags = VhostUserConfigFlags::from_bits(msg.flags).ok_or(Error::InvalidMessage)?;
630 
631         self.backend
632             .set_config(msg.offset, &buf[mem::size_of::<VhostUserConfig>()..], flags)
633     }
634 
set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()>635     fn set_slave_req_fd(&mut self, files: Option<Vec<File>>) -> Result<()> {
636         let file = take_single_file(files).ok_or(Error::InvalidMessage)?;
637         // SAFETY: Safe because we have ownership of the files that were
638         // checked when received. We have to trust that they are Unix sockets
639         // since we have no way to check this. If not, it will fail later.
640         let sock = unsafe { UnixStream::from_raw_fd(file.into_raw_fd()) };
641         let slave = Slave::from_stream(sock);
642         self.backend.set_slave_req_fd(slave);
643         Ok(())
644     }
645 
handle_vring_fd_request( &mut self, buf: &[u8], files: Option<Vec<File>>, ) -> Result<(u8, Option<File>)>646     fn handle_vring_fd_request(
647         &mut self,
648         buf: &[u8],
649         files: Option<Vec<File>>,
650     ) -> Result<(u8, Option<File>)> {
651         if buf.len() > MAX_MSG_SIZE || buf.len() < mem::size_of::<VhostUserU64>() {
652             return Err(Error::InvalidMessage);
653         }
654         // SAFETY: Safe because we checked that `buf` size is at least that of VhostUserU64.
655         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const VhostUserU64) };
656         if !msg.is_valid() {
657             return Err(Error::InvalidMessage);
658         }
659 
660         // Bits (0-7) of the payload contain the vring index. Bit 8 is the
661         // invalid FD flag. This bit is set when there is no file descriptor
662         // in the ancillary data. This signals that polling will be used
663         // instead of waiting for the call.
664         // If Bit 8 is unset, the data must contain a file descriptor.
665         let has_fd = (msg.value & 0x100u64) == 0;
666 
667         let file = take_single_file(files);
668 
669         if has_fd && file.is_none() || !has_fd && file.is_some() {
670             return Err(Error::InvalidMessage);
671         }
672 
673         Ok((msg.value as u8, file))
674     }
675 
check_state(&self) -> Result<()>676     fn check_state(&self) -> Result<()> {
677         match self.error {
678             Some(e) => Err(Error::SocketBroken(std::io::Error::from_raw_os_error(e))),
679             None => Ok(()),
680         }
681     }
682 
check_request_size( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, expected: usize, ) -> Result<()>683     fn check_request_size(
684         &self,
685         hdr: &VhostUserMsgHeader<MasterReq>,
686         size: usize,
687         expected: usize,
688     ) -> Result<()> {
689         if hdr.get_size() as usize != expected
690             || hdr.is_reply()
691             || hdr.get_version() != 0x1
692             || size != expected
693         {
694             return Err(Error::InvalidMessage);
695         }
696         Ok(())
697     }
698 
check_attached_files( &self, hdr: &VhostUserMsgHeader<MasterReq>, files: &Option<Vec<File>>, ) -> Result<()>699     fn check_attached_files(
700         &self,
701         hdr: &VhostUserMsgHeader<MasterReq>,
702         files: &Option<Vec<File>>,
703     ) -> Result<()> {
704         match hdr.get_code() {
705             Ok(
706                 MasterReq::SET_MEM_TABLE
707                 | MasterReq::SET_VRING_CALL
708                 | MasterReq::SET_VRING_KICK
709                 | MasterReq::SET_VRING_ERR
710                 | MasterReq::SET_LOG_BASE
711                 | MasterReq::SET_LOG_FD
712                 | MasterReq::SET_SLAVE_REQ_FD
713                 | MasterReq::SET_INFLIGHT_FD
714                 | MasterReq::ADD_MEM_REG,
715             ) => Ok(()),
716             _ if files.is_some() => Err(Error::InvalidMessage),
717             _ => Ok(()),
718         }
719     }
720 
extract_request_body<T: Sized + VhostUserMsgValidator>( &self, hdr: &VhostUserMsgHeader<MasterReq>, size: usize, buf: &[u8], ) -> Result<T>721     fn extract_request_body<T: Sized + VhostUserMsgValidator>(
722         &self,
723         hdr: &VhostUserMsgHeader<MasterReq>,
724         size: usize,
725         buf: &[u8],
726     ) -> Result<T> {
727         self.check_request_size(hdr, size, mem::size_of::<T>())?;
728         // SAFETY: Safe because we checked that `buf` size is equal to T size.
729         let msg = unsafe { std::ptr::read_unaligned(buf.as_ptr() as *const T) };
730         if !msg.is_valid() {
731             return Err(Error::InvalidMessage);
732         }
733         Ok(msg)
734     }
735 
update_reply_ack_flag(&mut self)736     fn update_reply_ack_flag(&mut self) {
737         let vflag = VhostUserVirtioFeatures::PROTOCOL_FEATURES.bits();
738         let pflag = VhostUserProtocolFeatures::REPLY_ACK;
739         if (self.virtio_features & vflag) != 0
740             && self.protocol_features.contains(pflag)
741             && (self.acked_protocol_features & pflag.bits()) != 0
742         {
743             self.reply_ack_enabled = true;
744         } else {
745             self.reply_ack_enabled = false;
746         }
747     }
748 
new_reply_header<T: Sized>( &self, req: &VhostUserMsgHeader<MasterReq>, payload_size: usize, ) -> Result<VhostUserMsgHeader<MasterReq>>749     fn new_reply_header<T: Sized>(
750         &self,
751         req: &VhostUserMsgHeader<MasterReq>,
752         payload_size: usize,
753     ) -> Result<VhostUserMsgHeader<MasterReq>> {
754         if mem::size_of::<T>() > MAX_MSG_SIZE
755             || payload_size > MAX_MSG_SIZE
756             || mem::size_of::<T>() + payload_size > MAX_MSG_SIZE
757         {
758             return Err(Error::InvalidParam);
759         }
760         self.check_state()?;
761         Ok(VhostUserMsgHeader::new(
762             req.get_code()?,
763             VhostUserHeaderFlag::REPLY.bits(),
764             (mem::size_of::<T>() + payload_size) as u32,
765         ))
766     }
767 
send_ack_message( &mut self, req: &VhostUserMsgHeader<MasterReq>, res: Result<()>, ) -> Result<()>768     fn send_ack_message(
769         &mut self,
770         req: &VhostUserMsgHeader<MasterReq>,
771         res: Result<()>,
772     ) -> Result<()> {
773         if self.reply_ack_enabled && req.is_need_reply() {
774             let hdr = self.new_reply_header::<VhostUserU64>(req, 0)?;
775             let val = match res {
776                 Ok(_) => 0,
777                 Err(_) => 1,
778             };
779             let msg = VhostUserU64::new(val);
780             self.main_sock.send_message(&hdr, &msg, None)?;
781         }
782         res
783     }
784 
send_reply_message<T: ByteValued>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, ) -> Result<()>785     fn send_reply_message<T: ByteValued>(
786         &mut self,
787         req: &VhostUserMsgHeader<MasterReq>,
788         msg: &T,
789     ) -> Result<()> {
790         let hdr = self.new_reply_header::<T>(req, 0)?;
791         self.main_sock.send_message(&hdr, msg, None)?;
792         Ok(())
793     }
794 
send_reply_with_payload<T: ByteValued>( &mut self, req: &VhostUserMsgHeader<MasterReq>, msg: &T, payload: &[u8], ) -> Result<()>795     fn send_reply_with_payload<T: ByteValued>(
796         &mut self,
797         req: &VhostUserMsgHeader<MasterReq>,
798         msg: &T,
799         payload: &[u8],
800     ) -> Result<()> {
801         let hdr = self.new_reply_header::<T>(req, payload.len())?;
802         self.main_sock
803             .send_message_with_payload(&hdr, msg, payload, None)?;
804         Ok(())
805     }
806 }
807 
808 impl<S: VhostUserSlaveReqHandler> AsRawFd for SlaveReqHandler<S> {
as_raw_fd(&self) -> RawFd809     fn as_raw_fd(&self) -> RawFd {
810         self.main_sock.as_raw_fd()
811     }
812 }
813 
814 #[cfg(test)]
815 mod tests {
816     use std::os::unix::io::AsRawFd;
817 
818     use super::*;
819     use crate::vhost_user::dummy_slave::DummySlaveReqHandler;
820 
821     #[test]
test_slave_req_handler_new()822     fn test_slave_req_handler_new() {
823         let (p1, _p2) = UnixStream::pair().unwrap();
824         let endpoint = Endpoint::<MasterReq>::from_stream(p1);
825         let backend = Arc::new(Mutex::new(DummySlaveReqHandler::new()));
826         let mut handler = SlaveReqHandler::new(endpoint, backend);
827 
828         handler.check_state().unwrap();
829         handler.set_failed(libc::EAGAIN);
830         handler.check_state().unwrap_err();
831         assert!(handler.as_raw_fd() >= 0);
832     }
833 }
834