xref: /aosp_15_r20/external/crosvm/disk/src/zstd.rs (revision bb4ee6a4ae7042d18b07a98463b9c8b875e44b39)
1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4 
5 //! Use seekable zstd archive of raw disk image as read only disk
6 
7 use std::cmp::min;
8 use std::fs::File;
9 use std::io;
10 use std::io::ErrorKind;
11 use std::io::Read;
12 use std::io::Seek;
13 use std::sync::Arc;
14 
15 use anyhow::bail;
16 use anyhow::Context;
17 use async_trait::async_trait;
18 use base::AsRawDescriptor;
19 use base::FileAllocate;
20 use base::FileReadWriteAtVolatile;
21 use base::FileSetLen;
22 use base::RawDescriptor;
23 use base::VolatileSlice;
24 use cros_async::BackingMemory;
25 use cros_async::Executor;
26 use cros_async::IoSource;
27 
28 use crate::AsyncDisk;
29 use crate::DiskFile;
30 use crate::DiskGetLen;
31 use crate::Error as DiskError;
32 use crate::Result as DiskResult;
33 use crate::ToAsyncDisk;
34 
35 // Zstandard frame magic
36 pub const ZSTD_FRAME_MAGIC: u32 = 0xFD2FB528;
37 
38 // Skippable frame magic can be anything between [0x184D2A50, 0x184D2A5F]
39 pub const ZSTD_SKIPPABLE_MAGIC_LOW: u32 = 0x184D2A50;
40 pub const ZSTD_SKIPPABLE_MAGIC_HIGH: u32 = 0x184D2A5F;
41 pub const ZSTD_SEEK_TABLE_MAGIC: u32 = 0x8F92EAB1;
42 
43 pub const ZSTD_DEFAULT_FRAME_SIZE: usize = 128 << 10; // 128KB
44 
45 #[derive(Clone, Debug)]
46 pub struct ZstdSeekTable {
47     // Cumulative sum of decompressed sizes of all frames before the indexed frame.
48     // The last element is the total decompressed size of the zstd archive.
49     cumulative_decompressed_sizes: Vec<u64>,
50     // Cumulative sum of compressed sizes of all frames before the indexed frame.
51     // The last element is the total compressed size of the zstd archive.
52     cumulative_compressed_sizes: Vec<u64>,
53 }
54 
55 impl ZstdSeekTable {
56     /// Read seek table entries from seek_table_entries
from_footer( seek_table_entries: &[u8], num_frames: u32, checksum_flag: bool, ) -> anyhow::Result<ZstdSeekTable>57     pub fn from_footer(
58         seek_table_entries: &[u8],
59         num_frames: u32,
60         checksum_flag: bool,
61     ) -> anyhow::Result<ZstdSeekTable> {
62         let mut cumulative_decompressed_size: u64 = 0;
63         let mut cumulative_compressed_size: u64 = 0;
64         let mut cumulative_decompressed_sizes = Vec::with_capacity(num_frames as usize + 1);
65         let mut cumulative_compressed_sizes = Vec::with_capacity(num_frames as usize + 1);
66         let mut offset = 0;
67         cumulative_decompressed_sizes.push(0);
68         cumulative_compressed_sizes.push(0);
69         for _ in 0..num_frames {
70             let compressed_size = u32::from_le_bytes(
71                 seek_table_entries
72                     .get(offset..offset + 4)
73                     .context("failed to parse seektable entry")?
74                     .try_into()?,
75             );
76             let decompressed_size = u32::from_le_bytes(
77                 seek_table_entries
78                     .get(offset + 4..offset + 8)
79                     .context("failed to parse seektable entry")?
80                     .try_into()?,
81             );
82             cumulative_decompressed_size += decompressed_size as u64;
83             cumulative_compressed_size += compressed_size as u64;
84             cumulative_decompressed_sizes.push(cumulative_decompressed_size);
85             cumulative_compressed_sizes.push(cumulative_compressed_size);
86             offset += 8 + (checksum_flag as usize * 4);
87         }
88         cumulative_decompressed_sizes.push(cumulative_decompressed_size);
89         cumulative_compressed_sizes.push(cumulative_compressed_size);
90 
91         Ok(ZstdSeekTable {
92             cumulative_decompressed_sizes,
93             cumulative_compressed_sizes,
94         })
95     }
96 
97     /// Returns the index of the frame that contains the given decompressed offset.
find_frame_index(&self, decompressed_offset: u64) -> Option<usize>98     pub fn find_frame_index(&self, decompressed_offset: u64) -> Option<usize> {
99         if self.cumulative_decompressed_sizes.is_empty()
100             || decompressed_offset >= *self.cumulative_decompressed_sizes.last().unwrap()
101         {
102             return None;
103         }
104         self.cumulative_decompressed_sizes
105             .partition_point(|&size| size <= decompressed_offset)
106             .checked_sub(1)
107     }
108 }
109 
110 #[derive(Debug)]
111 pub struct ZstdDisk {
112     file: File,
113     seek_table: ZstdSeekTable,
114 }
115 
116 impl ZstdDisk {
from_file(mut file: File) -> anyhow::Result<ZstdDisk>117     pub fn from_file(mut file: File) -> anyhow::Result<ZstdDisk> {
118         // Verify file is large enough to contain a seek table (17 bytes)
119         if file.metadata()?.len() < 17 {
120             return Err(anyhow::anyhow!("File too small to contain zstd seek table"));
121         }
122 
123         // Read last 9 bytes as seek table footer
124         let mut seektable_footer = [0u8; 9];
125         file.seek(std::io::SeekFrom::End(-9))?;
126         file.read_exact(&mut seektable_footer)?;
127 
128         // Verify last 4 bytes of footer is seek table magic
129         if u32::from_le_bytes(seektable_footer[5..9].try_into()?) != ZSTD_SEEK_TABLE_MAGIC {
130             return Err(anyhow::anyhow!("Invalid zstd seek table magic"));
131         }
132 
133         // Get number of frame from seek table
134         let num_frames = u32::from_le_bytes(seektable_footer[0..4].try_into()?);
135 
136         // Read flags from seek table descriptor
137         let checksum_flag = (seektable_footer[4] >> 7) & 1 != 0;
138         if (seektable_footer[4] & 0x7C) != 0 {
139             bail!(
140                 "This zstd seekable decoder cannot parse seek table with non-zero reserved flags"
141             );
142         }
143 
144         let seek_table_entries_size = num_frames * (8 + (checksum_flag as u32 * 4));
145 
146         // Seek to the beginning of the seek table
147         file.seek(std::io::SeekFrom::End(
148             -(9 + seek_table_entries_size as i64),
149         ))?;
150 
151         // Return new ZstdDisk
152         let mut seek_table_entries: Vec<u8> = vec![0u8; seek_table_entries_size as usize];
153         file.read_exact(&mut seek_table_entries)?;
154 
155         let seek_table =
156             ZstdSeekTable::from_footer(&seek_table_entries, num_frames, checksum_flag)?;
157 
158         Ok(ZstdDisk { file, seek_table })
159     }
160 }
161 
162 impl DiskGetLen for ZstdDisk {
get_len(&self) -> std::io::Result<u64>163     fn get_len(&self) -> std::io::Result<u64> {
164         self.seek_table
165             .cumulative_decompressed_sizes
166             .last()
167             .copied()
168             .ok_or(io::ErrorKind::InvalidData.into())
169     }
170 }
171 
172 impl FileSetLen for ZstdDisk {
set_len(&self, _len: u64) -> std::io::Result<()>173     fn set_len(&self, _len: u64) -> std::io::Result<()> {
174         Err(io::Error::new(
175             io::ErrorKind::PermissionDenied,
176             "unsupported operation",
177         ))
178     }
179 }
180 
181 impl AsRawDescriptor for ZstdDisk {
as_raw_descriptor(&self) -> RawDescriptor182     fn as_raw_descriptor(&self) -> RawDescriptor {
183         self.file.as_raw_descriptor()
184     }
185 }
186 
187 struct CompressedReadInstruction {
188     frame_index: usize,
189     read_offset: u64,
190     read_size: u64,
191 }
192 
compresed_frame_read_instruction( seek_table: &ZstdSeekTable, offset: u64, ) -> anyhow::Result<CompressedReadInstruction>193 fn compresed_frame_read_instruction(
194     seek_table: &ZstdSeekTable,
195     offset: u64,
196 ) -> anyhow::Result<CompressedReadInstruction> {
197     let frame_index = seek_table
198         .find_frame_index(offset)
199         .with_context(|| format!("no frame for offset {}", offset))?;
200     let compressed_offset = seek_table.cumulative_compressed_sizes[frame_index];
201     let next_compressed_offset = seek_table
202         .cumulative_compressed_sizes
203         .get(frame_index + 1)
204         .context("Offset out of range (next_compressed_offset overflow)")?;
205     let compressed_size = next_compressed_offset - compressed_offset;
206     Ok(CompressedReadInstruction {
207         frame_index,
208         read_offset: compressed_offset,
209         read_size: compressed_size,
210     })
211 }
212 
213 impl FileReadWriteAtVolatile for ZstdDisk {
read_at_volatile(&self, slice: VolatileSlice, offset: u64) -> io::Result<usize>214     fn read_at_volatile(&self, slice: VolatileSlice, offset: u64) -> io::Result<usize> {
215         let read_instruction = compresed_frame_read_instruction(&self.seek_table, offset)
216             .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
217 
218         let mut compressed_data = vec![0u8; read_instruction.read_size as usize];
219 
220         let compressed_frame_slice = VolatileSlice::new(compressed_data.as_mut_slice());
221 
222         self.file
223             .read_at_volatile(compressed_frame_slice, read_instruction.read_offset)
224             .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
225 
226         let mut decompressor: zstd::bulk::Decompressor<'_> = zstd::bulk::Decompressor::new()?;
227         let mut decompressed_data = Vec::with_capacity(ZSTD_DEFAULT_FRAME_SIZE);
228         let decoded_size =
229             decompressor.decompress_to_buffer(&compressed_data, &mut decompressed_data)?;
230 
231         let decompressed_offset_in_frame =
232             offset - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
233 
234         if decompressed_offset_in_frame >= decoded_size as u64 {
235             return Err(io::Error::new(
236                 io::ErrorKind::InvalidData,
237                 "BUG: Frame offset larger than decoded size",
238             ));
239         }
240 
241         let read_len = min(
242             slice.size() as u64,
243             (decoded_size as u64) - decompressed_offset_in_frame,
244         ) as usize;
245         let data_to_copy = &decompressed_data[decompressed_offset_in_frame as usize..][..read_len];
246         slice
247             .sub_slice(0, read_len)
248             .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
249             .copy_from(data_to_copy);
250         Ok(data_to_copy.len())
251     }
252 
write_at_volatile(&self, _slice: VolatileSlice, _offset: u64) -> io::Result<usize>253     fn write_at_volatile(&self, _slice: VolatileSlice, _offset: u64) -> io::Result<usize> {
254         Err(io::Error::new(
255             io::ErrorKind::PermissionDenied,
256             "unsupported operation",
257         ))
258     }
259 }
260 
261 pub struct AsyncZstdDisk {
262     inner: IoSource<File>,
263     seek_table: ZstdSeekTable,
264 }
265 
266 impl ToAsyncDisk for ZstdDisk {
to_async_disk(self: Box<Self>, ex: &Executor) -> DiskResult<Box<dyn AsyncDisk>>267     fn to_async_disk(self: Box<Self>, ex: &Executor) -> DiskResult<Box<dyn AsyncDisk>> {
268         Ok(Box::new(AsyncZstdDisk {
269             inner: ex.async_from(self.file).map_err(DiskError::ToAsync)?,
270             seek_table: self.seek_table,
271         }))
272     }
273 }
274 
275 impl DiskGetLen for AsyncZstdDisk {
get_len(&self) -> io::Result<u64>276     fn get_len(&self) -> io::Result<u64> {
277         self.seek_table
278             .cumulative_decompressed_sizes
279             .last()
280             .copied()
281             .ok_or(io::ErrorKind::InvalidData.into())
282     }
283 }
284 
285 impl FileSetLen for AsyncZstdDisk {
set_len(&self, _len: u64) -> io::Result<()>286     fn set_len(&self, _len: u64) -> io::Result<()> {
287         Err(io::Error::new(
288             io::ErrorKind::PermissionDenied,
289             "unsupported operation",
290         ))
291     }
292 }
293 
294 impl FileAllocate for AsyncZstdDisk {
allocate(&self, _offset: u64, _length: u64) -> io::Result<()>295     fn allocate(&self, _offset: u64, _length: u64) -> io::Result<()> {
296         Err(io::Error::new(
297             io::ErrorKind::PermissionDenied,
298             "unsupported operation",
299         ))
300     }
301 }
302 
303 #[async_trait(?Send)]
304 impl AsyncDisk for AsyncZstdDisk {
flush(&self) -> DiskResult<()>305     async fn flush(&self) -> DiskResult<()> {
306         // zstd is read-only, nothing to flush.
307         Ok(())
308     }
309 
fsync(&self) -> DiskResult<()>310     async fn fsync(&self) -> DiskResult<()> {
311         // Do nothing because it's read-only.
312         Ok(())
313     }
314 
fdatasync(&self) -> DiskResult<()>315     async fn fdatasync(&self) -> DiskResult<()> {
316         // Do nothing because it's read-only.
317         Ok(())
318     }
319 
320     /// Reads data from `file_offset` of decompressed disk image till the end of current
321     /// zstd frame and write them into memory `mem` at `mem_offsets`. This function should
322     /// function the same as running `preadv()` on decompressed zstd image and reading into
323     /// the array of `iovec`s specified with `mem` and `mem_offsets`.
read_to_mem<'a>( &'a self, file_offset: u64, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: cros_async::MemRegionIter<'a>, ) -> DiskResult<usize>324     async fn read_to_mem<'a>(
325         &'a self,
326         file_offset: u64,
327         mem: Arc<dyn BackingMemory + Send + Sync>,
328         mem_offsets: cros_async::MemRegionIter<'a>,
329     ) -> DiskResult<usize> {
330         let read_instruction = compresed_frame_read_instruction(&self.seek_table, file_offset)
331             .map_err(|e| DiskError::ReadingData(io::Error::new(io::ErrorKind::InvalidData, e)))?;
332 
333         let compressed_data = vec![0u8; read_instruction.read_size as usize];
334 
335         let (compressed_read_size, compressed_data) = self
336             .inner
337             .read_to_vec(Some(read_instruction.read_offset), compressed_data)
338             .await
339             .map_err(|e| DiskError::ReadingData(io::Error::new(ErrorKind::Other, e)))?;
340 
341         if compressed_read_size != read_instruction.read_size as usize {
342             return Err(DiskError::ReadingData(io::Error::new(
343                 ErrorKind::UnexpectedEof,
344                 "Read from compressed data result in wrong length",
345             )));
346         }
347 
348         let mut decompressor: zstd::bulk::Decompressor<'_> =
349             zstd::bulk::Decompressor::new().map_err(DiskError::ReadingData)?;
350         let mut decompressed_data = Vec::with_capacity(ZSTD_DEFAULT_FRAME_SIZE);
351         let decoded_size = decompressor
352             .decompress_to_buffer(&compressed_data, &mut decompressed_data)
353             .map_err(DiskError::ReadingData)?;
354 
355         let decompressed_offset_in_frame = file_offset
356             - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
357 
358         if decompressed_offset_in_frame as usize > decoded_size {
359             return Err(DiskError::ReadingData(io::Error::new(
360                 ErrorKind::InvalidData,
361                 "BUG: Frame offset larger than decoded size",
362             )));
363         }
364 
365         // Copy the decompressed data to the provided memory regions.
366         let mut total_copied = 0;
367         for mem_region in mem_offsets {
368             let src_slice =
369                 &decompressed_data[decompressed_offset_in_frame as usize + total_copied..];
370             let dst_slice = mem
371                 .get_volatile_slice(mem_region)
372                 .map_err(DiskError::GuestMemory)?;
373 
374             let to_copy = min(src_slice.len(), dst_slice.size());
375 
376             if to_copy > 0 {
377                 dst_slice
378                     .sub_slice(0, to_copy)
379                     .map_err(|e| DiskError::ReadingData(io::Error::new(ErrorKind::Other, e)))?
380                     .copy_from(&src_slice[..to_copy]);
381 
382                 total_copied += to_copy;
383 
384                 // if fully copied destination buffers, break the loop.
385                 if total_copied == dst_slice.size() {
386                     break;
387                 }
388             }
389         }
390 
391         Ok(total_copied)
392     }
393 
write_from_mem<'a>( &'a self, _file_offset: u64, _mem: Arc<dyn BackingMemory + Send + Sync>, _mem_offsets: cros_async::MemRegionIter<'a>, ) -> DiskResult<usize>394     async fn write_from_mem<'a>(
395         &'a self,
396         _file_offset: u64,
397         _mem: Arc<dyn BackingMemory + Send + Sync>,
398         _mem_offsets: cros_async::MemRegionIter<'a>,
399     ) -> DiskResult<usize> {
400         Err(DiskError::UnsupportedOperation)
401     }
402 
punch_hole(&self, _file_offset: u64, _length: u64) -> DiskResult<()>403     async fn punch_hole(&self, _file_offset: u64, _length: u64) -> DiskResult<()> {
404         Err(DiskError::UnsupportedOperation)
405     }
406 
write_zeroes_at(&self, _file_offset: u64, _length: u64) -> DiskResult<()>407     async fn write_zeroes_at(&self, _file_offset: u64, _length: u64) -> DiskResult<()> {
408         Err(DiskError::UnsupportedOperation)
409     }
410 }
411 
412 impl DiskFile for ZstdDisk {}
413 
414 #[cfg(test)]
415 mod tests {
416     use super::*;
417 
418     #[test]
test_find_frame_index_empty()419     fn test_find_frame_index_empty() {
420         let seek_table = ZstdSeekTable {
421             cumulative_decompressed_sizes: vec![0],
422             cumulative_compressed_sizes: vec![0],
423         };
424         assert_eq!(seek_table.find_frame_index(0), None);
425         assert_eq!(seek_table.find_frame_index(5), None);
426     }
427 
428     #[test]
test_find_frame_index_single_frame()429     fn test_find_frame_index_single_frame() {
430         let seek_table = ZstdSeekTable {
431             cumulative_decompressed_sizes: vec![0, 100],
432             cumulative_compressed_sizes: vec![0, 50],
433         };
434         assert_eq!(seek_table.find_frame_index(0), Some(0));
435         assert_eq!(seek_table.find_frame_index(50), Some(0));
436         assert_eq!(seek_table.find_frame_index(99), Some(0));
437         assert_eq!(seek_table.find_frame_index(100), None);
438     }
439 
440     #[test]
test_find_frame_index_multiple_frames()441     fn test_find_frame_index_multiple_frames() {
442         let seek_table = ZstdSeekTable {
443             cumulative_decompressed_sizes: vec![0, 100, 300, 500],
444             cumulative_compressed_sizes: vec![0, 50, 120, 200],
445         };
446         assert_eq!(seek_table.find_frame_index(0), Some(0));
447         assert_eq!(seek_table.find_frame_index(99), Some(0));
448         assert_eq!(seek_table.find_frame_index(100), Some(1));
449         assert_eq!(seek_table.find_frame_index(299), Some(1));
450         assert_eq!(seek_table.find_frame_index(300), Some(2));
451         assert_eq!(seek_table.find_frame_index(499), Some(2));
452         assert_eq!(seek_table.find_frame_index(500), None);
453         assert_eq!(seek_table.find_frame_index(1000), None);
454     }
455 
456     #[test]
test_find_frame_index_with_skippable_frames()457     fn test_find_frame_index_with_skippable_frames() {
458         let seek_table = ZstdSeekTable {
459             cumulative_decompressed_sizes: vec![0, 100, 100, 100, 300],
460             cumulative_compressed_sizes: vec![0, 50, 60, 70, 150],
461         };
462         assert_eq!(seek_table.find_frame_index(0), Some(0));
463         assert_eq!(seek_table.find_frame_index(99), Some(0));
464         // Correctly skips the skippable frames.
465         assert_eq!(seek_table.find_frame_index(100), Some(3));
466         assert_eq!(seek_table.find_frame_index(299), Some(3));
467         assert_eq!(seek_table.find_frame_index(300), None);
468     }
469 
470     #[test]
test_find_frame_index_with_last_skippable_frame()471     fn test_find_frame_index_with_last_skippable_frame() {
472         let seek_table = ZstdSeekTable {
473             cumulative_decompressed_sizes: vec![0, 20, 40, 40, 60, 60, 80, 80],
474             cumulative_compressed_sizes: vec![0, 10, 20, 30, 40, 50, 60, 70],
475         };
476         assert_eq!(seek_table.find_frame_index(0), Some(0));
477         assert_eq!(seek_table.find_frame_index(20), Some(1));
478         assert_eq!(seek_table.find_frame_index(21), Some(1));
479         assert_eq!(seek_table.find_frame_index(79), Some(5));
480         assert_eq!(seek_table.find_frame_index(80), None);
481         assert_eq!(seek_table.find_frame_index(300), None);
482     }
483 }
484