1 // Copyright 2019 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 use std::cmp;
6 use std::io;
7 use std::io::Write;
8 use std::iter::FromIterator;
9 use std::marker::PhantomData;
10 use std::mem::size_of;
11 use std::mem::MaybeUninit;
12 use std::ptr::copy_nonoverlapping;
13 use std::sync::Arc;
14
15 use anyhow::Context;
16 use base::FileReadWriteAtVolatile;
17 use base::FileReadWriteVolatile;
18 use base::VolatileSlice;
19 use cros_async::MemRegion;
20 use cros_async::MemRegionIter;
21 use data_model::Le16;
22 use data_model::Le32;
23 use data_model::Le64;
24 use disk::AsyncDisk;
25 use smallvec::SmallVec;
26 use vm_memory::GuestAddress;
27 use vm_memory::GuestMemory;
28 use zerocopy::AsBytes;
29 use zerocopy::FromBytes;
30 use zerocopy::FromZeroes;
31
32 use super::DescriptorChain;
33 use crate::virtio::SplitDescriptorChain;
34
35 struct DescriptorChainRegions {
36 regions: SmallVec<[MemRegion; 2]>,
37
38 // Index of the current region in `regions`.
39 current_region_index: usize,
40
41 // Number of bytes consumed in the current region.
42 current_region_offset: usize,
43
44 // Total bytes consumed in the entire descriptor chain.
45 bytes_consumed: usize,
46 }
47
48 impl DescriptorChainRegions {
new(regions: SmallVec<[MemRegion; 2]>) -> Self49 fn new(regions: SmallVec<[MemRegion; 2]>) -> Self {
50 DescriptorChainRegions {
51 regions,
52 current_region_index: 0,
53 current_region_offset: 0,
54 bytes_consumed: 0,
55 }
56 }
57
available_bytes(&self) -> usize58 fn available_bytes(&self) -> usize {
59 // This is guaranteed not to overflow because the total length of the chain is checked
60 // during all creations of `DescriptorChain` (see `DescriptorChain::new()`).
61 self.get_remaining_regions()
62 .fold(0usize, |count, region| count + region.len)
63 }
64
bytes_consumed(&self) -> usize65 fn bytes_consumed(&self) -> usize {
66 self.bytes_consumed
67 }
68
69 /// Returns all the remaining buffers in the `DescriptorChain`. Calling this function does not
70 /// consume any bytes from the `DescriptorChain`. Instead callers should use the `consume`
71 /// method to advance the `DescriptorChain`. Multiple calls to `get` with no intervening calls
72 /// to `consume` will return the same data.
get_remaining_regions(&self) -> MemRegionIter73 fn get_remaining_regions(&self) -> MemRegionIter {
74 MemRegionIter::new(&self.regions[self.current_region_index..])
75 .skip_bytes(self.current_region_offset)
76 }
77
78 /// Like `get_remaining_regions` but guarantees that the combined length of all the returned
79 /// iovecs is not greater than `count`. The combined length of the returned iovecs may be less
80 /// than `count` but will always be greater than 0 as long as there is still space left in the
81 /// `DescriptorChain`.
get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter82 fn get_remaining_regions_with_count(&self, count: usize) -> MemRegionIter {
83 MemRegionIter::new(&self.regions[self.current_region_index..])
84 .skip_bytes(self.current_region_offset)
85 .take_bytes(count)
86 }
87
88 /// Returns all the remaining buffers in the `DescriptorChain` as `VolatileSlice`s of the given
89 /// `GuestMemory`. Calling this function does not consume any bytes from the `DescriptorChain`.
90 /// Instead callers should use the `consume` method to advance the `DescriptorChain`. Multiple
91 /// calls to `get` with no intervening calls to `consume` will return the same data.
get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]>92 fn get_remaining<'mem>(&self, mem: &'mem GuestMemory) -> SmallVec<[VolatileSlice<'mem>; 16]> {
93 self.get_remaining_regions()
94 .filter_map(|region| {
95 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
96 .ok()
97 })
98 .collect()
99 }
100
101 /// Like 'get_remaining_regions_with_count' except convert the offsets to volatile slices in
102 /// the 'GuestMemory' given by 'mem'.
get_remaining_with_count<'mem>( &self, mem: &'mem GuestMemory, count: usize, ) -> SmallVec<[VolatileSlice<'mem>; 16]>103 fn get_remaining_with_count<'mem>(
104 &self,
105 mem: &'mem GuestMemory,
106 count: usize,
107 ) -> SmallVec<[VolatileSlice<'mem>; 16]> {
108 self.get_remaining_regions_with_count(count)
109 .filter_map(|region| {
110 mem.get_slice_at_addr(GuestAddress(region.offset), region.len)
111 .ok()
112 })
113 .collect()
114 }
115
116 /// Consumes `count` bytes from the `DescriptorChain`. If `count` is larger than
117 /// `self.available_bytes()` then all remaining bytes in the `DescriptorChain` will be consumed.
consume(&mut self, mut count: usize)118 fn consume(&mut self, mut count: usize) {
119 while let Some(region) = self.regions.get(self.current_region_index) {
120 let region_remaining = region.len - self.current_region_offset;
121 if count < region_remaining {
122 // The remaining count to consume is less than the remaining un-consumed length of
123 // the current region. Adjust the region offset without advancing to the next region
124 // and stop.
125 self.current_region_offset += count;
126 self.bytes_consumed += count;
127 return;
128 }
129
130 // The current region has been exhausted. Advance to the next region.
131 self.current_region_index += 1;
132 self.current_region_offset = 0;
133
134 self.bytes_consumed += region_remaining;
135 count -= region_remaining;
136 }
137 }
138
split_at(&mut self, offset: usize) -> DescriptorChainRegions139 fn split_at(&mut self, offset: usize) -> DescriptorChainRegions {
140 let mut other = DescriptorChainRegions {
141 regions: self.regions.clone(),
142 current_region_index: self.current_region_index,
143 current_region_offset: self.current_region_offset,
144 bytes_consumed: self.bytes_consumed,
145 };
146 other.consume(offset);
147 other.bytes_consumed = 0;
148
149 let mut rem = offset;
150 let mut end = self.current_region_index;
151 for region in &mut self.regions[self.current_region_index..] {
152 if rem <= region.len {
153 region.len = rem;
154 break;
155 }
156
157 end += 1;
158 rem -= region.len;
159 }
160
161 self.regions.truncate(end + 1);
162
163 other
164 }
165 }
166
167 /// Provides high-level interface over the sequence of memory regions
168 /// defined by readable descriptors in the descriptor chain.
169 ///
170 /// Note that virtio spec requires driver to place any device-writable
171 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
172 /// Reader will skip iterating over descriptor chain when first writable
173 /// descriptor is encountered.
174 pub struct Reader {
175 mem: GuestMemory,
176 regions: DescriptorChainRegions,
177 }
178
179 /// An iterator over `FromBytes` objects on readable descriptors in the descriptor chain.
180 pub struct ReaderIterator<'a, T: FromBytes> {
181 reader: &'a mut Reader,
182 phantom: PhantomData<T>,
183 }
184
185 impl<'a, T: FromBytes> Iterator for ReaderIterator<'a, T> {
186 type Item = io::Result<T>;
187
next(&mut self) -> Option<io::Result<T>>188 fn next(&mut self) -> Option<io::Result<T>> {
189 if self.reader.available_bytes() == 0 {
190 None
191 } else {
192 Some(self.reader.read_obj())
193 }
194 }
195 }
196
197 impl Reader {
198 /// Construct a new Reader wrapper over `readable_regions`.
new_from_regions( mem: &GuestMemory, readable_regions: SmallVec<[MemRegion; 2]>, ) -> Reader199 pub fn new_from_regions(
200 mem: &GuestMemory,
201 readable_regions: SmallVec<[MemRegion; 2]>,
202 ) -> Reader {
203 Reader {
204 mem: mem.clone(),
205 regions: DescriptorChainRegions::new(readable_regions),
206 }
207 }
208
209 /// Reads an object from the descriptor chain buffer without consuming it.
peek_obj<T: FromBytes>(&self) -> io::Result<T>210 pub fn peek_obj<T: FromBytes>(&self) -> io::Result<T> {
211 let mut obj = MaybeUninit::uninit();
212
213 // SAFETY: We pass a valid pointer and size of `obj`.
214 let copied = unsafe {
215 copy_regions_to_mut_ptr(
216 &self.mem,
217 self.get_remaining_regions(),
218 obj.as_mut_ptr() as *mut u8,
219 size_of::<T>(),
220 )?
221 };
222 if copied != size_of::<T>() {
223 return Err(io::Error::from(io::ErrorKind::UnexpectedEof));
224 }
225
226 // SAFETY: `FromBytes` guarantees any set of initialized bytes is a valid value for `T`, and
227 // we initialized all bytes in `obj` in the copy above.
228 Ok(unsafe { obj.assume_init() })
229 }
230
231 /// Reads and consumes an object from the descriptor chain buffer.
read_obj<T: FromBytes>(&mut self) -> io::Result<T>232 pub fn read_obj<T: FromBytes>(&mut self) -> io::Result<T> {
233 let obj = self.peek_obj::<T>()?;
234 self.consume(size_of::<T>());
235 Ok(obj)
236 }
237
238 /// Reads objects by consuming all the remaining data in the descriptor chain buffer and returns
239 /// them as a collection. Returns an error if the size of the remaining data is indivisible by
240 /// the size of an object of type `T`.
collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C241 pub fn collect<C: FromIterator<io::Result<T>>, T: FromBytes>(&mut self) -> C {
242 self.iter().collect()
243 }
244
245 /// Creates an iterator for sequentially reading `FromBytes` objects from the `Reader`.
246 /// Unlike `collect`, this doesn't consume all the remaining data in the `Reader` and
247 /// doesn't require the objects to be stored in a separate collection.
iter<T: FromBytes>(&mut self) -> ReaderIterator<T>248 pub fn iter<T: FromBytes>(&mut self) -> ReaderIterator<T> {
249 ReaderIterator {
250 reader: self,
251 phantom: PhantomData,
252 }
253 }
254
255 /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
256 /// bytes remaining. Returns the number of bytes read.
read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize257 pub fn read_to_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
258 let mut read = 0usize;
259 let mut dst = slice;
260 for src in self.get_remaining() {
261 src.copy_to_volatile_slice(dst);
262 let copied = std::cmp::min(src.size(), dst.size());
263 read += copied;
264 dst = match dst.offset(copied) {
265 Ok(v) => v,
266 Err(_) => break, // The slice is fully consumed
267 };
268 }
269 self.regions.consume(read);
270 read
271 }
272
273 /// Reads data from the descriptor chain buffer and passes the `VolatileSlice`s to the callback
274 /// `cb`.
read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>( &mut self, cb: C, count: usize, ) -> usize275 pub fn read_to_cb<C: FnOnce(&[VolatileSlice]) -> usize>(
276 &mut self,
277 cb: C,
278 count: usize,
279 ) -> usize {
280 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
281 let written = cb(&iovs[..]);
282 self.regions.consume(written);
283 written
284 }
285
286 /// Reads data from the descriptor chain buffer into a writable object.
287 /// Returns the number of bytes read from the descriptor chain buffer.
288 /// The number of bytes read can be less than `count` if there isn't
289 /// enough data in the descriptor chain buffer.
read_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, count: usize, ) -> io::Result<usize>290 pub fn read_to<F: FileReadWriteVolatile>(
291 &mut self,
292 mut dst: F,
293 count: usize,
294 ) -> io::Result<usize> {
295 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
296 let written = dst.write_vectored_volatile(&iovs[..])?;
297 self.regions.consume(written);
298 Ok(written)
299 }
300
301 /// Reads data from the descriptor chain buffer into a File at offset `off`.
302 /// Returns the number of bytes read from the descriptor chain buffer.
303 /// The number of bytes read can be less than `count` if there isn't
304 /// enough data in the descriptor chain buffer.
read_to_at<F: FileReadWriteAtVolatile>( &mut self, dst: &F, count: usize, off: u64, ) -> io::Result<usize>305 pub fn read_to_at<F: FileReadWriteAtVolatile>(
306 &mut self,
307 dst: &F,
308 count: usize,
309 off: u64,
310 ) -> io::Result<usize> {
311 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
312 let written = dst.write_vectored_at_volatile(&iovs[..], off)?;
313 self.regions.consume(written);
314 Ok(written)
315 }
316
317 /// Reads data from the descriptor chain similar to 'read_to' except reading 'count' or
318 /// returning an error if 'count' bytes can't be read.
read_exact_to<F: FileReadWriteVolatile>( &mut self, mut dst: F, mut count: usize, ) -> io::Result<()>319 pub fn read_exact_to<F: FileReadWriteVolatile>(
320 &mut self,
321 mut dst: F,
322 mut count: usize,
323 ) -> io::Result<()> {
324 while count > 0 {
325 match self.read_to(&mut dst, count) {
326 Ok(0) => {
327 return Err(io::Error::new(
328 io::ErrorKind::UnexpectedEof,
329 "failed to fill whole buffer",
330 ))
331 }
332 Ok(n) => count -= n,
333 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
334 Err(e) => return Err(e),
335 }
336 }
337
338 Ok(())
339 }
340
341 /// Reads data from the descriptor chain similar to 'read_to_at' except reading 'count' or
342 /// returning an error if 'count' bytes can't be read.
read_exact_to_at<F: FileReadWriteAtVolatile>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> io::Result<()>343 pub fn read_exact_to_at<F: FileReadWriteAtVolatile>(
344 &mut self,
345 dst: &F,
346 mut count: usize,
347 mut off: u64,
348 ) -> io::Result<()> {
349 while count > 0 {
350 match self.read_to_at(dst, count, off) {
351 Ok(0) => {
352 return Err(io::Error::new(
353 io::ErrorKind::UnexpectedEof,
354 "failed to fill whole buffer",
355 ))
356 }
357 Ok(n) => {
358 count -= n;
359 off += n as u64;
360 }
361 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
362 Err(e) => return Err(e),
363 }
364 }
365
366 Ok(())
367 }
368
369 /// Reads data from the descriptor chain buffer into an `AsyncDisk` at offset `off`.
370 /// Returns the number of bytes read from the descriptor chain buffer.
371 /// The number of bytes read can be less than `count` if there isn't
372 /// enough data in the descriptor chain buffer.
read_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, count: usize, off: u64, ) -> disk::Result<usize>373 pub async fn read_to_at_fut<F: AsyncDisk + ?Sized>(
374 &mut self,
375 dst: &F,
376 count: usize,
377 off: u64,
378 ) -> disk::Result<usize> {
379 let written = dst
380 .write_from_mem(
381 off,
382 Arc::new(self.mem.clone()),
383 self.regions.get_remaining_regions_with_count(count),
384 )
385 .await?;
386 self.regions.consume(written);
387 Ok(written)
388 }
389
390 /// Reads exactly `count` bytes from the chain to the disk asynchronously or returns an error if
391 /// not enough data can be read.
read_exact_to_at_fut<F: AsyncDisk + ?Sized>( &mut self, dst: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>392 pub async fn read_exact_to_at_fut<F: AsyncDisk + ?Sized>(
393 &mut self,
394 dst: &F,
395 mut count: usize,
396 mut off: u64,
397 ) -> disk::Result<()> {
398 while count > 0 {
399 let nread = self.read_to_at_fut(dst, count, off).await?;
400 if nread == 0 {
401 return Err(disk::Error::ReadingData(io::Error::new(
402 io::ErrorKind::UnexpectedEof,
403 "failed to write whole buffer",
404 )));
405 }
406 count -= nread;
407 off += nread as u64;
408 }
409
410 Ok(())
411 }
412
413 /// Returns number of bytes available for reading. May return an error if the combined
414 /// lengths of all the buffers in the DescriptorChain would cause an integer overflow.
available_bytes(&self) -> usize415 pub fn available_bytes(&self) -> usize {
416 self.regions.available_bytes()
417 }
418
419 /// Returns number of bytes already read from the descriptor chain buffer.
bytes_read(&self) -> usize420 pub fn bytes_read(&self) -> usize {
421 self.regions.bytes_consumed()
422 }
423
get_remaining_regions(&self) -> MemRegionIter424 pub fn get_remaining_regions(&self) -> MemRegionIter {
425 self.regions.get_remaining_regions()
426 }
427
428 /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Reader`.
429 /// Calling this method does not actually consume any data from the `Reader` and callers should
430 /// call `consume` to advance the `Reader`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>431 pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
432 self.regions.get_remaining(&self.mem)
433 }
434
435 /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
436 /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume(&mut self, amt: usize)437 pub fn consume(&mut self, amt: usize) {
438 self.regions.consume(amt)
439 }
440
441 /// Splits this `Reader` into two at the given offset in the `DescriptorChain` buffer. After the
442 /// split, `self` will be able to read up to `offset` bytes while the returned `Reader` can read
443 /// up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then the
444 /// returned `Reader` will not be able to read any bytes.
split_at(&mut self, offset: usize) -> Reader445 pub fn split_at(&mut self, offset: usize) -> Reader {
446 Reader {
447 mem: self.mem.clone(),
448 regions: self.regions.split_at(offset),
449 }
450 }
451 }
452
453 /// Copy up to `size` bytes from `src` into `dst`.
454 ///
455 /// Returns the total number of bytes copied.
456 ///
457 /// # Safety
458 ///
459 /// The caller must ensure that it is safe to write `size` bytes of data into `dst`.
460 ///
461 /// After the function returns, it is only safe to assume that the number of bytes indicated by the
462 /// return value (which may be less than the requested `size`) have been initialized. Bytes beyond
463 /// that point are not initialized by this function.
copy_regions_to_mut_ptr( mem: &GuestMemory, src: MemRegionIter, dst: *mut u8, size: usize, ) -> io::Result<usize>464 unsafe fn copy_regions_to_mut_ptr(
465 mem: &GuestMemory,
466 src: MemRegionIter,
467 dst: *mut u8,
468 size: usize,
469 ) -> io::Result<usize> {
470 let mut copied = 0;
471 for src_region in src {
472 if copied >= size {
473 break;
474 }
475
476 let remaining = size - copied;
477 let count = cmp::min(remaining, src_region.len);
478
479 let vslice = mem
480 .get_slice_at_addr(GuestAddress(src_region.offset), count)
481 .map_err(|_e| io::Error::from(io::ErrorKind::InvalidData))?;
482
483 // SAFETY: `get_slice_at_addr()` verified that the region points to valid memory, and
484 // the `count` calculation ensures we will write at most `size` bytes into `dst`.
485 unsafe {
486 copy_nonoverlapping(vslice.as_ptr(), dst.add(copied), count);
487 }
488
489 copied += count;
490 }
491
492 Ok(copied)
493 }
494
495 impl io::Read for Reader {
read(&mut self, buf: &mut [u8]) -> io::Result<usize>496 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
497 // SAFETY: We pass a valid pointer and size combination derived from `buf`.
498 let total = unsafe {
499 copy_regions_to_mut_ptr(
500 &self.mem,
501 self.regions.get_remaining_regions(),
502 buf.as_mut_ptr(),
503 buf.len(),
504 )?
505 };
506 self.regions.consume(total);
507 Ok(total)
508 }
509 }
510
511 /// Provides high-level interface over the sequence of memory regions
512 /// defined by writable descriptors in the descriptor chain.
513 ///
514 /// Note that virtio spec requires driver to place any device-writable
515 /// descriptors after any device-readable descriptors (2.6.4.2 in Virtio Spec v1.1).
516 /// Writer will start iterating the descriptors from the first writable one and will
517 /// assume that all following descriptors are writable.
518 pub struct Writer {
519 mem: GuestMemory,
520 regions: DescriptorChainRegions,
521 }
522
523 impl Writer {
524 /// Construct a new Writer wrapper over `writable_regions`.
new_from_regions( mem: &GuestMemory, writable_regions: SmallVec<[MemRegion; 2]>, ) -> Writer525 pub fn new_from_regions(
526 mem: &GuestMemory,
527 writable_regions: SmallVec<[MemRegion; 2]>,
528 ) -> Writer {
529 Writer {
530 mem: mem.clone(),
531 regions: DescriptorChainRegions::new(writable_regions),
532 }
533 }
534
535 /// Writes an object to the descriptor chain buffer.
write_obj<T: AsBytes>(&mut self, val: T) -> io::Result<()>536 pub fn write_obj<T: AsBytes>(&mut self, val: T) -> io::Result<()> {
537 self.write_all(val.as_bytes())
538 }
539
540 /// Writes all objects produced by `iter` into the descriptor chain buffer. Unlike `consume`,
541 /// this doesn't require the values to be stored in an intermediate collection first. It also
542 /// allows callers to choose which elements in a collection to write, for example by using the
543 /// `filter` or `take` methods of the `Iterator` trait.
write_iter<T: AsBytes, I: Iterator<Item = T>>(&mut self, mut iter: I) -> io::Result<()>544 pub fn write_iter<T: AsBytes, I: Iterator<Item = T>>(&mut self, mut iter: I) -> io::Result<()> {
545 iter.try_for_each(|v| self.write_obj(v))
546 }
547
548 /// Writes a collection of objects into the descriptor chain buffer.
consume<T: AsBytes, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()>549 pub fn consume<T: AsBytes, C: IntoIterator<Item = T>>(&mut self, vals: C) -> io::Result<()> {
550 self.write_iter(vals.into_iter())
551 }
552
553 /// Returns number of bytes available for writing. May return an error if the combined
554 /// lengths of all the buffers in the DescriptorChain would cause an overflow.
available_bytes(&self) -> usize555 pub fn available_bytes(&self) -> usize {
556 self.regions.available_bytes()
557 }
558
559 /// Reads data into a volatile slice up to the minimum of the slice's length or the number of
560 /// bytes remaining. Returns the number of bytes read.
write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize561 pub fn write_from_volatile_slice(&mut self, slice: VolatileSlice) -> usize {
562 let mut written = 0usize;
563 let mut src = slice;
564 for dst in self.get_remaining() {
565 src.copy_to_volatile_slice(dst);
566 let copied = std::cmp::min(src.size(), dst.size());
567 written += copied;
568 src = match src.offset(copied) {
569 Ok(v) => v,
570 Err(_) => break, // The slice is fully consumed
571 };
572 }
573 self.regions.consume(written);
574 written
575 }
576
577 /// Writes data to the descriptor chain buffer from a readable object.
578 /// Returns the number of bytes written to the descriptor chain buffer.
579 /// The number of bytes written can be less than `count` if
580 /// there isn't enough data in the descriptor chain buffer.
write_from<F: FileReadWriteVolatile>( &mut self, mut src: F, count: usize, ) -> io::Result<usize>581 pub fn write_from<F: FileReadWriteVolatile>(
582 &mut self,
583 mut src: F,
584 count: usize,
585 ) -> io::Result<usize> {
586 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
587 let read = src.read_vectored_volatile(&iovs[..])?;
588 self.regions.consume(read);
589 Ok(read)
590 }
591
592 /// Writes data to the descriptor chain buffer from a File at offset `off`.
593 /// Returns the number of bytes written to the descriptor chain buffer.
594 /// The number of bytes written can be less than `count` if
595 /// there isn't enough data in the descriptor chain buffer.
write_from_at<F: FileReadWriteAtVolatile>( &mut self, src: &F, count: usize, off: u64, ) -> io::Result<usize>596 pub fn write_from_at<F: FileReadWriteAtVolatile>(
597 &mut self,
598 src: &F,
599 count: usize,
600 off: u64,
601 ) -> io::Result<usize> {
602 let iovs = self.regions.get_remaining_with_count(&self.mem, count);
603 let read = src.read_vectored_at_volatile(&iovs[..], off)?;
604 self.regions.consume(read);
605 Ok(read)
606 }
607
write_all_from<F: FileReadWriteVolatile>( &mut self, mut src: F, mut count: usize, ) -> io::Result<()>608 pub fn write_all_from<F: FileReadWriteVolatile>(
609 &mut self,
610 mut src: F,
611 mut count: usize,
612 ) -> io::Result<()> {
613 while count > 0 {
614 match self.write_from(&mut src, count) {
615 Ok(0) => {
616 return Err(io::Error::new(
617 io::ErrorKind::WriteZero,
618 "failed to write whole buffer",
619 ))
620 }
621 Ok(n) => count -= n,
622 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
623 Err(e) => return Err(e),
624 }
625 }
626
627 Ok(())
628 }
629
write_all_from_at<F: FileReadWriteAtVolatile>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> io::Result<()>630 pub fn write_all_from_at<F: FileReadWriteAtVolatile>(
631 &mut self,
632 src: &F,
633 mut count: usize,
634 mut off: u64,
635 ) -> io::Result<()> {
636 while count > 0 {
637 match self.write_from_at(src, count, off) {
638 Ok(0) => {
639 return Err(io::Error::new(
640 io::ErrorKind::WriteZero,
641 "failed to write whole buffer",
642 ))
643 }
644 Ok(n) => {
645 count -= n;
646 off += n as u64;
647 }
648 Err(ref e) if e.kind() == io::ErrorKind::Interrupted => {}
649 Err(e) => return Err(e),
650 }
651 }
652 Ok(())
653 }
654 /// Writes data to the descriptor chain buffer from an `AsyncDisk` at offset `off`.
655 /// Returns the number of bytes written to the descriptor chain buffer.
656 /// The number of bytes written can be less than `count` if
657 /// there isn't enough data in the descriptor chain buffer.
write_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, count: usize, off: u64, ) -> disk::Result<usize>658 pub async fn write_from_at_fut<F: AsyncDisk + ?Sized>(
659 &mut self,
660 src: &F,
661 count: usize,
662 off: u64,
663 ) -> disk::Result<usize> {
664 let read = src
665 .read_to_mem(
666 off,
667 Arc::new(self.mem.clone()),
668 self.regions.get_remaining_regions_with_count(count),
669 )
670 .await?;
671 self.regions.consume(read);
672 Ok(read)
673 }
674
write_all_from_at_fut<F: AsyncDisk + ?Sized>( &mut self, src: &F, mut count: usize, mut off: u64, ) -> disk::Result<()>675 pub async fn write_all_from_at_fut<F: AsyncDisk + ?Sized>(
676 &mut self,
677 src: &F,
678 mut count: usize,
679 mut off: u64,
680 ) -> disk::Result<()> {
681 while count > 0 {
682 let nwritten = self.write_from_at_fut(src, count, off).await?;
683 if nwritten == 0 {
684 return Err(disk::Error::WritingData(io::Error::new(
685 io::ErrorKind::UnexpectedEof,
686 "failed to write whole buffer",
687 )));
688 }
689 count -= nwritten;
690 off += nwritten as u64;
691 }
692 Ok(())
693 }
694
695 /// Returns number of bytes already written to the descriptor chain buffer.
bytes_written(&self) -> usize696 pub fn bytes_written(&self) -> usize {
697 self.regions.bytes_consumed()
698 }
699
get_remaining_regions(&self) -> MemRegionIter700 pub fn get_remaining_regions(&self) -> MemRegionIter {
701 self.regions.get_remaining_regions()
702 }
703
704 /// Returns a `&[VolatileSlice]` that represents all the remaining data in this `Writer`.
705 /// Calling this method does not actually advance the current position of the `Writer` in the
706 /// buffer and callers should call `consume_bytes` to advance the `Writer`. Not calling
707 /// `consume_bytes` with the amount of data copied into the returned `VolatileSlice`s will
708 /// result in that that data being overwritten the next time data is written into the `Writer`.
get_remaining(&self) -> SmallVec<[VolatileSlice; 16]>709 pub fn get_remaining(&self) -> SmallVec<[VolatileSlice; 16]> {
710 self.regions.get_remaining(&self.mem)
711 }
712
713 /// Consumes `amt` bytes from the underlying descriptor chain. If `amt` is larger than the
714 /// remaining data left in this `Reader`, then all remaining data will be consumed.
consume_bytes(&mut self, amt: usize)715 pub fn consume_bytes(&mut self, amt: usize) {
716 self.regions.consume(amt)
717 }
718
719 /// Splits this `Writer` into two at the given offset in the `DescriptorChain` buffer. After the
720 /// split, `self` will be able to write up to `offset` bytes while the returned `Writer` can
721 /// write up to `available_bytes() - offset` bytes. If `offset > self.available_bytes()`, then
722 /// the returned `Writer` will not be able to write any bytes.
split_at(&mut self, offset: usize) -> Writer723 pub fn split_at(&mut self, offset: usize) -> Writer {
724 Writer {
725 mem: self.mem.clone(),
726 regions: self.regions.split_at(offset),
727 }
728 }
729 }
730
731 impl io::Write for Writer {
write(&mut self, buf: &[u8]) -> io::Result<usize>732 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
733 let mut rem = buf;
734 let mut total = 0;
735 for b in self.regions.get_remaining(&self.mem) {
736 if rem.is_empty() {
737 break;
738 }
739
740 let count = cmp::min(rem.len(), b.size());
741 // SAFETY:
742 // Safe because we have already verified that `vs` points to valid memory.
743 unsafe {
744 copy_nonoverlapping(rem.as_ptr(), b.as_mut_ptr(), count);
745 }
746 rem = &rem[count..];
747 total += count;
748 }
749
750 self.regions.consume(total);
751 Ok(total)
752 }
753
flush(&mut self) -> io::Result<()>754 fn flush(&mut self) -> io::Result<()> {
755 // Nothing to flush since the writes go straight into the buffer.
756 Ok(())
757 }
758 }
759
760 const VIRTQ_DESC_F_NEXT: u16 = 0x1;
761 const VIRTQ_DESC_F_WRITE: u16 = 0x2;
762
763 #[derive(Copy, Clone, PartialEq, Eq)]
764 pub enum DescriptorType {
765 Readable,
766 Writable,
767 }
768
769 #[derive(Copy, Clone, Debug, FromZeroes, FromBytes, AsBytes)]
770 #[repr(C)]
771 struct virtq_desc {
772 addr: Le64,
773 len: Le32,
774 flags: Le16,
775 next: Le16,
776 }
777
778 /// Test utility function to create a descriptor chain in guest memory.
create_descriptor_chain( memory: &GuestMemory, descriptor_array_addr: GuestAddress, mut buffers_start_addr: GuestAddress, descriptors: Vec<(DescriptorType, u32)>, spaces_between_regions: u32, ) -> anyhow::Result<DescriptorChain>779 pub fn create_descriptor_chain(
780 memory: &GuestMemory,
781 descriptor_array_addr: GuestAddress,
782 mut buffers_start_addr: GuestAddress,
783 descriptors: Vec<(DescriptorType, u32)>,
784 spaces_between_regions: u32,
785 ) -> anyhow::Result<DescriptorChain> {
786 let descriptors_len = descriptors.len();
787 for (index, (type_, size)) in descriptors.into_iter().enumerate() {
788 let mut flags = 0;
789 if let DescriptorType::Writable = type_ {
790 flags |= VIRTQ_DESC_F_WRITE;
791 }
792 if index + 1 < descriptors_len {
793 flags |= VIRTQ_DESC_F_NEXT;
794 }
795
796 let index = index as u16;
797 let desc = virtq_desc {
798 addr: buffers_start_addr.offset().into(),
799 len: size.into(),
800 flags: flags.into(),
801 next: (index + 1).into(),
802 };
803
804 let offset = size + spaces_between_regions;
805 buffers_start_addr = buffers_start_addr
806 .checked_add(offset as u64)
807 .context("Invalid buffers_start_addr)")?;
808
809 let _ = memory.write_obj_at_addr(
810 desc,
811 descriptor_array_addr
812 .checked_add(index as u64 * std::mem::size_of::<virtq_desc>() as u64)
813 .context("Invalid descriptor_array_addr")?,
814 );
815 }
816
817 let chain = SplitDescriptorChain::new(memory, descriptor_array_addr, 0x100, 0);
818 DescriptorChain::new(chain, memory, 0)
819 }
820
821 #[cfg(test)]
822 mod tests {
823 use std::fs::File;
824 use std::io::Read;
825
826 use cros_async::Executor;
827 use tempfile::tempfile;
828 use tempfile::NamedTempFile;
829
830 use super::*;
831
832 #[test]
reader_test_simple_chain()833 fn reader_test_simple_chain() {
834 use DescriptorType::*;
835
836 let memory_start_addr = GuestAddress(0x0);
837 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
838
839 let mut chain = create_descriptor_chain(
840 &memory,
841 GuestAddress(0x0),
842 GuestAddress(0x100),
843 vec![
844 (Readable, 8),
845 (Readable, 16),
846 (Readable, 18),
847 (Readable, 64),
848 ],
849 0,
850 )
851 .expect("create_descriptor_chain failed");
852 let reader = &mut chain.reader;
853 assert_eq!(reader.available_bytes(), 106);
854 assert_eq!(reader.bytes_read(), 0);
855
856 let mut buffer = [0u8; 64];
857 reader
858 .read_exact(&mut buffer)
859 .expect("read_exact should not fail here");
860
861 assert_eq!(reader.available_bytes(), 42);
862 assert_eq!(reader.bytes_read(), 64);
863
864 match reader.read(&mut buffer) {
865 Err(_) => panic!("read should not fail here"),
866 Ok(length) => assert_eq!(length, 42),
867 }
868
869 assert_eq!(reader.available_bytes(), 0);
870 assert_eq!(reader.bytes_read(), 106);
871 }
872
873 #[test]
writer_test_simple_chain()874 fn writer_test_simple_chain() {
875 use DescriptorType::*;
876
877 let memory_start_addr = GuestAddress(0x0);
878 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
879
880 let mut chain = create_descriptor_chain(
881 &memory,
882 GuestAddress(0x0),
883 GuestAddress(0x100),
884 vec![
885 (Writable, 8),
886 (Writable, 16),
887 (Writable, 18),
888 (Writable, 64),
889 ],
890 0,
891 )
892 .expect("create_descriptor_chain failed");
893 let writer = &mut chain.writer;
894 assert_eq!(writer.available_bytes(), 106);
895 assert_eq!(writer.bytes_written(), 0);
896
897 let buffer = [0; 64];
898 writer
899 .write_all(&buffer)
900 .expect("write_all should not fail here");
901
902 assert_eq!(writer.available_bytes(), 42);
903 assert_eq!(writer.bytes_written(), 64);
904
905 match writer.write(&buffer) {
906 Err(_) => panic!("write should not fail here"),
907 Ok(length) => assert_eq!(length, 42),
908 }
909
910 assert_eq!(writer.available_bytes(), 0);
911 assert_eq!(writer.bytes_written(), 106);
912 }
913
914 #[test]
reader_test_incompatible_chain()915 fn reader_test_incompatible_chain() {
916 use DescriptorType::*;
917
918 let memory_start_addr = GuestAddress(0x0);
919 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
920
921 let mut chain = create_descriptor_chain(
922 &memory,
923 GuestAddress(0x0),
924 GuestAddress(0x100),
925 vec![(Writable, 8)],
926 0,
927 )
928 .expect("create_descriptor_chain failed");
929 let reader = &mut chain.reader;
930 assert_eq!(reader.available_bytes(), 0);
931 assert_eq!(reader.bytes_read(), 0);
932
933 assert!(reader.read_obj::<u8>().is_err());
934
935 assert_eq!(reader.available_bytes(), 0);
936 assert_eq!(reader.bytes_read(), 0);
937 }
938
939 #[test]
writer_test_incompatible_chain()940 fn writer_test_incompatible_chain() {
941 use DescriptorType::*;
942
943 let memory_start_addr = GuestAddress(0x0);
944 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
945
946 let mut chain = create_descriptor_chain(
947 &memory,
948 GuestAddress(0x0),
949 GuestAddress(0x100),
950 vec![(Readable, 8)],
951 0,
952 )
953 .expect("create_descriptor_chain failed");
954 let writer = &mut chain.writer;
955 assert_eq!(writer.available_bytes(), 0);
956 assert_eq!(writer.bytes_written(), 0);
957
958 assert!(writer.write_obj(0u8).is_err());
959
960 assert_eq!(writer.available_bytes(), 0);
961 assert_eq!(writer.bytes_written(), 0);
962 }
963
964 #[test]
reader_failing_io()965 fn reader_failing_io() {
966 use DescriptorType::*;
967
968 let memory_start_addr = GuestAddress(0x0);
969 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
970
971 let mut chain = create_descriptor_chain(
972 &memory,
973 GuestAddress(0x0),
974 GuestAddress(0x100),
975 vec![(Readable, 256), (Readable, 256)],
976 0,
977 )
978 .expect("create_descriptor_chain failed");
979
980 let reader = &mut chain.reader;
981
982 // Open a file in read-only mode so writes to it to trigger an I/O error.
983 let device_file = if cfg!(windows) { "NUL" } else { "/dev/zero" };
984 let mut ro_file = File::open(device_file).expect("failed to open device file");
985
986 reader
987 .read_exact_to(&mut ro_file, 512)
988 .expect_err("successfully read more bytes than SharedMemory size");
989
990 // The write above should have failed entirely, so we end up not writing any bytes at all.
991 assert_eq!(reader.available_bytes(), 512);
992 assert_eq!(reader.bytes_read(), 0);
993 }
994
995 #[test]
writer_failing_io()996 fn writer_failing_io() {
997 use DescriptorType::*;
998
999 let memory_start_addr = GuestAddress(0x0);
1000 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1001
1002 let mut chain = create_descriptor_chain(
1003 &memory,
1004 GuestAddress(0x0),
1005 GuestAddress(0x100),
1006 vec![(Writable, 256), (Writable, 256)],
1007 0,
1008 )
1009 .expect("create_descriptor_chain failed");
1010
1011 let writer = &mut chain.writer;
1012
1013 let mut file = tempfile().unwrap();
1014
1015 file.set_len(384).unwrap();
1016
1017 writer
1018 .write_all_from(&mut file, 512)
1019 .expect_err("successfully wrote more bytes than in SharedMemory");
1020
1021 assert_eq!(writer.available_bytes(), 128);
1022 assert_eq!(writer.bytes_written(), 384);
1023 }
1024
1025 #[test]
reader_writer_shared_chain()1026 fn reader_writer_shared_chain() {
1027 use DescriptorType::*;
1028
1029 let memory_start_addr = GuestAddress(0x0);
1030 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1031
1032 let mut chain = create_descriptor_chain(
1033 &memory,
1034 GuestAddress(0x0),
1035 GuestAddress(0x100),
1036 vec![
1037 (Readable, 16),
1038 (Readable, 16),
1039 (Readable, 96),
1040 (Writable, 64),
1041 (Writable, 1),
1042 (Writable, 3),
1043 ],
1044 0,
1045 )
1046 .expect("create_descriptor_chain failed");
1047 let reader = &mut chain.reader;
1048 let writer = &mut chain.writer;
1049
1050 assert_eq!(reader.bytes_read(), 0);
1051 assert_eq!(writer.bytes_written(), 0);
1052
1053 let mut buffer = Vec::with_capacity(200);
1054
1055 assert_eq!(
1056 reader
1057 .read_to_end(&mut buffer)
1058 .expect("read should not fail here"),
1059 128
1060 );
1061
1062 // The writable descriptors are only 68 bytes long.
1063 writer
1064 .write_all(&buffer[..68])
1065 .expect("write should not fail here");
1066
1067 assert_eq!(reader.available_bytes(), 0);
1068 assert_eq!(reader.bytes_read(), 128);
1069 assert_eq!(writer.available_bytes(), 0);
1070 assert_eq!(writer.bytes_written(), 68);
1071 }
1072
1073 #[test]
reader_writer_shattered_object()1074 fn reader_writer_shattered_object() {
1075 use DescriptorType::*;
1076
1077 let memory_start_addr = GuestAddress(0x0);
1078 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1079
1080 let secret: Le32 = 0x12345678.into();
1081
1082 // Create a descriptor chain with memory regions that are properly separated.
1083 let mut chain_writer = create_descriptor_chain(
1084 &memory,
1085 GuestAddress(0x0),
1086 GuestAddress(0x100),
1087 vec![(Writable, 1), (Writable, 1), (Writable, 1), (Writable, 1)],
1088 123,
1089 )
1090 .expect("create_descriptor_chain failed");
1091 let writer = &mut chain_writer.writer;
1092 writer
1093 .write_obj(secret)
1094 .expect("write_obj should not fail here");
1095
1096 // Now create new descriptor chain pointing to the same memory and try to read it.
1097 let mut chain_reader = create_descriptor_chain(
1098 &memory,
1099 GuestAddress(0x0),
1100 GuestAddress(0x100),
1101 vec![(Readable, 1), (Readable, 1), (Readable, 1), (Readable, 1)],
1102 123,
1103 )
1104 .expect("create_descriptor_chain failed");
1105 let reader = &mut chain_reader.reader;
1106 match reader.read_obj::<Le32>() {
1107 Err(_) => panic!("read_obj should not fail here"),
1108 Ok(read_secret) => assert_eq!(read_secret, secret),
1109 }
1110 }
1111
1112 #[test]
reader_unexpected_eof()1113 fn reader_unexpected_eof() {
1114 use DescriptorType::*;
1115
1116 let memory_start_addr = GuestAddress(0x0);
1117 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1118
1119 let mut chain = create_descriptor_chain(
1120 &memory,
1121 GuestAddress(0x0),
1122 GuestAddress(0x100),
1123 vec![(Readable, 256), (Readable, 256)],
1124 0,
1125 )
1126 .expect("create_descriptor_chain failed");
1127
1128 let reader = &mut chain.reader;
1129
1130 let mut buf = vec![0; 1024];
1131
1132 assert_eq!(
1133 reader
1134 .read_exact(&mut buf[..])
1135 .expect_err("read more bytes than available")
1136 .kind(),
1137 io::ErrorKind::UnexpectedEof
1138 );
1139 }
1140
1141 #[test]
split_border()1142 fn split_border() {
1143 use DescriptorType::*;
1144
1145 let memory_start_addr = GuestAddress(0x0);
1146 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1147
1148 let mut chain = create_descriptor_chain(
1149 &memory,
1150 GuestAddress(0x0),
1151 GuestAddress(0x100),
1152 vec![
1153 (Readable, 16),
1154 (Readable, 16),
1155 (Readable, 96),
1156 (Writable, 64),
1157 (Writable, 1),
1158 (Writable, 3),
1159 ],
1160 0,
1161 )
1162 .expect("create_descriptor_chain failed");
1163 let reader = &mut chain.reader;
1164
1165 let other = reader.split_at(32);
1166 assert_eq!(reader.available_bytes(), 32);
1167 assert_eq!(other.available_bytes(), 96);
1168 }
1169
1170 #[test]
split_middle()1171 fn split_middle() {
1172 use DescriptorType::*;
1173
1174 let memory_start_addr = GuestAddress(0x0);
1175 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1176
1177 let mut chain = create_descriptor_chain(
1178 &memory,
1179 GuestAddress(0x0),
1180 GuestAddress(0x100),
1181 vec![
1182 (Readable, 16),
1183 (Readable, 16),
1184 (Readable, 96),
1185 (Writable, 64),
1186 (Writable, 1),
1187 (Writable, 3),
1188 ],
1189 0,
1190 )
1191 .expect("create_descriptor_chain failed");
1192 let reader = &mut chain.reader;
1193
1194 let other = reader.split_at(24);
1195 assert_eq!(reader.available_bytes(), 24);
1196 assert_eq!(other.available_bytes(), 104);
1197 }
1198
1199 #[test]
split_end()1200 fn split_end() {
1201 use DescriptorType::*;
1202
1203 let memory_start_addr = GuestAddress(0x0);
1204 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1205
1206 let mut chain = create_descriptor_chain(
1207 &memory,
1208 GuestAddress(0x0),
1209 GuestAddress(0x100),
1210 vec![
1211 (Readable, 16),
1212 (Readable, 16),
1213 (Readable, 96),
1214 (Writable, 64),
1215 (Writable, 1),
1216 (Writable, 3),
1217 ],
1218 0,
1219 )
1220 .expect("create_descriptor_chain failed");
1221 let reader = &mut chain.reader;
1222
1223 let other = reader.split_at(128);
1224 assert_eq!(reader.available_bytes(), 128);
1225 assert_eq!(other.available_bytes(), 0);
1226 }
1227
1228 #[test]
split_beginning()1229 fn split_beginning() {
1230 use DescriptorType::*;
1231
1232 let memory_start_addr = GuestAddress(0x0);
1233 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1234
1235 let mut chain = create_descriptor_chain(
1236 &memory,
1237 GuestAddress(0x0),
1238 GuestAddress(0x100),
1239 vec![
1240 (Readable, 16),
1241 (Readable, 16),
1242 (Readable, 96),
1243 (Writable, 64),
1244 (Writable, 1),
1245 (Writable, 3),
1246 ],
1247 0,
1248 )
1249 .expect("create_descriptor_chain failed");
1250 let reader = &mut chain.reader;
1251
1252 let other = reader.split_at(0);
1253 assert_eq!(reader.available_bytes(), 0);
1254 assert_eq!(other.available_bytes(), 128);
1255 }
1256
1257 #[test]
split_outofbounds()1258 fn split_outofbounds() {
1259 use DescriptorType::*;
1260
1261 let memory_start_addr = GuestAddress(0x0);
1262 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1263
1264 let mut chain = create_descriptor_chain(
1265 &memory,
1266 GuestAddress(0x0),
1267 GuestAddress(0x100),
1268 vec![
1269 (Readable, 16),
1270 (Readable, 16),
1271 (Readable, 96),
1272 (Writable, 64),
1273 (Writable, 1),
1274 (Writable, 3),
1275 ],
1276 0,
1277 )
1278 .expect("create_descriptor_chain failed");
1279 let reader = &mut chain.reader;
1280
1281 let other = reader.split_at(256);
1282 assert_eq!(
1283 other.available_bytes(),
1284 0,
1285 "Reader returned from out-of-bounds split still has available bytes"
1286 );
1287 }
1288
1289 #[test]
read_full()1290 fn read_full() {
1291 use DescriptorType::*;
1292
1293 let memory_start_addr = GuestAddress(0x0);
1294 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1295
1296 let mut chain = create_descriptor_chain(
1297 &memory,
1298 GuestAddress(0x0),
1299 GuestAddress(0x100),
1300 vec![(Readable, 16), (Readable, 16), (Readable, 16)],
1301 0,
1302 )
1303 .expect("create_descriptor_chain failed");
1304 let reader = &mut chain.reader;
1305
1306 let mut buf = [0u8; 64];
1307 assert_eq!(
1308 reader.read(&mut buf[..]).expect("failed to read to buffer"),
1309 48
1310 );
1311 }
1312
1313 #[test]
write_full()1314 fn write_full() {
1315 use DescriptorType::*;
1316
1317 let memory_start_addr = GuestAddress(0x0);
1318 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1319
1320 let mut chain = create_descriptor_chain(
1321 &memory,
1322 GuestAddress(0x0),
1323 GuestAddress(0x100),
1324 vec![(Writable, 16), (Writable, 16), (Writable, 16)],
1325 0,
1326 )
1327 .expect("create_descriptor_chain failed");
1328 let writer = &mut chain.writer;
1329
1330 let buf = [0xdeu8; 64];
1331 assert_eq!(
1332 writer.write(&buf[..]).expect("failed to write from buffer"),
1333 48
1334 );
1335 }
1336
1337 #[test]
consume_collect()1338 fn consume_collect() {
1339 use DescriptorType::*;
1340
1341 let memory_start_addr = GuestAddress(0x0);
1342 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1343 let vs: Vec<Le64> = vec![
1344 0x0101010101010101.into(),
1345 0x0202020202020202.into(),
1346 0x0303030303030303.into(),
1347 ];
1348
1349 let mut write_chain = create_descriptor_chain(
1350 &memory,
1351 GuestAddress(0x0),
1352 GuestAddress(0x100),
1353 vec![(Writable, 24)],
1354 0,
1355 )
1356 .expect("create_descriptor_chain failed");
1357 let writer = &mut write_chain.writer;
1358 writer
1359 .consume(vs.clone())
1360 .expect("failed to consume() a vector");
1361
1362 let mut read_chain = create_descriptor_chain(
1363 &memory,
1364 GuestAddress(0x0),
1365 GuestAddress(0x100),
1366 vec![(Readable, 24)],
1367 0,
1368 )
1369 .expect("create_descriptor_chain failed");
1370 let reader = &mut read_chain.reader;
1371 let vs_read = reader
1372 .collect::<io::Result<Vec<Le64>>, _>()
1373 .expect("failed to collect() values");
1374 assert_eq!(vs, vs_read);
1375 }
1376
1377 #[test]
get_remaining_region_with_count()1378 fn get_remaining_region_with_count() {
1379 use DescriptorType::*;
1380
1381 let memory_start_addr = GuestAddress(0x0);
1382 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1383
1384 let chain = create_descriptor_chain(
1385 &memory,
1386 GuestAddress(0x0),
1387 GuestAddress(0x100),
1388 vec![
1389 (Readable, 16),
1390 (Readable, 16),
1391 (Readable, 96),
1392 (Writable, 64),
1393 (Writable, 1),
1394 (Writable, 3),
1395 ],
1396 0,
1397 )
1398 .expect("create_descriptor_chain failed");
1399
1400 let Reader {
1401 mem: _,
1402 mut regions,
1403 } = chain.reader;
1404
1405 let drain = regions
1406 .get_remaining_regions_with_count(usize::MAX)
1407 .fold(0usize, |total, region| total + region.len);
1408 assert_eq!(drain, 128);
1409
1410 let exact = regions
1411 .get_remaining_regions_with_count(32)
1412 .fold(0usize, |total, region| total + region.len);
1413 assert!(exact > 0);
1414 assert!(exact <= 32);
1415
1416 let split = regions
1417 .get_remaining_regions_with_count(24)
1418 .fold(0usize, |total, region| total + region.len);
1419 assert!(split > 0);
1420 assert!(split <= 24);
1421
1422 regions.consume(64);
1423
1424 let first = regions
1425 .get_remaining_regions_with_count(8)
1426 .fold(0usize, |total, region| total + region.len);
1427 assert!(first > 0);
1428 assert!(first <= 8);
1429 }
1430
1431 #[test]
get_remaining_with_count()1432 fn get_remaining_with_count() {
1433 use DescriptorType::*;
1434
1435 let memory_start_addr = GuestAddress(0x0);
1436 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1437
1438 let chain = create_descriptor_chain(
1439 &memory,
1440 GuestAddress(0x0),
1441 GuestAddress(0x100),
1442 vec![
1443 (Readable, 16),
1444 (Readable, 16),
1445 (Readable, 96),
1446 (Writable, 64),
1447 (Writable, 1),
1448 (Writable, 3),
1449 ],
1450 0,
1451 )
1452 .expect("create_descriptor_chain failed");
1453 let Reader {
1454 mem: _,
1455 mut regions,
1456 } = chain.reader;
1457
1458 let drain = regions
1459 .get_remaining_with_count(&memory, usize::MAX)
1460 .iter()
1461 .fold(0usize, |total, iov| total + iov.size());
1462 assert_eq!(drain, 128);
1463
1464 let exact = regions
1465 .get_remaining_with_count(&memory, 32)
1466 .iter()
1467 .fold(0usize, |total, iov| total + iov.size());
1468 assert!(exact > 0);
1469 assert!(exact <= 32);
1470
1471 let split = regions
1472 .get_remaining_with_count(&memory, 24)
1473 .iter()
1474 .fold(0usize, |total, iov| total + iov.size());
1475 assert!(split > 0);
1476 assert!(split <= 24);
1477
1478 regions.consume(64);
1479
1480 let first = regions
1481 .get_remaining_with_count(&memory, 8)
1482 .iter()
1483 .fold(0usize, |total, iov| total + iov.size());
1484 assert!(first > 0);
1485 assert!(first <= 8);
1486 }
1487
1488 #[test]
reader_peek_obj()1489 fn reader_peek_obj() {
1490 use DescriptorType::*;
1491
1492 let memory_start_addr = GuestAddress(0x0);
1493 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1494
1495 // Write test data to memory.
1496 memory
1497 .write_obj_at_addr(Le16::from(0xBEEF), GuestAddress(0x100))
1498 .unwrap();
1499 memory
1500 .write_obj_at_addr(Le16::from(0xDEAD), GuestAddress(0x200))
1501 .unwrap();
1502
1503 let mut chain_reader = create_descriptor_chain(
1504 &memory,
1505 GuestAddress(0x0),
1506 GuestAddress(0x100),
1507 vec![(Readable, 2), (Readable, 2)],
1508 0x100 - 2,
1509 )
1510 .expect("create_descriptor_chain failed");
1511 let reader = &mut chain_reader.reader;
1512
1513 // peek_obj() at the beginning of the chain should return the first object.
1514 let peek1 = reader.peek_obj::<Le16>().unwrap();
1515 assert_eq!(peek1, Le16::from(0xBEEF));
1516
1517 // peek_obj() again should return the same object, since it was not consumed.
1518 let peek2 = reader.peek_obj::<Le16>().unwrap();
1519 assert_eq!(peek2, Le16::from(0xBEEF));
1520
1521 // peek_obj() of an object spanning two descriptors should copy from both.
1522 let peek3 = reader.peek_obj::<Le32>().unwrap();
1523 assert_eq!(peek3, Le32::from(0xDEADBEEF));
1524
1525 // read_obj() should return the first object.
1526 let read1 = reader.read_obj::<Le16>().unwrap();
1527 assert_eq!(read1, Le16::from(0xBEEF));
1528
1529 // peek_obj() of a value that is larger than the rest of the chain should fail.
1530 reader
1531 .peek_obj::<Le32>()
1532 .expect_err("peek_obj past end of chain");
1533
1534 // read_obj() again should return the second object.
1535 let read2 = reader.read_obj::<Le16>().unwrap();
1536 assert_eq!(read2, Le16::from(0xDEAD));
1537
1538 // peek_obj() should fail at the end of the chain.
1539 reader
1540 .peek_obj::<Le16>()
1541 .expect_err("peek_obj past end of chain");
1542 }
1543
1544 #[test]
region_reader_failing_io()1545 fn region_reader_failing_io() {
1546 let ex = Executor::new().unwrap();
1547 ex.run_until(region_reader_failing_io_async(&ex)).unwrap();
1548 }
region_reader_failing_io_async(ex: &Executor)1549 async fn region_reader_failing_io_async(ex: &Executor) {
1550 use DescriptorType::*;
1551
1552 let memory_start_addr = GuestAddress(0x0);
1553 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1554
1555 let mut chain = create_descriptor_chain(
1556 &memory,
1557 GuestAddress(0x0),
1558 GuestAddress(0x100),
1559 vec![(Readable, 256), (Readable, 256)],
1560 0,
1561 )
1562 .expect("create_descriptor_chain failed");
1563
1564 let reader = &mut chain.reader;
1565
1566 // Open a file in read-only mode so writes to it to trigger an I/O error.
1567 let named_temp_file = NamedTempFile::new().expect("failed to create temp file");
1568 let ro_file =
1569 File::open(named_temp_file.path()).expect("failed to open temp file read only");
1570 let async_ro_file = disk::SingleFileDisk::new(ro_file, ex).expect("Failed to crate SFD");
1571
1572 reader
1573 .read_exact_to_at_fut(&async_ro_file, 512, 0)
1574 .await
1575 .expect_err("successfully read more bytes than SingleFileDisk size");
1576
1577 // The write above should have failed entirely, so we end up not writing any bytes at all.
1578 assert_eq!(reader.available_bytes(), 512);
1579 assert_eq!(reader.bytes_read(), 0);
1580 }
1581
1582 #[test]
region_writer_failing_io()1583 fn region_writer_failing_io() {
1584 let ex = Executor::new().unwrap();
1585 ex.run_until(region_writer_failing_io_async(&ex)).unwrap()
1586 }
region_writer_failing_io_async(ex: &Executor)1587 async fn region_writer_failing_io_async(ex: &Executor) {
1588 use DescriptorType::*;
1589
1590 let memory_start_addr = GuestAddress(0x0);
1591 let memory = GuestMemory::new(&[(memory_start_addr, 0x10000)]).unwrap();
1592
1593 let mut chain = create_descriptor_chain(
1594 &memory,
1595 GuestAddress(0x0),
1596 GuestAddress(0x100),
1597 vec![(Writable, 256), (Writable, 256)],
1598 0,
1599 )
1600 .expect("create_descriptor_chain failed");
1601
1602 let writer = &mut chain.writer;
1603
1604 let file = tempfile().expect("failed to create temp file");
1605
1606 file.set_len(384).unwrap();
1607 let async_file = disk::SingleFileDisk::new(file, ex).expect("Failed to crate SFD");
1608
1609 writer
1610 .write_all_from_at_fut(&async_file, 512, 0)
1611 .await
1612 .expect_err("successfully wrote more bytes than in SingleFileDisk");
1613
1614 assert_eq!(writer.available_bytes(), 128);
1615 assert_eq!(writer.bytes_written(), 384);
1616 }
1617 }
1618