xref: /aosp_15_r20/external/crosvm/devices/src/virtio/snd/vios_backend/shm_vios.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2020 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::collections::HashMap;
6 use std::collections::VecDeque;
7 use std::fs::File;
8 use std::io::Error as IOError;
9 use std::io::ErrorKind as IOErrorKind;
10 use std::io::Seek;
11 use std::io::SeekFrom;
12 use std::path::Path;
13 use std::path::PathBuf;
14 use std::sync::mpsc::channel;
15 use std::sync::mpsc::Receiver;
16 use std::sync::mpsc::RecvError;
17 use std::sync::mpsc::Sender;
18 use std::sync::Arc;
19 
20 use base::error;
21 use base::AsRawDescriptor;
22 use base::Error as BaseError;
23 use base::Event;
24 use base::EventToken;
25 use base::FromRawDescriptor;
26 use base::IntoRawDescriptor;
27 use base::MemoryMapping;
28 use base::MemoryMappingBuilder;
29 use base::MmapError;
30 use base::RawDescriptor;
31 use base::SafeDescriptor;
32 use base::ScmSocket;
33 use base::UnixSeqpacket;
34 use base::VolatileMemory;
35 use base::VolatileMemoryError;
36 use base::VolatileSlice;
37 use base::WaitContext;
38 use base::WorkerThread;
39 use remain::sorted;
40 use serde::Deserialize;
41 use serde::Serialize;
42 use sync::Mutex;
43 use thiserror::Error as ThisError;
44 use zerocopy::AsBytes;
45 use zerocopy::FromBytes;
46 use zerocopy::FromZeroes;
47 
48 use crate::virtio::snd::constants::*;
49 use crate::virtio::snd::layout::*;
50 use crate::virtio::snd::vios_backend::streams::StreamState;
51 
52 pub type Result<T> = std::result::Result<T, Error>;
53 
54 #[sorted]
55 #[derive(ThisError, Debug)]
56 pub enum Error {
57     #[error("Error memory mapping client_shm: {0}")]
58     BaseMmapError(BaseError),
59     #[error("Sender was dropped without sending buffer status, the recv thread may have exited")]
60     BufferStatusSenderLost(RecvError),
61     #[error("Command failed with status {0}")]
62     CommandFailed(u32),
63     #[error("Error duplicating file descriptor: {0}")]
64     DupError(BaseError),
65     #[error("Failed to create Recv event: {0}")]
66     EventCreateError(BaseError),
67     #[error("Failed to dup Recv event: {0}")]
68     EventDupError(BaseError),
69     #[error("Failed to signal event: {0}")]
70     EventWriteError(BaseError),
71     #[error("Failed to get size of tx shared memory: {0}")]
72     FileSizeError(IOError),
73     #[error("Error accessing guest's shared memory: {0}")]
74     GuestMmapError(MmapError),
75     #[error("No jack with id {0}")]
76     InvalidJackId(u32),
77     #[error("No stream with id {0}")]
78     InvalidStreamId(u32),
79     #[error("IO buffer operation failed: status = {0}")]
80     IOBufferError(u32),
81     #[error("No PCM streams available")]
82     NoStreamsAvailable,
83     #[error("Insuficient space for the new buffer in the queue's buffer area")]
84     OutOfSpace,
85     #[error("Platform not supported")]
86     PlatformNotSupported,
87     #[error("{0}")]
88     ProtocolError(ProtocolErrorKind),
89     #[error("Failed to connect to VioS server {1}: {0:?}")]
90     ServerConnectionError(IOError, PathBuf),
91     #[error("Failed to communicate with VioS server: {0:?}")]
92     ServerError(IOError),
93     #[error("Failed to communicate with VioS server: {0:?}")]
94     ServerIOError(IOError),
95     #[error("Error accessing VioS server's shared memory: {0}")]
96     ServerMmapError(MmapError),
97     #[error("Failed to duplicate UnixSeqpacket: {0}")]
98     UnixSeqpacketDupError(IOError),
99     #[error("Unsupported frame rate: {0}")]
100     UnsupportedFrameRate(u32),
101     #[error("Error accessing volatile memory: {0}")]
102     VolatileMemoryError(VolatileMemoryError),
103     #[error("Failed to create Recv thread's WaitContext: {0}")]
104     WaitContextCreateError(BaseError),
105     #[error("Error waiting for events")]
106     WaitError(BaseError),
107     #[error("Invalid operation for stream direction: {0}")]
108     WrongDirection(u8),
109     #[error("Set saved params should only be used while restoring the device")]
110     WrongSetParams,
111 }
112 
113 #[derive(ThisError, Debug)]
114 pub enum ProtocolErrorKind {
115     #[error("The server sent a config of the wrong size: {0}")]
116     UnexpectedConfigSize(usize),
117     #[error("Received {1} file descriptors from the server, expected {0}")]
118     UnexpectedNumberOfFileDescriptors(usize, usize), // expected, received
119     #[error("Server's version ({0}) doesn't match client's")]
120     VersionMismatch(u32),
121     #[error("Received a msg with an unexpected size: expected {0}, received {1}")]
122     UnexpectedMessageSize(usize, usize), // expected, received
123 }
124 
125 /// The client for the VioS backend
126 ///
127 /// Uses a protocol equivalent to virtio-snd over a shared memory file and a unix socket for
128 /// notifications. It's thread safe, it can be encapsulated in an Arc smart pointer and shared
129 /// between threads.
130 pub struct VioSClient {
131     // These mutexes should almost never be held simultaneously. If at some point they have to the
132     // locking order should match the order in which they are declared here.
133     config: VioSConfig,
134     jacks: Vec<virtio_snd_jack_info>,
135     streams: Vec<virtio_snd_pcm_info>,
136     chmaps: Vec<virtio_snd_chmap_info>,
137     // The control socket is used from multiple threads to send and wait for a reply, which needs
138     // to happen atomically, hence the need for a mutex instead of just sharing clones of the
139     // socket.
140     control_socket: Mutex<UnixSeqpacket>,
141     event_socket: UnixSeqpacket,
142     // These are thread safe and don't require locking
143     tx: IoBufferQueue,
144     rx: IoBufferQueue,
145     // This is accessed by the recv_thread and whatever thread processes the events
146     events: Arc<Mutex<VecDeque<virtio_snd_event>>>,
147     event_notifier: Event,
148     // These are accessed by the recv_thread and the stream threads
149     tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
150     rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
151     recv_thread_state: Arc<Mutex<ThreadFlags>>,
152     recv_thread: Mutex<Option<WorkerThread<Result<()>>>>,
153     // Params are required to be stored for snapshot/restore. On restore, we don't have the params
154     // locally available as the VM is started anew, so they need to be restored.
155     params: HashMap<u32, virtio_snd_pcm_set_params>,
156 }
157 
158 #[derive(Serialize, Deserialize)]
159 pub struct VioSClientSnapshot {
160     config: VioSConfig,
161     jacks: Vec<virtio_snd_jack_info>,
162     streams: Vec<virtio_snd_pcm_info>,
163     chmaps: Vec<virtio_snd_chmap_info>,
164     params: HashMap<u32, virtio_snd_pcm_set_params>,
165 }
166 
167 impl VioSClient {
168     /// Create a new client given the path to the audio server's socket.
try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient>169     pub fn try_new<P: AsRef<Path>>(server: P) -> Result<VioSClient> {
170         let client_socket = ScmSocket::try_from(
171             UnixSeqpacket::connect(server.as_ref())
172                 .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?,
173         )
174         .map_err(|e| Error::ServerConnectionError(e, server.as_ref().into()))?;
175         let mut config: VioSConfig = Default::default();
176         const NUM_FDS: usize = 5;
177         let (recv_size, mut safe_fds) = client_socket
178             .recv_with_fds(config.as_bytes_mut(), NUM_FDS)
179             .map_err(Error::ServerError)?;
180 
181         if recv_size != std::mem::size_of::<VioSConfig>() {
182             return Err(Error::ProtocolError(
183                 ProtocolErrorKind::UnexpectedConfigSize(recv_size),
184             ));
185         }
186 
187         if config.version != VIOS_VERSION {
188             return Err(Error::ProtocolError(ProtocolErrorKind::VersionMismatch(
189                 config.version,
190             )));
191         }
192 
193         fn pop<T: FromRawDescriptor>(
194             safe_fds: &mut Vec<SafeDescriptor>,
195             expected: usize,
196             received: usize,
197         ) -> Result<T> {
198             // SAFETY:
199             // Safe because we transfer ownership from the SafeDescriptor to T
200             unsafe {
201                 Ok(T::from_raw_descriptor(
202                     safe_fds
203                         .pop()
204                         .ok_or(Error::ProtocolError(
205                             ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(
206                                 expected, received,
207                             ),
208                         ))?
209                         .into_raw_descriptor(),
210                 ))
211             }
212         }
213 
214         let fd_count = safe_fds.len();
215         let rx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
216         let tx_shm_file = pop::<File>(&mut safe_fds, NUM_FDS, fd_count)?;
217         let rx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
218         let tx_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
219         let event_socket = pop::<UnixSeqpacket>(&mut safe_fds, NUM_FDS, fd_count)?;
220 
221         if !safe_fds.is_empty() {
222             return Err(Error::ProtocolError(
223                 ProtocolErrorKind::UnexpectedNumberOfFileDescriptors(NUM_FDS, fd_count),
224             ));
225         }
226 
227         let tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
228             Arc::new(Mutex::new(HashMap::new()));
229         let rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>> =
230             Arc::new(Mutex::new(HashMap::new()));
231         let recv_thread_state = Arc::new(Mutex::new(ThreadFlags {
232             reporting_events: false,
233         }));
234 
235         let mut client = VioSClient {
236             config,
237             jacks: Vec::new(),
238             streams: Vec::new(),
239             chmaps: Vec::new(),
240             control_socket: Mutex::new(client_socket.into_inner()),
241             event_socket,
242             tx: IoBufferQueue::new(tx_socket, tx_shm_file)?,
243             rx: IoBufferQueue::new(rx_socket, rx_shm_file)?,
244             events: Arc::new(Mutex::new(VecDeque::new())),
245             event_notifier: Event::new().map_err(Error::EventCreateError)?,
246             tx_subscribers,
247             rx_subscribers,
248             recv_thread_state,
249             recv_thread: Mutex::new(None),
250             params: HashMap::new(),
251         };
252         client.request_and_cache_info()?;
253         Ok(client)
254     }
255 
256     /// Get the number of jacks
num_jacks(&self) -> u32257     pub fn num_jacks(&self) -> u32 {
258         self.config.jacks
259     }
260 
261     /// Get the number of pcm streams
num_streams(&self) -> u32262     pub fn num_streams(&self) -> u32 {
263         self.config.streams
264     }
265 
266     /// Get the number of channel maps
num_chmaps(&self) -> u32267     pub fn num_chmaps(&self) -> u32 {
268         self.config.chmaps
269     }
270 
271     /// Get the configuration information on a jack
jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info>272     pub fn jack_info(&self, idx: u32) -> Option<virtio_snd_jack_info> {
273         self.jacks.get(idx as usize).copied()
274     }
275 
276     /// Get the configuration information on a pcm stream
stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info>277     pub fn stream_info(&self, idx: u32) -> Option<virtio_snd_pcm_info> {
278         self.streams.get(idx as usize).cloned()
279     }
280 
281     /// Get the configuration information on a channel map
chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info>282     pub fn chmap_info(&self, idx: u32) -> Option<virtio_snd_chmap_info> {
283         self.chmaps.get(idx as usize).copied()
284     }
285 
286     /// Starts the background thread that receives release messages from the server. If the thread
287     /// was already started this function does nothing.
288     /// This thread must be started prior to attempting any stream IO operation or the calling
289     /// thread would block.
start_bg_thread(&self) -> Result<()>290     pub fn start_bg_thread(&self) -> Result<()> {
291         if self.recv_thread.lock().is_some() {
292             return Ok(());
293         }
294         let tx_socket = self.tx.try_clone_socket()?;
295         let rx_socket = self.rx.try_clone_socket()?;
296         let event_socket = self
297             .event_socket
298             .try_clone()
299             .map_err(Error::UnixSeqpacketDupError)?;
300         let mut opt = self.recv_thread.lock();
301         // The lock on recv_thread was released above to avoid holding more than one lock at a time
302         // while duplicating the fds. So we have to check the condition again.
303         if opt.is_none() {
304             *opt = Some(spawn_recv_thread(
305                 self.tx_subscribers.clone(),
306                 self.rx_subscribers.clone(),
307                 self.event_notifier
308                     .try_clone()
309                     .map_err(Error::EventDupError)?,
310                 self.events.clone(),
311                 self.recv_thread_state.clone(),
312                 tx_socket,
313                 rx_socket,
314                 event_socket,
315             ));
316         }
317         Ok(())
318     }
319 
320     /// Stops the background thread.
stop_bg_thread(&self) -> Result<()>321     pub fn stop_bg_thread(&self) -> Result<()> {
322         if let Some(recv_thread) = self.recv_thread.lock().take() {
323             recv_thread.stop()?;
324         }
325         Ok(())
326     }
327 
328     /// Gets an Event object that will trigger every time an event is received from the server
get_event_notifier(&self) -> Result<Event>329     pub fn get_event_notifier(&self) -> Result<Event> {
330         // Let the background thread know that there is at least one consumer of events
331         self.recv_thread_state.lock().reporting_events = true;
332         self.event_notifier
333             .try_clone()
334             .map_err(Error::EventDupError)
335     }
336 
337     /// Retrieves one event. Callers should have received a notification through the event notifier
338     /// before calling this function.
pop_event(&self) -> Option<virtio_snd_event>339     pub fn pop_event(&self) -> Option<virtio_snd_event> {
340         self.events.lock().pop_front()
341     }
342 
343     /// Remap a jack. This should only be called if the jack announces support for the operation
344     /// through the features field in the corresponding virtio_snd_jack_info struct.
remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()>345     pub fn remap_jack(&self, jack_id: u32, association: u32, sequence: u32) -> Result<()> {
346         if jack_id >= self.config.jacks {
347             return Err(Error::InvalidJackId(jack_id));
348         }
349         let msg = virtio_snd_jack_remap {
350             hdr: virtio_snd_jack_hdr {
351                 hdr: virtio_snd_hdr {
352                     code: VIRTIO_SND_R_JACK_REMAP.into(),
353                 },
354                 jack_id: jack_id.into(),
355             },
356             association: association.into(),
357             sequence: sequence.into(),
358         };
359         let control_socket_lock = self.control_socket.lock();
360         send_cmd(&control_socket_lock, msg)
361     }
362 
363     /// Configures a stream with the given parameters.
set_stream_parameters( &mut self, stream_id: u32, params: VioSStreamParams, ) -> Result<()>364     pub fn set_stream_parameters(
365         &mut self,
366         stream_id: u32,
367         params: VioSStreamParams,
368     ) -> Result<()> {
369         self.streams
370             .get(stream_id as usize)
371             .ok_or(Error::InvalidStreamId(stream_id))?;
372         let raw_params: virtio_snd_pcm_set_params = (stream_id, params).into();
373         // Old value is not needed and is dropped
374         let _ = self.params.insert(stream_id, raw_params);
375         let control_socket_lock = self.control_socket.lock();
376         send_cmd(&control_socket_lock, raw_params)
377     }
378 
379     /// Configures a stream with the given parameters.
set_stream_parameters_raw( &mut self, raw_params: virtio_snd_pcm_set_params, ) -> Result<()>380     pub fn set_stream_parameters_raw(
381         &mut self,
382         raw_params: virtio_snd_pcm_set_params,
383     ) -> Result<()> {
384         let stream_id = raw_params.hdr.stream_id.to_native();
385         // Old value is not needed and is dropped
386         let _ = self.params.insert(stream_id, raw_params);
387         self.streams
388             .get(stream_id as usize)
389             .ok_or(Error::InvalidStreamId(stream_id))?;
390         let control_socket_lock = self.control_socket.lock();
391         send_cmd(&control_socket_lock, raw_params)
392     }
393 
394     /// Send the PREPARE_STREAM command to the server.
prepare_stream(&self, stream_id: u32) -> Result<()>395     pub fn prepare_stream(&self, stream_id: u32) -> Result<()> {
396         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_PREPARE)
397     }
398 
399     /// Send the RELEASE_STREAM command to the server.
release_stream(&self, stream_id: u32) -> Result<()>400     pub fn release_stream(&self, stream_id: u32) -> Result<()> {
401         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_RELEASE)
402     }
403 
404     /// Send the START_STREAM command to the server.
start_stream(&self, stream_id: u32) -> Result<()>405     pub fn start_stream(&self, stream_id: u32) -> Result<()> {
406         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_START)
407     }
408 
409     /// Send the STOP_STREAM command to the server.
stop_stream(&self, stream_id: u32) -> Result<()>410     pub fn stop_stream(&self, stream_id: u32) -> Result<()> {
411         self.common_stream_op(stream_id, VIRTIO_SND_R_PCM_STOP)
412     }
413 
414     /// Send audio frames to the server. Blocks the calling thread until the server acknowledges
415     /// the data.
inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>( &self, stream_id: u32, size: usize, callback: Cb, ) -> Result<(u32, R)>416     pub fn inject_audio_data<R, Cb: FnOnce(VolatileSlice) -> R>(
417         &self,
418         stream_id: u32,
419         size: usize,
420         callback: Cb,
421     ) -> Result<(u32, R)> {
422         if self
423             .streams
424             .get(stream_id as usize)
425             .ok_or(Error::InvalidStreamId(stream_id))?
426             .direction
427             != VIRTIO_SND_D_OUTPUT
428         {
429             return Err(Error::WrongDirection(VIRTIO_SND_D_OUTPUT));
430         }
431         self.streams
432             .get(stream_id as usize)
433             .ok_or(Error::InvalidStreamId(stream_id))?;
434         let dst_offset = self.tx.allocate_buffer(size)?;
435         let buffer_slice = self.tx.buffer_at(dst_offset, size)?;
436         let ret = callback(buffer_slice);
437         // Register to receive the status before sending the buffer to the server
438         let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
439         self.tx_subscribers.lock().insert(dst_offset, sender);
440         self.tx.send_buffer(stream_id, dst_offset, size)?;
441         let (_, latency) = await_status(receiver)?;
442         Ok((latency, ret))
443     }
444 
445     /// Request audio frames from the server. It blocks until the data is available.
request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>( &self, stream_id: u32, size: usize, callback: Cb, ) -> Result<(u32, R)>446     pub fn request_audio_data<R, Cb: FnOnce(&VolatileSlice) -> R>(
447         &self,
448         stream_id: u32,
449         size: usize,
450         callback: Cb,
451     ) -> Result<(u32, R)> {
452         if self
453             .streams
454             .get(stream_id as usize)
455             .ok_or(Error::InvalidStreamId(stream_id))?
456             .direction
457             != VIRTIO_SND_D_INPUT
458         {
459             return Err(Error::WrongDirection(VIRTIO_SND_D_INPUT));
460         }
461         let src_offset = self.rx.allocate_buffer(size)?;
462         // Register to receive the status before sending the buffer to the server
463         let (sender, receiver): (Sender<BufferReleaseMsg>, Receiver<BufferReleaseMsg>) = channel();
464         self.rx_subscribers.lock().insert(src_offset, sender);
465         self.rx.send_buffer(stream_id, src_offset, size)?;
466         // Make sure no mutexes are held while awaiting for the buffer to be written to
467         let (recv_size, latency) = await_status(receiver)?;
468         let buffer_slice = self.rx.buffer_at(src_offset, recv_size)?;
469         Ok((latency, callback(&buffer_slice)))
470     }
471 
472     /// Get a list of file descriptors used by the implementation.
keep_rds(&self) -> Vec<RawDescriptor>473     pub fn keep_rds(&self) -> Vec<RawDescriptor> {
474         let control_desc = self.control_socket.lock().as_raw_descriptor();
475         let event_desc = self.event_socket.as_raw_descriptor();
476         let event_notifier = self.event_notifier.as_raw_descriptor();
477         let mut ret = vec![control_desc, event_desc, event_notifier];
478         ret.append(&mut self.tx.keep_rds());
479         ret.append(&mut self.rx.keep_rds());
480         ret
481     }
482 
common_stream_op(&self, stream_id: u32, op: u32) -> Result<()>483     fn common_stream_op(&self, stream_id: u32, op: u32) -> Result<()> {
484         self.streams
485             .get(stream_id as usize)
486             .ok_or(Error::InvalidStreamId(stream_id))?;
487         let msg = virtio_snd_pcm_hdr {
488             hdr: virtio_snd_hdr { code: op.into() },
489             stream_id: stream_id.into(),
490         };
491         let control_socket_lock = self.control_socket.lock();
492         send_cmd(&control_socket_lock, msg)
493     }
494 
request_and_cache_info(&mut self) -> Result<()>495     fn request_and_cache_info(&mut self) -> Result<()> {
496         self.request_and_cache_jacks_info()?;
497         self.request_and_cache_streams_info()?;
498         self.request_and_cache_chmaps_info()?;
499         Ok(())
500     }
501 
request_info<T: AsBytes + FromBytes + Default + Copy + Clone>( &self, req_code: u32, count: usize, ) -> Result<Vec<T>>502     fn request_info<T: AsBytes + FromBytes + Default + Copy + Clone>(
503         &self,
504         req_code: u32,
505         count: usize,
506     ) -> Result<Vec<T>> {
507         let info_size = std::mem::size_of::<T>();
508         let status_size = std::mem::size_of::<virtio_snd_hdr>();
509         let req = virtio_snd_query_info {
510             hdr: virtio_snd_hdr {
511                 code: req_code.into(),
512             },
513             start_id: 0u32.into(),
514             count: (count as u32).into(),
515             size: (std::mem::size_of::<virtio_snd_query_info>() as u32).into(),
516         };
517         let control_socket_lock = self.control_socket.lock();
518         seq_socket_send(&control_socket_lock, req)?;
519         let reply = control_socket_lock
520             .recv_as_vec()
521             .map_err(Error::ServerIOError)?;
522         let mut status: virtio_snd_hdr = Default::default();
523         status
524             .as_bytes_mut()
525             .copy_from_slice(&reply[0..status_size]);
526         if status.code.to_native() != VIRTIO_SND_S_OK {
527             return Err(Error::CommandFailed(status.code.to_native()));
528         }
529         if reply.len() != status_size + count * info_size {
530             return Err(Error::ProtocolError(
531                 ProtocolErrorKind::UnexpectedMessageSize(count * info_size, reply.len()),
532             ));
533         }
534         Ok(reply[status_size..]
535             .chunks(info_size)
536             .map(|info_buffer| T::read_from(info_buffer).unwrap())
537             .collect())
538     }
539 
request_and_cache_jacks_info(&mut self) -> Result<()>540     fn request_and_cache_jacks_info(&mut self) -> Result<()> {
541         let num_jacks = self.config.jacks as usize;
542         if num_jacks == 0 {
543             return Ok(());
544         }
545         self.jacks = self.request_info(VIRTIO_SND_R_JACK_INFO, num_jacks)?;
546         Ok(())
547     }
548 
request_and_cache_streams_info(&mut self) -> Result<()>549     fn request_and_cache_streams_info(&mut self) -> Result<()> {
550         let num_streams = self.config.streams as usize;
551         if num_streams == 0 {
552             return Ok(());
553         }
554         self.streams = self.request_info(VIRTIO_SND_R_PCM_INFO, num_streams)?;
555         Ok(())
556     }
557 
request_and_cache_chmaps_info(&mut self) -> Result<()>558     fn request_and_cache_chmaps_info(&mut self) -> Result<()> {
559         let num_chmaps = self.config.chmaps as usize;
560         if num_chmaps == 0 {
561             return Ok(());
562         }
563         self.chmaps = self.request_info(VIRTIO_SND_R_CHMAP_INFO, num_chmaps)?;
564         Ok(())
565     }
566 
snapshot(&self) -> VioSClientSnapshot567     pub fn snapshot(&self) -> VioSClientSnapshot {
568         VioSClientSnapshot {
569             config: self.config,
570             jacks: self.jacks.clone(),
571             streams: self.streams.clone(),
572             chmaps: self.chmaps.clone(),
573             params: self.params.clone(),
574         }
575     }
576 
577     // Function called `restore` to signify it will happen as part of the snapshot/restore flow. No
578     // data is actually restored in the case of VioSClient.
restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()>579     pub fn restore(&mut self, data: VioSClientSnapshot) -> anyhow::Result<()> {
580         anyhow::ensure!(
581             data.config == self.config,
582             "config doesn't match on restore: expected: {:?}, got: {:?}",
583             data.config,
584             self.config
585         );
586         self.jacks = data.jacks;
587         self.streams = data.streams;
588         self.chmaps = data.chmaps;
589         self.params = data.params;
590         Ok(())
591     }
592 
restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()>593     pub fn restore_stream(&mut self, stream_id: u32, state: StreamState) -> Result<()> {
594         if let Some(params) = self.params.get(&stream_id).cloned() {
595             self.set_stream_parameters_raw(params)?;
596         }
597         match state {
598             StreamState::Started => {
599                 // If state != prepared, start will always fail.
600                 // As such, it is fine to only print the first error without returning, as the
601                 // second action will then fail.
602                 if let Err(e) = self.prepare_stream(stream_id) {
603                     error!("failed to prepare stream: {}", e);
604                 };
605                 self.start_stream(stream_id)
606             }
607             StreamState::Prepared => self.prepare_stream(stream_id),
608             // Nothing to do here
609             _ => Ok(()),
610         }
611     }
612 }
613 
614 #[derive(Clone, Copy)]
615 struct ThreadFlags {
616     reporting_events: bool,
617 }
618 
619 #[derive(EventToken)]
620 enum Token {
621     Notification,
622     TxBufferMsg,
623     RxBufferMsg,
624     EventMsg,
625 }
626 
recv_buffer_status_msg( socket: &UnixSeqpacket, subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, ) -> Result<()>627 fn recv_buffer_status_msg(
628     socket: &UnixSeqpacket,
629     subscribers: &Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
630 ) -> Result<()> {
631     let mut msg: IoStatusMsg = Default::default();
632     let size = socket
633         .recv(msg.as_bytes_mut())
634         .map_err(Error::ServerIOError)?;
635     if size != std::mem::size_of::<IoStatusMsg>() {
636         return Err(Error::ProtocolError(
637             ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<IoStatusMsg>(), size),
638         ));
639     }
640     let mut status = msg.status.status.into();
641     if status == u32::MAX {
642         // Anyone waiting for this would continue to wait for as long as status is
643         // u32::MAX
644         status -= 1;
645     }
646     let latency = msg.status.latency_bytes.into();
647     let offset = msg.buffer_offset as usize;
648     let consumed_len = msg.consumed_len as usize;
649     let promise_opt = subscribers.lock().remove(&offset);
650     match promise_opt {
651         None => error!(
652             "Received an unexpected buffer status message: {}. This is a BUG!!",
653             offset
654         ),
655         Some(sender) => {
656             if let Err(e) = sender.send(BufferReleaseMsg {
657                 status,
658                 latency,
659                 consumed_len,
660             }) {
661                 error!("Failed to notify waiting thread: {:?}", e);
662             }
663         }
664     }
665     Ok(())
666 }
667 
recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event>668 fn recv_event(socket: &UnixSeqpacket) -> Result<virtio_snd_event> {
669     let mut msg: virtio_snd_event = Default::default();
670     let size = socket
671         .recv(msg.as_bytes_mut())
672         .map_err(Error::ServerIOError)?;
673     if size != std::mem::size_of::<virtio_snd_event>() {
674         return Err(Error::ProtocolError(
675             ProtocolErrorKind::UnexpectedMessageSize(std::mem::size_of::<virtio_snd_event>(), size),
676         ));
677     }
678     Ok(msg)
679 }
680 
spawn_recv_thread( tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>, event_notifier: Event, event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>, state: Arc<Mutex<ThreadFlags>>, tx_socket: UnixSeqpacket, rx_socket: UnixSeqpacket, event_socket: UnixSeqpacket, ) -> WorkerThread<Result<()>>681 fn spawn_recv_thread(
682     tx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
683     rx_subscribers: Arc<Mutex<HashMap<usize, Sender<BufferReleaseMsg>>>>,
684     event_notifier: Event,
685     event_queue: Arc<Mutex<VecDeque<virtio_snd_event>>>,
686     state: Arc<Mutex<ThreadFlags>>,
687     tx_socket: UnixSeqpacket,
688     rx_socket: UnixSeqpacket,
689     event_socket: UnixSeqpacket,
690 ) -> WorkerThread<Result<()>> {
691     WorkerThread::start("shm_vios", move |event| {
692         let wait_ctx: WaitContext<Token> = WaitContext::build_with(&[
693             (&tx_socket, Token::TxBufferMsg),
694             (&rx_socket, Token::RxBufferMsg),
695             (&event_socket, Token::EventMsg),
696             (&event, Token::Notification),
697         ])
698         .map_err(Error::WaitContextCreateError)?;
699         let mut running = true;
700         while running {
701             let events = wait_ctx.wait().map_err(Error::WaitError)?;
702             for evt in events {
703                 match evt.token {
704                     Token::TxBufferMsg => recv_buffer_status_msg(&tx_socket, &tx_subscribers)?,
705                     Token::RxBufferMsg => recv_buffer_status_msg(&rx_socket, &rx_subscribers)?,
706                     Token::EventMsg => {
707                         let evt = recv_event(&event_socket)?;
708                         let state_cpy = *state.lock();
709                         if state_cpy.reporting_events {
710                             event_queue.lock().push_back(evt);
711                             event_notifier.signal().map_err(Error::EventWriteError)?;
712                         } // else just drop the events
713                     }
714                     Token::Notification => {
715                         // Just consume the notification and check for termination on the next
716                         // iteration
717                         if let Err(e) = event.wait() {
718                             error!("Failed to consume notification from recv thread: {:?}", e);
719                         }
720                         running = false;
721                     }
722                 }
723             }
724         }
725         Ok(())
726     })
727 }
728 
await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)>729 fn await_status(promise: Receiver<BufferReleaseMsg>) -> Result<(usize, u32)> {
730     let BufferReleaseMsg {
731         status,
732         latency,
733         consumed_len,
734     } = promise.recv().map_err(Error::BufferStatusSenderLost)?;
735     if status == VIRTIO_SND_S_OK {
736         Ok((consumed_len, latency))
737     } else {
738         Err(Error::IOBufferError(status))
739     }
740 }
741 
742 struct IoBufferQueue {
743     socket: UnixSeqpacket,
744     file: File,
745     mmap: MemoryMapping,
746     size: usize,
747     next: Mutex<usize>,
748 }
749 
750 impl IoBufferQueue {
new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue>751     fn new(socket: UnixSeqpacket, mut file: File) -> Result<IoBufferQueue> {
752         let size = file.seek(SeekFrom::End(0)).map_err(Error::FileSizeError)? as usize;
753 
754         let mmap = MemoryMappingBuilder::new(size)
755             .from_file(&file)
756             .build()
757             .map_err(Error::ServerMmapError)?;
758 
759         Ok(IoBufferQueue {
760             socket,
761             file,
762             mmap,
763             size,
764             next: Mutex::new(0),
765         })
766     }
767 
allocate_buffer(&self, size: usize) -> Result<usize>768     fn allocate_buffer(&self, size: usize) -> Result<usize> {
769         if size > self.size {
770             return Err(Error::OutOfSpace);
771         }
772         let mut next_lock = self.next.lock();
773         let offset = if size > self.size - *next_lock {
774             // Can't fit the new buffer at the end of the area, so put it at the beginning
775             0
776         } else {
777             *next_lock
778         };
779         *next_lock = offset + size;
780         Ok(offset)
781     }
782 
buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice>783     fn buffer_at(&self, offset: usize, len: usize) -> Result<VolatileSlice> {
784         self.mmap
785             .get_slice(offset, len)
786             .map_err(Error::VolatileMemoryError)
787     }
788 
try_clone_socket(&self) -> Result<UnixSeqpacket>789     fn try_clone_socket(&self) -> Result<UnixSeqpacket> {
790         self.socket
791             .try_clone()
792             .map_err(Error::UnixSeqpacketDupError)
793     }
794 
send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()>795     fn send_buffer(&self, stream_id: u32, offset: usize, size: usize) -> Result<()> {
796         let msg = IoTransferMsg::new(stream_id, offset, size);
797         seq_socket_send(&self.socket, msg)
798     }
799 
keep_rds(&self) -> Vec<RawDescriptor>800     fn keep_rds(&self) -> Vec<RawDescriptor> {
801         vec![
802             self.file.as_raw_descriptor(),
803             self.socket.as_raw_descriptor(),
804         ]
805     }
806 }
807 
808 /// Groups the parameters used to configure a stream prior to using it.
809 pub struct VioSStreamParams {
810     pub buffer_bytes: u32,
811     pub period_bytes: u32,
812     pub features: u32,
813     pub channels: u8,
814     pub format: u8,
815     pub rate: u8,
816 }
817 
818 impl From<(u32, VioSStreamParams)> for virtio_snd_pcm_set_params {
from(val: (u32, VioSStreamParams)) -> Self819     fn from(val: (u32, VioSStreamParams)) -> Self {
820         virtio_snd_pcm_set_params {
821             hdr: virtio_snd_pcm_hdr {
822                 hdr: virtio_snd_hdr {
823                     code: VIRTIO_SND_R_PCM_SET_PARAMS.into(),
824                 },
825                 stream_id: val.0.into(),
826             },
827             buffer_bytes: val.1.buffer_bytes.into(),
828             period_bytes: val.1.period_bytes.into(),
829             features: val.1.features.into(),
830             channels: val.1.channels,
831             format: val.1.format,
832             rate: val.1.rate,
833             padding: 0u8,
834         }
835     }
836 }
837 
send_cmd<T: AsBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()>838 fn send_cmd<T: AsBytes>(control_socket: &UnixSeqpacket, data: T) -> Result<()> {
839     seq_socket_send(control_socket, data)?;
840     recv_cmd_status(control_socket)
841 }
842 
recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()>843 fn recv_cmd_status(control_socket: &UnixSeqpacket) -> Result<()> {
844     let mut status: virtio_snd_hdr = Default::default();
845     control_socket
846         .recv(status.as_bytes_mut())
847         .map_err(Error::ServerIOError)?;
848     if status.code.to_native() == VIRTIO_SND_S_OK {
849         Ok(())
850     } else {
851         Err(Error::CommandFailed(status.code.to_native()))
852     }
853 }
854 
seq_socket_send<T: AsBytes>(socket: &UnixSeqpacket, data: T) -> Result<()>855 fn seq_socket_send<T: AsBytes>(socket: &UnixSeqpacket, data: T) -> Result<()> {
856     loop {
857         let send_res = socket.send(data.as_bytes());
858         if let Err(e) = send_res {
859             match e.kind() {
860                 // Retry if interrupted
861                 IOErrorKind::Interrupted => continue,
862                 _ => return Err(Error::ServerIOError(e)),
863             }
864         }
865         // Success
866         break;
867     }
868     Ok(())
869 }
870 
871 const VIOS_VERSION: u32 = 2;
872 
873 #[repr(C)]
874 #[derive(
875     Copy,
876     Clone,
877     Default,
878     AsBytes,
879     FromZeroes,
880     FromBytes,
881     Serialize,
882     Deserialize,
883     PartialEq,
884     Eq,
885     Debug,
886 )]
887 struct VioSConfig {
888     version: u32,
889     jacks: u32,
890     streams: u32,
891     chmaps: u32,
892 }
893 
894 struct BufferReleaseMsg {
895     status: u32,
896     latency: u32,
897     consumed_len: usize,
898 }
899 
900 #[repr(C)]
901 #[derive(Copy, Clone, AsBytes, FromZeroes, FromBytes)]
902 struct IoTransferMsg {
903     io_xfer: virtio_snd_pcm_xfer,
904     buffer_offset: u32,
905     buffer_len: u32,
906 }
907 
908 impl IoTransferMsg {
new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg909     fn new(stream_id: u32, buffer_offset: usize, buffer_len: usize) -> IoTransferMsg {
910         IoTransferMsg {
911             io_xfer: virtio_snd_pcm_xfer {
912                 stream_id: stream_id.into(),
913             },
914             buffer_offset: buffer_offset as u32,
915             buffer_len: buffer_len as u32,
916         }
917     }
918 }
919 
920 #[repr(C)]
921 #[derive(Copy, Clone, Default, AsBytes, FromZeroes, FromBytes)]
922 struct IoStatusMsg {
923     status: virtio_snd_pcm_status,
924     buffer_offset: u32,
925     consumed_len: u32,
926 }
927