xref: /aosp_15_r20/external/crosvm/devices/src/virtio/descriptor_utils.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2019 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 use std::cmp;
6 use std::io;
7 use std::io::Write;
8 use std::iter::FromIterator;
9 use std::marker::PhantomData;
10 use std::mem::size_of;
11 use std::mem::MaybeUninit;
12 use std::ptr::copy_nonoverlapping;
13 use std::sync::Arc;
14 
15 use anyhow::Context;
16 use base::FileReadWriteAtVolatile;
17 use base::FileReadWriteVolatile;
18 use base::VolatileSlice;
19 use cros_async::MemRegion;
20 use cros_async::MemRegionIter;
21 use data_model::Le16;
22 use data_model::Le32;
23 use data_model::Le64;
24 use disk::AsyncDisk;
25 use smallvec::SmallVec;
26 use vm_memory::GuestAddress;
27 use vm_memory::GuestMemory;
28 use zerocopy::AsBytes;
29 use zerocopy::FromBytes;
30 use zerocopy::FromZeroes;
31 
32 use super::DescriptorChain;
33 use crate::virtio::SplitDescriptorChain;
34 
35 struct DescriptorChainRegions {
36     regions: SmallVec<[MemRegion; 2]>,
37 
38     // Index of the current region in `regions`.
39     current_region_index: usize,
40 
41     // Number of bytes consumed in the current region.
42     current_region_offset: usize,
43 
44     // Total bytes consumed in the entire descriptor chain.
45     bytes_consumed: usize,
46 }
47 
48 impl DescriptorChainRegions {
new(regions: SmallVec<[MemRegion; 2]>) -> Self49     fn new(regions: SmallVec<[MemRegion; 2]>) -> Self {
50         DescriptorChainRegions {
51             regions,
52             current_region_index: 0,
53             current_region_offset: 0,
54             bytes_consumed: 0,
55         }
56     }
57 
available_bytes(&self) -> usize58     fn available_bytes(&self) -> usize {
59         // This is guaranteed not to overflow because the total length of the chain is checked
60         // during all creations of `DescriptorChain` (see `DescriptorChain::new()`).
61         self.get_remaining_regions()
62             .fold(0usize, |count, region| count + region.len)
63     }
64 
bytes_consumed(&self) -> usize65     fn bytes_consumed(&self) -> usize {
66         self.bytes_consumed
67     }
68 
69     /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
70     /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
71     /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
72     /// to `consume` will return the same data.
get_remaining_regions(&self) -> MemRegionIter73     fn get_remaining_regions(&self) -> MemRegionIter {
74         MemRegionIter::new(&self.regions[self.current_region_index..])
75             .skip_bytes(self.current_region_offset)
76     }
77 
78     /// Like `get_remaining_regions` but guarantees that the combined length of all the returned
79     /// iovecs is not greater than `count`. The combined length of the returned iovecs may be less
80     /// than `count` but will always be greater than 0 as long as there is still space left in the
81     /// `DescriptorChain`.
get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter82     fn get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter {
83         MemRegionIter::new(&self.regions[self.current_region_index..])
84             .skip_bytes(self.current_region_offset)
85             .take_bytes(count)
86     }
87 
88     /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
89     /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
90     /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
91     /// calls to `get` with no intervening calls to `consume` will return the same data.
get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]>92     fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
93         self.get_remaining_regions()
94             .filter_map(|region| {
95                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
96                     .ok()
97             })
98             .collect()
99     }
100 
101     /// Like 'get_remaining_regions_with_count' except convert the offsets to volatile slices in
102     /// the 'GuestMemory' given by 'mem'.
get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]>103     fn get_remaining_with_count<'mem>(
104         &self,
105         mem: &'mem GuestMemory,
106         count: usize,
107     ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
108         self.get_remaining_regions_with_count(count)
109             .filter_map(|region| {
110                 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
111                     .ok()
112             })
113             .collect()
114     }
115 
116     /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
117     /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
consume(&mut self, mut count: usize)118     fn consume(&mut self, mut count: usize) {
119         while let Some(region) = self.regions.get(self.current_region_index) {
120             let region_remaining = region.len - self.current_region_offset;
121             if count < region_remaining {
122                 // The remaining count to consume is less than the remaining un-consumed length of
123                 // the current region. Adjust the region offset without advancing to the next region
124                 // and stop.
125                 self.current_region_offset += count;
126                 self.bytes_consumed += count;
127                 return;
128             }
129 
130             // The current region has been exhausted. Advance to the next region.
131             self.current_region_index += 1;
132             self.current_region_offset = 0;
133 
134             self.bytes_consumed += region_remaining;
135             count -= region_remaining;
136         }
137     }
138 
split_at(&mut self, offset: usize) -> DescriptorChainRegions139     fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
140         let mut other = DescriptorChainRegions {
141             regions: self.regions.clone(),
142             current_region_index: self.current_region_index,
143             current_region_offset: self.current_region_offset,
144             bytes_consumed: self.bytes_consumed,
145         };
146         other.consume(offset);
147         other.bytes_consumed = 0;
148 
149         let mut rem = offset;
150         let mut end = self.current_region_index;
151         for region in &mut self.regions[self.current_region_index..] {
152             if rem <= region.len {
153                 region.len = rem;
154                 break;
155             }
156 
157             end += 1;
158             rem -= region.len;
159         }
160 
161         self.regions.truncate(end + 1);
162 
163         other
164     }
165 }
166 
167 /// Provides high-level interface over the sequence of memory regions
168 /// defined by readable descriptors in the descriptor chain.
169 ///
170 /// Note that virtio spec requires driver to place any device-writable
171 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
172 /// Reader will skip iterating over descriptor chain when first writable
173 /// descriptor is encountered.
174 pub struct Reader {
175     mem: GuestMemory,
176     regions: DescriptorChainRegions,
177 }
178 
179 /// An iterator over `FromBytes` objects on readable descriptors in the descriptor chain.
180 pub struct ReaderIterator<'a, T: FromBytes> {
181     reader: &'a mut Reader,
182     phantom: PhantomData<T>,
183 }
184 
185 impl<'a, T: FromBytes> Iterator for ReaderIterator<'a, T> {
186     type Item = io::Result<T>;
187 
next(&mut self) -> Option<io::Result<T>>188     fn next(&mut self) -> Option<io::Result<T>> {
189         if self.reader.available_bytes() == 0 {
190             None
191         } else {
192             Some(self.reader.read_obj())
193         }
194     }
195 }
196 
197 impl Reader {
198     /// Construct a new Reader wrapper over `readable_regions`.
new_from_regions( mem: &GuestMemory, readable_regions: SmallVec<[MemRegion; 2]>, ) -> Reader199     pub fn new_from_regions(
200         mem: &GuestMemory,
201         readable_regions: SmallVec<[MemRegion; 2]>,
202     ) -> Reader {
203         Reader {
204             mem: mem.clone(),
205             regions: DescriptorChainRegions::new(readable_regions),
206         }
207     }
208 
209     /// Reads an object from the descriptor chain buffer without consuming it.
peek_obj<T: FromBytes>(&self) -> io::Result<T>210     pub fn peek_obj<T: FromBytes>(&self) -> io::Result<T> {
211         let mut obj = MaybeUninit::uninit();
212 
213         // SAFETY: We pass a valid pointer and size of `obj`.
214         let copied = unsafe {
215             copy_regions_to_mut_ptr(
216                 &self.mem,
217                 self.get_remaining_regions(),
218                 obj.as_mut_ptr() as *mut u8,
219                 size_of::<T>(),
220             )?
221         };
222         if copied != size_of::<T>() {
223             return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
224         }
225 
226         // SAFETY: `FromBytes` guarantees any set of initialized bytes is a valid value for `T`, and
227         // we initialized all bytes in `obj` in the copy above.
228         Ok(unsafe { obj.assume_init() })
229     }
230 
231     /// Reads and consumes an object from the descriptor chain buffer.
read_obj<T: FromBytes>(&mut self) -> io::Result<T>232     pub fn read_obj<T: FromBytes>(&mut self) -> io::Result<T> {
233         let obj = self.peek_obj::<T>()?;
234         self.consume(size_of::<T>());
235         Ok(obj)
236     }
237 
238     /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
239     /// them as a collection. Returns an error if the size of the remaining data is indivisible by
240     /// the size of an object of type `T`.
collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C241     pub fn collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C {
242         self.iter().collect()
243     }
244 
245     /// Creates an iterator for sequentially reading `FromBytes` objects from the `Reader`.
246     /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
247     /// doesn't require the objects to be stored in a separate collection.
iter<T: FromBytes>(&mut self) -> ReaderIterator<T>248     pub fn iter<T: FromBytes>(&mut self) -> ReaderIterator<T> {
249         ReaderIterator {
250             reader: self,
251             phantom: PhantomData,
252         }
253     }
254 
255     /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
256     /// bytes remaining. Returns the number of bytes read.
read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize257     pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
258         let mut read = 0usize;
259         let mut dst = slice;
260         for src in self.get_remaining() {
261             src.copy_to_volatile_slice(dst);
262             let copied = std::cmp::min(src.size(), dst.size());
263             read += copied;
264             dst = match dst.offset(copied) {
265                 Ok(v) => v,
266                 Err(_) => break, // The slice is fully consumed
267             };
268         }
269         self.regions.consume(read);
270         read
271     }
272 
273     /// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback
274     /// `cb`.
read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>( &mut self, cb: C, count: usize, ) -> usize275     pub fn read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>(
276         &mut self,
277         cb: C,
278         count: usize,
279     ) -> usize {
280         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
281         let written = cb(&iovs[..]);
282         self.regions.consume(written);
283         written
284     }
285 
286     /// Reads data from the descriptor chain buffer into a writable object.
287     /// Returns the number of bytes read from the descriptor chain buffer.
288     /// The number of bytes read can be less than `count` if there isn't
289     /// enough data in the descriptor chain buffer.
read_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, count: usize, ) -> io::Result<usize>290     pub fn read_to<F: FileReadWriteVolatile>(
291         &mut self,
292         mut dst: F,
293         count: usize,
294     ) -> io::Result<usize> {
295         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
296         let written = dst.write_vectored_volatile(&iovs[..])?;
297         self.regions.consume(written);
298         Ok(written)
299     }
300 
301     /// Reads data from the descriptor chain buffer into a File at offset `off`.
302     /// Returns the number of bytes read from the descriptor chain buffer.
303     /// The number of bytes read can be less than `count` if there isn't
304     /// enough data in the descriptor chain buffer.
read_to_at<F: FileReadWriteAtVolatile>( &mut self, dst: &F, count: usize, off: u64, ) -> io::Result<usize>305     pub fn read_to_at<F: FileReadWriteAtVolatile>(
306         &mut self,
307         dst: &F,
308         count: usize,
309         off: u64,
310     ) -> io::Result<usize> {
311         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
312         let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
313         self.regions.consume(written);
314         Ok(written)
315     }
316 
317     /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
318     /// returning an error if 'count' bytes can't be read.
read_exact_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()>319     pub fn read_exact_to<F: FileReadWriteVolatile>(
320         &mut self,
321         mut dst: F,
322         mut count: usize,
323     ) -> io::Result<()> {
324         while count > 0 {
325             match self.read_to(&mut dst, count) {
326                 Ok(0) => {
327                     return Err(io::Error::new(
328                         io::ErrorKind::UnexpectedEof,
329                         "failed to fill whole buffer",
330                     ))
331                 }
332                 Ok(n) => count -= n,
333                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
334                 Err(e) => return Err(e),
335             }
336         }
337 
338         Ok(())
339     }
340 
341     /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
342     /// returning an error if 'count' bytes can't be read.
read_exact_to_at<F: FileReadWriteAtVolatile>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> io::Result<()>343     pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
344         &mut self,
345         dst: &F,
346         mut count: usize,
347         mut off: u64,
348     ) -> io::Result<()> {
349         while count > 0 {
350             match self.read_to_at(dst, count, off) {
351                 Ok(0) => {
352                     return Err(io::Error::new(
353                         io::ErrorKind::UnexpectedEof,
354                         "failed to fill whole buffer",
355                     ))
356                 }
357                 Ok(n) => {
358                     count -= n;
359                     off += n as u64;
360                 }
361                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
362                 Err(e) => return Err(e),
363             }
364         }
365 
366         Ok(())
367     }
368 
369     /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
370     /// Returns the number of bytes read from the descriptor chain buffer.
371     /// The number of bytes read can be less than `count` if there isn't
372     /// enough data in the descriptor chain buffer.
read_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result<usize>373     pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
374         &mut self,
375         dst: &F,
376         count: usize,
377         off: u64,
378     ) -> disk::Result<usize> {
379         let written = dst
380             .write_from_mem(
381                 off,
382                 Arc::new(self.mem.clone()),
383                 self.regions.get_remaining_regions_with_count(count),
384             )
385             .await?;
386         self.regions.consume(written);
387         Ok(written)
388     }
389 
390     /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
391     /// not enough data can be read.
read_exact_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>392     pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
393         &mut self,
394         dst: &F,
395         mut count: usize,
396         mut off: u64,
397     ) -> disk::Result<()> {
398         while count > 0 {
399             let nread = self.read_to_at_fut(dst, count, off).await?;
400             if nread == 0 {
401                 return Err(disk::Error::ReadingData(io::Error::new(
402                     io::ErrorKind::UnexpectedEof,
403                     "failed to write whole buffer",
404                 )));
405             }
406             count -= nread;
407             off += nread as u64;
408         }
409 
410         Ok(())
411     }
412 
413     /// Returns number of bytes available for reading.  May return an error if the combined
414     /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
available_bytes(&self) -> usize415     pub fn available_bytes(&self) -> usize {
416         self.regions.available_bytes()
417     }
418 
419     /// Returns number of bytes already read from the descriptor chain buffer.
bytes_read(&self) -> usize420     pub fn bytes_read(&self) -> usize {
421         self.regions.bytes_consumed()
422     }
423 
get_remaining_regions(&self) -> MemRegionIter424     pub fn get_remaining_regions(&self) -> MemRegionIter {
425         self.regions.get_remaining_regions()
426     }
427 
428     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
429     /// Calling this method does not actually consume any data from the `Reader` and callers should
430     /// call `consume` to advance the `Reader`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>431     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
432         self.regions.get_remaining(&self.mem)
433     }
434 
435     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
436     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume(&mut self, amt: usize)437     pub fn consume(&mut self, amt: usize) {
438         self.regions.consume(amt)
439     }
440 
441     /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
442     /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
443     /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
444     /// returned `Reader` will not be able to read any bytes.
split_at(&mut self, offset: usize) -> Reader445     pub fn split_at(&mut self, offset: usize) -> Reader {
446         Reader {
447             mem: self.mem.clone(),
448             regions: self.regions.split_at(offset),
449         }
450     }
451 }
452 
453 /// Copy up to `size` bytes from `src` into `dst`.
454 ///
455 /// Returns the total number of bytes copied.
456 ///
457 /// # Safety
458 ///
459 /// The caller must ensure that it is safe to write `size` bytes of data into `dst`.
460 ///
461 /// After the function returns, it is only safe to assume that the number of bytes indicated by the
462 /// return value (which may be less than the requested `size`) have been initialized. Bytes beyond
463 /// that point are not initialized by this function.
copy_regions_to_mut_ptr( mem: &GuestMemory, src: MemRegionIter, dst: *mut u8, size: usize, ) -> io::Result<usize>464 unsafe fn copy_regions_to_mut_ptr(
465     mem: &GuestMemory,
466     src: MemRegionIter,
467     dst: *mut u8,
468     size: usize,
469 ) -> io::Result<usize> {
470     let mut copied = 0;
471     for src_region in src {
472         if copied >= size {
473             break;
474         }
475 
476         let remaining = size - copied;
477         let count = cmp::min(remaining, src_region.len);
478 
479         let vslice = mem
480             .get_slice_at_addr(GuestAddress(src_region.offset), count)
481             .map_err(|_e| io::Error::from(io::ErrorKind::InvalidData))?;
482 
483         // SAFETY: `get_slice_at_addr()` verified that the region points to valid memory, and
484         // the `count` calculation ensures we will write at most `size` bytes into `dst`.
485         unsafe {
486             copy_nonoverlapping(vslice.as_ptr(), dst.add(copied), count);
487         }
488 
489         copied += count;
490     }
491 
492     Ok(copied)
493 }
494 
495 impl io::Read for Reader {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>496     fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
497         // SAFETY: We pass a valid pointer and size combination derived from `buf`.
498         let total = unsafe {
499             copy_regions_to_mut_ptr(
500                 &self.mem,
501                 self.regions.get_remaining_regions(),
502                 buf.as_mut_ptr(),
503                 buf.len(),
504             )?
505         };
506         self.regions.consume(total);
507         Ok(total)
508     }
509 }
510 
511 /// Provides high-level interface over the sequence of memory regions
512 /// defined by writable descriptors in the descriptor chain.
513 ///
514 /// Note that virtio spec requires driver to place any device-writable
515 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
516 /// Writer will start iterating the descriptors from the first writable one and will
517 /// assume that all following descriptors are writable.
518 pub struct Writer {
519     mem: GuestMemory,
520     regions: DescriptorChainRegions,
521 }
522 
523 impl Writer {
524     /// Construct a new Writer wrapper over `writable_regions`.
new_from_regions( mem: &GuestMemory, writable_regions: SmallVec<[MemRegion; 2]>, ) -> Writer525     pub fn new_from_regions(
526         mem: &GuestMemory,
527         writable_regions: SmallVec<[MemRegion; 2]>,
528     ) -> Writer {
529         Writer {
530             mem: mem.clone(),
531             regions: DescriptorChainRegions::new(writable_regions),
532         }
533     }
534 
535     /// Writes an object to the descriptor chain buffer.
write_obj<T: AsBytes>(&mut self, val: T) -> io::Result<()>536     pub fn write_obj<T: AsBytes>(&mut self, val: T) -> io::Result<()> {
537         self.write_all(val.as_bytes())
538     }
539 
540     /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
541     /// this doesn't require the values to be stored in an intermediate collection first. It also
542     /// allows callers to choose which elements in a collection to write, for example by using the
543     /// `filter` or `take` methods of the `Iterator` trait.
write_iter<T: AsBytes, I: Iterator<Item = T>>(&mut self, mut iter: I) -> io::Result<()>544     pub fn write_iter<T: AsBytes, I: Iterator<Item = T>>(&mut self, mut iter: I) -> io::Result<()> {
545         iter.try_for_each(|v| self.write_obj(v))
546     }
547 
548     /// Writes a collection of objects into the descriptor chain buffer.
consume<T: AsBytes, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()>549     pub fn consume<T: AsBytes, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()> {
550         self.write_iter(vals.into_iter())
551     }
552 
553     /// Returns number of bytes available for writing.  May return an error if the combined
554     /// lengths of all the buffers in the DescriptorChain would cause an overflow.
available_bytes(&self) -> usize555     pub fn available_bytes(&self) -> usize {
556         self.regions.available_bytes()
557     }
558 
559     /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
560     /// bytes remaining. Returns the number of bytes read.
write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize561     pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
562         let mut written = 0usize;
563         let mut src = slice;
564         for dst in self.get_remaining() {
565             src.copy_to_volatile_slice(dst);
566             let copied = std::cmp::min(src.size(), dst.size());
567             written += copied;
568             src = match src.offset(copied) {
569                 Ok(v) => v,
570                 Err(_) => break, // The slice is fully consumed
571             };
572         }
573         self.regions.consume(written);
574         written
575     }
576 
577     /// Writes data to the descriptor chain buffer from a readable object.
578     /// Returns the number of bytes written to the descriptor chain buffer.
579     /// The number of bytes written can be less than `count` if
580     /// there isn't enough data in the descriptor chain buffer.
write_from<F: FileReadWriteVolatile>( &mut self, mut src: F, count: usize, ) -> io::Result<usize>581     pub fn write_from<F: FileReadWriteVolatile>(
582         &mut self,
583         mut src: F,
584         count: usize,
585     ) -> io::Result<usize> {
586         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
587         let read = src.read_vectored_volatile(&iovs[..])?;
588         self.regions.consume(read);
589         Ok(read)
590     }
591 
592     /// Writes data to the descriptor chain buffer from a File at offset `off`.
593     /// Returns the number of bytes written to the descriptor chain buffer.
594     /// The number of bytes written can be less than `count` if
595     /// there isn't enough data in the descriptor chain buffer.
write_from_at<F: FileReadWriteAtVolatile>( &mut self, src: &F, count: usize, off: u64, ) -> io::Result<usize>596     pub fn write_from_at<F: FileReadWriteAtVolatile>(
597         &mut self,
598         src: &F,
599         count: usize,
600         off: u64,
601     ) -> io::Result<usize> {
602         let iovs = self.regions.get_remaining_with_count(&self.mem, count);
603         let read = src.read_vectored_at_volatile(&iovs[..], off)?;
604         self.regions.consume(read);
605         Ok(read)
606     }
607 
write_all_from<F: FileReadWriteVolatile>( &mut self, mut src: F, mut count: usize, ) -> io::Result<()>608     pub fn write_all_from<F: FileReadWriteVolatile>(
609         &mut self,
610         mut src: F,
611         mut count: usize,
612     ) -> io::Result<()> {
613         while count > 0 {
614             match self.write_from(&mut src, count) {
615                 Ok(0) => {
616                     return Err(io::Error::new(
617                         io::ErrorKind::WriteZero,
618                         "failed to write whole buffer",
619                     ))
620                 }
621                 Ok(n) => count -= n,
622                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
623                 Err(e) => return Err(e),
624             }
625         }
626 
627         Ok(())
628     }
629 
write_all_from_at<F: FileReadWriteAtVolatile>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> io::Result<()>630     pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
631         &mut self,
632         src: &F,
633         mut count: usize,
634         mut off: u64,
635     ) -> io::Result<()> {
636         while count > 0 {
637             match self.write_from_at(src, count, off) {
638                 Ok(0) => {
639                     return Err(io::Error::new(
640                         io::ErrorKind::WriteZero,
641                         "failed to write whole buffer",
642                     ))
643                 }
644                 Ok(n) => {
645                     count -= n;
646                     off += n as u64;
647                 }
648                 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
649                 Err(e) => return Err(e),
650             }
651         }
652         Ok(())
653     }
654     /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
655     /// Returns the number of bytes written to the descriptor chain buffer.
656     /// The number of bytes written can be less than `count` if
657     /// there isn't enough data in the descriptor chain buffer.
write_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result<usize>658     pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
659         &mut self,
660         src: &F,
661         count: usize,
662         off: u64,
663     ) -> disk::Result<usize> {
664         let read = src
665             .read_to_mem(
666                 off,
667                 Arc::new(self.mem.clone()),
668                 self.regions.get_remaining_regions_with_count(count),
669             )
670             .await?;
671         self.regions.consume(read);
672         Ok(read)
673     }
674 
write_all_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>675     pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
676         &mut self,
677         src: &F,
678         mut count: usize,
679         mut off: u64,
680     ) -> disk::Result<()> {
681         while count > 0 {
682             let nwritten = self.write_from_at_fut(src, count, off).await?;
683             if nwritten == 0 {
684                 return Err(disk::Error::WritingData(io::Error::new(
685                     io::ErrorKind::UnexpectedEof,
686                     "failed to write whole buffer",
687                 )));
688             }
689             count -= nwritten;
690             off += nwritten as u64;
691         }
692         Ok(())
693     }
694 
695     /// Returns number of bytes already written to the descriptor chain buffer.
bytes_written(&self) -> usize696     pub fn bytes_written(&self) -> usize {
697         self.regions.bytes_consumed()
698     }
699 
get_remaining_regions(&self) -> MemRegionIter700     pub fn get_remaining_regions(&self) -> MemRegionIter {
701         self.regions.get_remaining_regions()
702     }
703 
704     /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`.
705     /// Calling this method does not actually advance the current position of the `Writer` in the
706     /// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling
707     /// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will
708     /// result in that that data being overwritten the next time data is written into the `Writer`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>709     pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
710         self.regions.get_remaining(&self.mem)
711     }
712 
713     /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
714     /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume_bytes(&mut self, amt: usize)715     pub fn consume_bytes(&mut self, amt: usize) {
716         self.regions.consume(amt)
717     }
718 
719     /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
720     /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
721     /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
722     /// the returned `Writer` will not be able to write any bytes.
split_at(&mut self, offset: usize) -> Writer723     pub fn split_at(&mut self, offset: usize) -> Writer {
724         Writer {
725             mem: self.mem.clone(),
726             regions: self.regions.split_at(offset),
727         }
728     }
729 }
730 
731 impl io::Write for Writer {
write(&mut self, buf: &[u8]) -> io::Result<usize>732     fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
733         let mut rem = buf;
734         let mut total = 0;
735         for b in self.regions.get_remaining(&self.mem) {
736             if rem.is_empty() {
737                 break;
738             }
739 
740             let count = cmp::min(rem.len(), b.size());
741             // SAFETY:
742             // Safe because we have already verified that `vs` points to valid memory.
743             unsafe {
744                 copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
745             }
746             rem = &rem[count..];
747             total += count;
748         }
749 
750         self.regions.consume(total);
751         Ok(total)
752     }
753 
flush(&mut self) -> io::Result<()>754     fn flush(&mut self) -> io::Result<()> {
755         // Nothing to flush since the writes go straight into the buffer.
756         Ok(())
757     }
758 }
759 
760 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
761 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
762 
763 #[derive(Copy, Clone, PartialEq, Eq)]
764 pub enum DescriptorType {
765     Readable,
766     Writable,
767 }
768 
769 #[derive(Copy, Clone, Debug, FromZeroes, FromBytes, AsBytes)]
770 #[repr(C)]
771 struct virtq_desc {
772     addr: Le64,
773     len: Le32,
774     flags: Le16,
775     next: Le16,
776 }
777 
778 /// Test utility function to create a descriptor chain in guest memory.
create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> anyhow::Result<DescriptorChain>779 pub fn create_descriptor_chain(
780     memory: &GuestMemory,
781     descriptor_array_addr: GuestAddress,
782     mut buffers_start_addr: GuestAddress,
783     descriptors: Vec<(DescriptorType, u32)>,
784     spaces_between_regions: u32,
785 ) -> anyhow::Result<DescriptorChain> {
786     let descriptors_len = descriptors.len();
787     for (index, (type_, size)) in descriptors.into_iter().enumerate() {
788         let mut flags = 0;
789         if let DescriptorType::Writable = type_ {
790             flags |= VIRTQ_DESC_F_WRITE;
791         }
792         if index + 1 < descriptors_len {
793             flags |= VIRTQ_DESC_F_NEXT;
794         }
795 
796         let index = index as u16;
797         let desc = virtq_desc {
798             addr: buffers_start_addr.offset().into(),
799             len: size.into(),
800             flags: flags.into(),
801             next: (index + 1).into(),
802         };
803 
804         let offset = size + spaces_between_regions;
805         buffers_start_addr = buffers_start_addr
806             .checked_add(offset as u64)
807             .context("Invalid buffers_start_addr)")?;
808 
809         let _ = memory.write_obj_at_addr(
810             desc,
811             descriptor_array_addr
812                 .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
813                 .context("Invalid descriptor_array_addr")?,
814         );
815     }
816 
817     let chain = SplitDescriptorChain::new(memory, descriptor_array_addr, 0x100, 0);
818     DescriptorChain::new(chain, memory, 0)
819 }
820 
821 #[cfg(test)]
822 mod tests {
823     use std::fs::File;
824     use std::io::Read;
825 
826     use cros_async::Executor;
827     use tempfile::tempfile;
828     use tempfile::NamedTempFile;
829 
830     use super::*;
831 
832     #[test]
reader_test_simple_chain()833     fn reader_test_simple_chain() {
834         use DescriptorType::*;
835 
836         let memory_start_addr = GuestAddress(0x0);
837         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
838 
839         let mut chain = create_descriptor_chain(
840             &memory,
841             GuestAddress(0x0),
842             GuestAddress(0x100),
843             vec![
844                 (Readable, 8),
845                 (Readable, 16),
846                 (Readable, 18),
847                 (Readable, 64),
848             ],
849             0,
850         )
851         .expect("create_descriptor_chain failed");
852         let reader = &mut chain.reader;
853         assert_eq!(reader.available_bytes(), 106);
854         assert_eq!(reader.bytes_read(), 0);
855 
856         let mut buffer = [0u8; 64];
857         reader
858             .read_exact(&mut buffer)
859             .expect("read_exact should not fail here");
860 
861         assert_eq!(reader.available_bytes(), 42);
862         assert_eq!(reader.bytes_read(), 64);
863 
864         match reader.read(&mut buffer) {
865             Err(_) => panic!("read should not fail here"),
866             Ok(length) => assert_eq!(length, 42),
867         }
868 
869         assert_eq!(reader.available_bytes(), 0);
870         assert_eq!(reader.bytes_read(), 106);
871     }
872 
873     #[test]
writer_test_simple_chain()874     fn writer_test_simple_chain() {
875         use DescriptorType::*;
876 
877         let memory_start_addr = GuestAddress(0x0);
878         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
879 
880         let mut chain = create_descriptor_chain(
881             &memory,
882             GuestAddress(0x0),
883             GuestAddress(0x100),
884             vec![
885                 (Writable, 8),
886                 (Writable, 16),
887                 (Writable, 18),
888                 (Writable, 64),
889             ],
890             0,
891         )
892         .expect("create_descriptor_chain failed");
893         let writer = &mut chain.writer;
894         assert_eq!(writer.available_bytes(), 106);
895         assert_eq!(writer.bytes_written(), 0);
896 
897         let buffer = [0; 64];
898         writer
899             .write_all(&buffer)
900             .expect("write_all should not fail here");
901 
902         assert_eq!(writer.available_bytes(), 42);
903         assert_eq!(writer.bytes_written(), 64);
904 
905         match writer.write(&buffer) {
906             Err(_) => panic!("write should not fail here"),
907             Ok(length) => assert_eq!(length, 42),
908         }
909 
910         assert_eq!(writer.available_bytes(), 0);
911         assert_eq!(writer.bytes_written(), 106);
912     }
913 
914     #[test]
reader_test_incompatible_chain()915     fn reader_test_incompatible_chain() {
916         use DescriptorType::*;
917 
918         let memory_start_addr = GuestAddress(0x0);
919         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
920 
921         let mut chain = create_descriptor_chain(
922             &memory,
923             GuestAddress(0x0),
924             GuestAddress(0x100),
925             vec![(Writable, 8)],
926             0,
927         )
928         .expect("create_descriptor_chain failed");
929         let reader = &mut chain.reader;
930         assert_eq!(reader.available_bytes(), 0);
931         assert_eq!(reader.bytes_read(), 0);
932 
933         assert!(reader.read_obj::<u8>().is_err());
934 
935         assert_eq!(reader.available_bytes(), 0);
936         assert_eq!(reader.bytes_read(), 0);
937     }
938 
939     #[test]
writer_test_incompatible_chain()940     fn writer_test_incompatible_chain() {
941         use DescriptorType::*;
942 
943         let memory_start_addr = GuestAddress(0x0);
944         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
945 
946         let mut chain = create_descriptor_chain(
947             &memory,
948             GuestAddress(0x0),
949             GuestAddress(0x100),
950             vec![(Readable, 8)],
951             0,
952         )
953         .expect("create_descriptor_chain failed");
954         let writer = &mut chain.writer;
955         assert_eq!(writer.available_bytes(), 0);
956         assert_eq!(writer.bytes_written(), 0);
957 
958         assert!(writer.write_obj(0u8).is_err());
959 
960         assert_eq!(writer.available_bytes(), 0);
961         assert_eq!(writer.bytes_written(), 0);
962     }
963 
964     #[test]
reader_failing_io()965     fn reader_failing_io() {
966         use DescriptorType::*;
967 
968         let memory_start_addr = GuestAddress(0x0);
969         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
970 
971         let mut chain = create_descriptor_chain(
972             &memory,
973             GuestAddress(0x0),
974             GuestAddress(0x100),
975             vec![(Readable, 256), (Readable, 256)],
976             0,
977         )
978         .expect("create_descriptor_chain failed");
979 
980         let reader = &mut chain.reader;
981 
982         // Open a file in read-only mode so writes to it to trigger an I/O error.
983         let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" };
984         let mut ro_file = File::open(device_file).expect("failed to open device file");
985 
986         reader
987             .read_exact_to(&mut ro_file, 512)
988             .expect_err("successfully read more bytes than SharedMemory size");
989 
990         // The write above should have failed entirely, so we end up not writing any bytes at all.
991         assert_eq!(reader.available_bytes(), 512);
992         assert_eq!(reader.bytes_read(), 0);
993     }
994 
995     #[test]
writer_failing_io()996     fn writer_failing_io() {
997         use DescriptorType::*;
998 
999         let memory_start_addr = GuestAddress(0x0);
1000         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1001 
1002         let mut chain = create_descriptor_chain(
1003             &memory,
1004             GuestAddress(0x0),
1005             GuestAddress(0x100),
1006             vec![(Writable, 256), (Writable, 256)],
1007             0,
1008         )
1009         .expect("create_descriptor_chain failed");
1010 
1011         let writer = &mut chain.writer;
1012 
1013         let mut file = tempfile().unwrap();
1014 
1015         file.set_len(384).unwrap();
1016 
1017         writer
1018             .write_all_from(&mut file, 512)
1019             .expect_err("successfully wrote more bytes than in SharedMemory");
1020 
1021         assert_eq!(writer.available_bytes(), 128);
1022         assert_eq!(writer.bytes_written(), 384);
1023     }
1024 
1025     #[test]
reader_writer_shared_chain()1026     fn reader_writer_shared_chain() {
1027         use DescriptorType::*;
1028 
1029         let memory_start_addr = GuestAddress(0x0);
1030         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1031 
1032         let mut chain = create_descriptor_chain(
1033             &memory,
1034             GuestAddress(0x0),
1035             GuestAddress(0x100),
1036             vec![
1037                 (Readable, 16),
1038                 (Readable, 16),
1039                 (Readable, 96),
1040                 (Writable, 64),
1041                 (Writable, 1),
1042                 (Writable, 3),
1043             ],
1044             0,
1045         )
1046         .expect("create_descriptor_chain failed");
1047         let reader = &mut chain.reader;
1048         let writer = &mut chain.writer;
1049 
1050         assert_eq!(reader.bytes_read(), 0);
1051         assert_eq!(writer.bytes_written(), 0);
1052 
1053         let mut buffer = Vec::with_capacity(200);
1054 
1055         assert_eq!(
1056             reader
1057                 .read_to_end(&mut buffer)
1058                 .expect("read should not fail here"),
1059             128
1060         );
1061 
1062         // The writable descriptors are only 68 bytes long.
1063         writer
1064             .write_all(&buffer[..68])
1065             .expect("write should not fail here");
1066 
1067         assert_eq!(reader.available_bytes(), 0);
1068         assert_eq!(reader.bytes_read(), 128);
1069         assert_eq!(writer.available_bytes(), 0);
1070         assert_eq!(writer.bytes_written(), 68);
1071     }
1072 
1073     #[test]
reader_writer_shattered_object()1074     fn reader_writer_shattered_object() {
1075         use DescriptorType::*;
1076 
1077         let memory_start_addr = GuestAddress(0x0);
1078         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1079 
1080         let secret: Le32 = 0x12345678.into();
1081 
1082         // Create a descriptor chain with memory regions that are properly separated.
1083         let mut chain_writer = create_descriptor_chain(
1084             &memory,
1085             GuestAddress(0x0),
1086             GuestAddress(0x100),
1087             vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1088             123,
1089         )
1090         .expect("create_descriptor_chain failed");
1091         let writer = &mut chain_writer.writer;
1092         writer
1093             .write_obj(secret)
1094             .expect("write_obj should not fail here");
1095 
1096         // Now create new descriptor chain pointing to the same memory and try to read it.
1097         let mut chain_reader = create_descriptor_chain(
1098             &memory,
1099             GuestAddress(0x0),
1100             GuestAddress(0x100),
1101             vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1102             123,
1103         )
1104         .expect("create_descriptor_chain failed");
1105         let reader = &mut chain_reader.reader;
1106         match reader.read_obj::<Le32>() {
1107             Err(_) => panic!("read_obj should not fail here"),
1108             Ok(read_secret) => assert_eq!(read_secret, secret),
1109         }
1110     }
1111 
1112     #[test]
reader_unexpected_eof()1113     fn reader_unexpected_eof() {
1114         use DescriptorType::*;
1115 
1116         let memory_start_addr = GuestAddress(0x0);
1117         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1118 
1119         let mut chain = create_descriptor_chain(
1120             &memory,
1121             GuestAddress(0x0),
1122             GuestAddress(0x100),
1123             vec![(Readable, 256), (Readable, 256)],
1124             0,
1125         )
1126         .expect("create_descriptor_chain failed");
1127 
1128         let reader = &mut chain.reader;
1129 
1130         let mut buf = vec![0; 1024];
1131 
1132         assert_eq!(
1133             reader
1134                 .read_exact(&mut buf[..])
1135                 .expect_err("read more bytes than available")
1136                 .kind(),
1137             io::ErrorKind::UnexpectedEof
1138         );
1139     }
1140 
1141     #[test]
split_border()1142     fn split_border() {
1143         use DescriptorType::*;
1144 
1145         let memory_start_addr = GuestAddress(0x0);
1146         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1147 
1148         let mut chain = create_descriptor_chain(
1149             &memory,
1150             GuestAddress(0x0),
1151             GuestAddress(0x100),
1152             vec![
1153                 (Readable, 16),
1154                 (Readable, 16),
1155                 (Readable, 96),
1156                 (Writable, 64),
1157                 (Writable, 1),
1158                 (Writable, 3),
1159             ],
1160             0,
1161         )
1162         .expect("create_descriptor_chain failed");
1163         let reader = &mut chain.reader;
1164 
1165         let other = reader.split_at(32);
1166         assert_eq!(reader.available_bytes(), 32);
1167         assert_eq!(other.available_bytes(), 96);
1168     }
1169 
1170     #[test]
split_middle()1171     fn split_middle() {
1172         use DescriptorType::*;
1173 
1174         let memory_start_addr = GuestAddress(0x0);
1175         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1176 
1177         let mut chain = create_descriptor_chain(
1178             &memory,
1179             GuestAddress(0x0),
1180             GuestAddress(0x100),
1181             vec![
1182                 (Readable, 16),
1183                 (Readable, 16),
1184                 (Readable, 96),
1185                 (Writable, 64),
1186                 (Writable, 1),
1187                 (Writable, 3),
1188             ],
1189             0,
1190         )
1191         .expect("create_descriptor_chain failed");
1192         let reader = &mut chain.reader;
1193 
1194         let other = reader.split_at(24);
1195         assert_eq!(reader.available_bytes(), 24);
1196         assert_eq!(other.available_bytes(), 104);
1197     }
1198 
1199     #[test]
split_end()1200     fn split_end() {
1201         use DescriptorType::*;
1202 
1203         let memory_start_addr = GuestAddress(0x0);
1204         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1205 
1206         let mut chain = create_descriptor_chain(
1207             &memory,
1208             GuestAddress(0x0),
1209             GuestAddress(0x100),
1210             vec![
1211                 (Readable, 16),
1212                 (Readable, 16),
1213                 (Readable, 96),
1214                 (Writable, 64),
1215                 (Writable, 1),
1216                 (Writable, 3),
1217             ],
1218             0,
1219         )
1220         .expect("create_descriptor_chain failed");
1221         let reader = &mut chain.reader;
1222 
1223         let other = reader.split_at(128);
1224         assert_eq!(reader.available_bytes(), 128);
1225         assert_eq!(other.available_bytes(), 0);
1226     }
1227 
1228     #[test]
split_beginning()1229     fn split_beginning() {
1230         use DescriptorType::*;
1231 
1232         let memory_start_addr = GuestAddress(0x0);
1233         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1234 
1235         let mut chain = create_descriptor_chain(
1236             &memory,
1237             GuestAddress(0x0),
1238             GuestAddress(0x100),
1239             vec![
1240                 (Readable, 16),
1241                 (Readable, 16),
1242                 (Readable, 96),
1243                 (Writable, 64),
1244                 (Writable, 1),
1245                 (Writable, 3),
1246             ],
1247             0,
1248         )
1249         .expect("create_descriptor_chain failed");
1250         let reader = &mut chain.reader;
1251 
1252         let other = reader.split_at(0);
1253         assert_eq!(reader.available_bytes(), 0);
1254         assert_eq!(other.available_bytes(), 128);
1255     }
1256 
1257     #[test]
split_outofbounds()1258     fn split_outofbounds() {
1259         use DescriptorType::*;
1260 
1261         let memory_start_addr = GuestAddress(0x0);
1262         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1263 
1264         let mut chain = create_descriptor_chain(
1265             &memory,
1266             GuestAddress(0x0),
1267             GuestAddress(0x100),
1268             vec![
1269                 (Readable, 16),
1270                 (Readable, 16),
1271                 (Readable, 96),
1272                 (Writable, 64),
1273                 (Writable, 1),
1274                 (Writable, 3),
1275             ],
1276             0,
1277         )
1278         .expect("create_descriptor_chain failed");
1279         let reader = &mut chain.reader;
1280 
1281         let other = reader.split_at(256);
1282         assert_eq!(
1283             other.available_bytes(),
1284             0,
1285             "Reader returned from out-of-bounds split still has available bytes"
1286         );
1287     }
1288 
1289     #[test]
read_full()1290     fn read_full() {
1291         use DescriptorType::*;
1292 
1293         let memory_start_addr = GuestAddress(0x0);
1294         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1295 
1296         let mut chain = create_descriptor_chain(
1297             &memory,
1298             GuestAddress(0x0),
1299             GuestAddress(0x100),
1300             vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1301             0,
1302         )
1303         .expect("create_descriptor_chain failed");
1304         let reader = &mut chain.reader;
1305 
1306         let mut buf = [0u8; 64];
1307         assert_eq!(
1308             reader.read(&mut buf[..]).expect("failed to read to buffer"),
1309             48
1310         );
1311     }
1312 
1313     #[test]
write_full()1314     fn write_full() {
1315         use DescriptorType::*;
1316 
1317         let memory_start_addr = GuestAddress(0x0);
1318         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1319 
1320         let mut chain = create_descriptor_chain(
1321             &memory,
1322             GuestAddress(0x0),
1323             GuestAddress(0x100),
1324             vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1325             0,
1326         )
1327         .expect("create_descriptor_chain failed");
1328         let writer = &mut chain.writer;
1329 
1330         let buf = [0xdeu8; 64];
1331         assert_eq!(
1332             writer.write(&buf[..]).expect("failed to write from buffer"),
1333             48
1334         );
1335     }
1336 
1337     #[test]
consume_collect()1338     fn consume_collect() {
1339         use DescriptorType::*;
1340 
1341         let memory_start_addr = GuestAddress(0x0);
1342         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1343         let vs: Vec<Le64> = vec![
1344             0x0101010101010101.into(),
1345             0x0202020202020202.into(),
1346             0x0303030303030303.into(),
1347         ];
1348 
1349         let mut write_chain = create_descriptor_chain(
1350             &memory,
1351             GuestAddress(0x0),
1352             GuestAddress(0x100),
1353             vec![(Writable, 24)],
1354             0,
1355         )
1356         .expect("create_descriptor_chain failed");
1357         let writer = &mut write_chain.writer;
1358         writer
1359             .consume(vs.clone())
1360             .expect("failed to consume() a vector");
1361 
1362         let mut read_chain = create_descriptor_chain(
1363             &memory,
1364             GuestAddress(0x0),
1365             GuestAddress(0x100),
1366             vec![(Readable, 24)],
1367             0,
1368         )
1369         .expect("create_descriptor_chain failed");
1370         let reader = &mut read_chain.reader;
1371         let vs_read = reader
1372             .collect::<io::Result<Vec<Le64>>, _>()
1373             .expect("failed to collect() values");
1374         assert_eq!(vs, vs_read);
1375     }
1376 
1377     #[test]
get_remaining_region_with_count()1378     fn get_remaining_region_with_count() {
1379         use DescriptorType::*;
1380 
1381         let memory_start_addr = GuestAddress(0x0);
1382         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1383 
1384         let chain = create_descriptor_chain(
1385             &memory,
1386             GuestAddress(0x0),
1387             GuestAddress(0x100),
1388             vec![
1389                 (Readable, 16),
1390                 (Readable, 16),
1391                 (Readable, 96),
1392                 (Writable, 64),
1393                 (Writable, 1),
1394                 (Writable, 3),
1395             ],
1396             0,
1397         )
1398         .expect("create_descriptor_chain failed");
1399 
1400         let Reader {
1401             mem: _,
1402             mut regions,
1403         } = chain.reader;
1404 
1405         let drain = regions
1406             .get_remaining_regions_with_count(usize::MAX)
1407             .fold(0usize, |total, region| total + region.len);
1408         assert_eq!(drain, 128);
1409 
1410         let exact = regions
1411             .get_remaining_regions_with_count(32)
1412             .fold(0usize, |total, region| total + region.len);
1413         assert!(exact > 0);
1414         assert!(exact <= 32);
1415 
1416         let split = regions
1417             .get_remaining_regions_with_count(24)
1418             .fold(0usize, |total, region| total + region.len);
1419         assert!(split > 0);
1420         assert!(split <= 24);
1421 
1422         regions.consume(64);
1423 
1424         let first = regions
1425             .get_remaining_regions_with_count(8)
1426             .fold(0usize, |total, region| total + region.len);
1427         assert!(first > 0);
1428         assert!(first <= 8);
1429     }
1430 
1431     #[test]
get_remaining_with_count()1432     fn get_remaining_with_count() {
1433         use DescriptorType::*;
1434 
1435         let memory_start_addr = GuestAddress(0x0);
1436         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1437 
1438         let chain = create_descriptor_chain(
1439             &memory,
1440             GuestAddress(0x0),
1441             GuestAddress(0x100),
1442             vec![
1443                 (Readable, 16),
1444                 (Readable, 16),
1445                 (Readable, 96),
1446                 (Writable, 64),
1447                 (Writable, 1),
1448                 (Writable, 3),
1449             ],
1450             0,
1451         )
1452         .expect("create_descriptor_chain failed");
1453         let Reader {
1454             mem: _,
1455             mut regions,
1456         } = chain.reader;
1457 
1458         let drain = regions
1459             .get_remaining_with_count(&memory, usize::MAX)
1460             .iter()
1461             .fold(0usize, |total, iov| total + iov.size());
1462         assert_eq!(drain, 128);
1463 
1464         let exact = regions
1465             .get_remaining_with_count(&memory, 32)
1466             .iter()
1467             .fold(0usize, |total, iov| total + iov.size());
1468         assert!(exact > 0);
1469         assert!(exact <= 32);
1470 
1471         let split = regions
1472             .get_remaining_with_count(&memory, 24)
1473             .iter()
1474             .fold(0usize, |total, iov| total + iov.size());
1475         assert!(split > 0);
1476         assert!(split <= 24);
1477 
1478         regions.consume(64);
1479 
1480         let first = regions
1481             .get_remaining_with_count(&memory, 8)
1482             .iter()
1483             .fold(0usize, |total, iov| total + iov.size());
1484         assert!(first > 0);
1485         assert!(first <= 8);
1486     }
1487 
1488     #[test]
reader_peek_obj()1489     fn reader_peek_obj() {
1490         use DescriptorType::*;
1491 
1492         let memory_start_addr = GuestAddress(0x0);
1493         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1494 
1495         // Write test data to memory.
1496         memory
1497             .write_obj_at_addr(Le16::from(0xBEEF), GuestAddress(0x100))
1498             .unwrap();
1499         memory
1500             .write_obj_at_addr(Le16::from(0xDEAD), GuestAddress(0x200))
1501             .unwrap();
1502 
1503         let mut chain_reader = create_descriptor_chain(
1504             &memory,
1505             GuestAddress(0x0),
1506             GuestAddress(0x100),
1507             vec![(Readable, 2), (Readable, 2)],
1508             0x100 - 2,
1509         )
1510         .expect("create_descriptor_chain failed");
1511         let reader = &mut chain_reader.reader;
1512 
1513         // peek_obj() at the beginning of the chain should return the first object.
1514         let peek1 = reader.peek_obj::<Le16>().unwrap();
1515         assert_eq!(peek1, Le16::from(0xBEEF));
1516 
1517         // peek_obj() again should return the same object, since it was not consumed.
1518         let peek2 = reader.peek_obj::<Le16>().unwrap();
1519         assert_eq!(peek2, Le16::from(0xBEEF));
1520 
1521         // peek_obj() of an object spanning two descriptors should copy from both.
1522         let peek3 = reader.peek_obj::<Le32>().unwrap();
1523         assert_eq!(peek3, Le32::from(0xDEADBEEF));
1524 
1525         // read_obj() should return the first object.
1526         let read1 = reader.read_obj::<Le16>().unwrap();
1527         assert_eq!(read1, Le16::from(0xBEEF));
1528 
1529         // peek_obj() of a value that is larger than the rest of the chain should fail.
1530         reader
1531             .peek_obj::<Le32>()
1532             .expect_err("peek_obj past end of chain");
1533 
1534         // read_obj() again should return the second object.
1535         let read2 = reader.read_obj::<Le16>().unwrap();
1536         assert_eq!(read2, Le16::from(0xDEAD));
1537 
1538         // peek_obj() should fail at the end of the chain.
1539         reader
1540             .peek_obj::<Le16>()
1541             .expect_err("peek_obj past end of chain");
1542     }
1543 
1544     #[test]
region_reader_failing_io()1545     fn region_reader_failing_io() {
1546         let ex = Executor::new().unwrap();
1547         ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
1548     }
region_reader_failing_io_async(ex: &Executor)1549     async fn region_reader_failing_io_async(ex: &Executor) {
1550         use DescriptorType::*;
1551 
1552         let memory_start_addr = GuestAddress(0x0);
1553         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1554 
1555         let mut chain = create_descriptor_chain(
1556             &memory,
1557             GuestAddress(0x0),
1558             GuestAddress(0x100),
1559             vec![(Readable, 256), (Readable, 256)],
1560             0,
1561         )
1562         .expect("create_descriptor_chain failed");
1563 
1564         let reader = &mut chain.reader;
1565 
1566         // Open a file in read-only mode so writes to it to trigger an I/O error.
1567         let named_temp_file = NamedTempFile::new().expect("failed to create temp file");
1568         let ro_file =
1569             File::open(named_temp_file.path()).expect("failed to open temp file read only");
1570         let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
1571 
1572         reader
1573             .read_exact_to_at_fut(&async_ro_file, 512, 0)
1574             .await
1575             .expect_err("successfully read more bytes than SingleFileDisk size");
1576 
1577         // The write above should have failed entirely, so we end up not writing any bytes at all.
1578         assert_eq!(reader.available_bytes(), 512);
1579         assert_eq!(reader.bytes_read(), 0);
1580     }
1581 
1582     #[test]
region_writer_failing_io()1583     fn region_writer_failing_io() {
1584         let ex = Executor::new().unwrap();
1585         ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
1586     }
region_writer_failing_io_async(ex: &Executor)1587     async fn region_writer_failing_io_async(ex: &Executor) {
1588         use DescriptorType::*;
1589 
1590         let memory_start_addr = GuestAddress(0x0);
1591         let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1592 
1593         let mut chain = create_descriptor_chain(
1594             &memory,
1595             GuestAddress(0x0),
1596             GuestAddress(0x100),
1597             vec![(Writable, 256), (Writable, 256)],
1598             0,
1599         )
1600         .expect("create_descriptor_chain failed");
1601 
1602         let writer = &mut chain.writer;
1603 
1604         let file = tempfile().expect("failed to create temp file");
1605 
1606         file.set_len(384).unwrap();
1607         let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
1608 
1609         writer
1610             .write_all_from_at_fut(&async_file, 512, 0)
1611             .await
1612             .expect_err("successfully wrote more bytes than in SingleFileDisk");
1613 
1614         assert_eq!(writer.available_bytes(), 128);
1615         assert_eq!(writer.bytes_written(), 384);
1616     }
1617 }
1618