1 // Copyright 2024 The ChromiumOS Authors
2 // Use of this source code is governed by a BSD-style license that can be
3 // found in the LICENSE file.
4
5 //! Use seekable zstd archive of raw disk image as read only disk
6
7 use std::cmp::min;
8 use std::fs::File;
9 use std::io;
10 use std::io::ErrorKind;
11 use std::io::Read;
12 use std::io::Seek;
13 use std::sync::Arc;
14
15 use anyhow::bail;
16 use anyhow::Context;
17 use async_trait::async_trait;
18 use base::AsRawDescriptor;
19 use base::FileAllocate;
20 use base::FileReadWriteAtVolatile;
21 use base::FileSetLen;
22 use base::RawDescriptor;
23 use base::VolatileSlice;
24 use cros_async::BackingMemory;
25 use cros_async::Executor;
26 use cros_async::IoSource;
27
28 use crate::AsyncDisk;
29 use crate::DiskFile;
30 use crate::DiskGetLen;
31 use crate::Error as DiskError;
32 use crate::Result as DiskResult;
33 use crate::ToAsyncDisk;
34
35 // Zstandard frame magic
36 pub const ZSTD_FRAME_MAGIC: u32 = 0xFD2FB528;
37
38 // Skippable frame magic can be anything between [0x184D2A50, 0x184D2A5F]
39 pub const ZSTD_SKIPPABLE_MAGIC_LOW: u32 = 0x184D2A50;
40 pub const ZSTD_SKIPPABLE_MAGIC_HIGH: u32 = 0x184D2A5F;
41 pub const ZSTD_SEEK_TABLE_MAGIC: u32 = 0x8F92EAB1;
42
43 pub const ZSTD_DEFAULT_FRAME_SIZE: usize = 128 << 10; // 128KB
44
45 #[derive(Clone, Debug)]
46 pub struct ZstdSeekTable {
47 // Cumulative sum of decompressed sizes of all frames before the indexed frame.
48 // The last element is the total decompressed size of the zstd archive.
49 cumulative_decompressed_sizes: Vec<u64>,
50 // Cumulative sum of compressed sizes of all frames before the indexed frame.
51 // The last element is the total compressed size of the zstd archive.
52 cumulative_compressed_sizes: Vec<u64>,
53 }
54
55 impl ZstdSeekTable {
56 /// Read seek table entries from seek_table_entries
from_footer( seek_table_entries: &[u8], num_frames: u32, checksum_flag: bool, ) -> anyhow::Result<ZstdSeekTable>57 pub fn from_footer(
58 seek_table_entries: &[u8],
59 num_frames: u32,
60 checksum_flag: bool,
61 ) -> anyhow::Result<ZstdSeekTable> {
62 let mut cumulative_decompressed_size: u64 = 0;
63 let mut cumulative_compressed_size: u64 = 0;
64 let mut cumulative_decompressed_sizes = Vec::with_capacity(num_frames as usize + 1);
65 let mut cumulative_compressed_sizes = Vec::with_capacity(num_frames as usize + 1);
66 let mut offset = 0;
67 cumulative_decompressed_sizes.push(0);
68 cumulative_compressed_sizes.push(0);
69 for _ in 0..num_frames {
70 let compressed_size = u32::from_le_bytes(
71 seek_table_entries
72 .get(offset..offset + 4)
73 .context("failed to parse seektable entry")?
74 .try_into()?,
75 );
76 let decompressed_size = u32::from_le_bytes(
77 seek_table_entries
78 .get(offset + 4..offset + 8)
79 .context("failed to parse seektable entry")?
80 .try_into()?,
81 );
82 cumulative_decompressed_size += decompressed_size as u64;
83 cumulative_compressed_size += compressed_size as u64;
84 cumulative_decompressed_sizes.push(cumulative_decompressed_size);
85 cumulative_compressed_sizes.push(cumulative_compressed_size);
86 offset += 8 + (checksum_flag as usize * 4);
87 }
88 cumulative_decompressed_sizes.push(cumulative_decompressed_size);
89 cumulative_compressed_sizes.push(cumulative_compressed_size);
90
91 Ok(ZstdSeekTable {
92 cumulative_decompressed_sizes,
93 cumulative_compressed_sizes,
94 })
95 }
96
97 /// Returns the index of the frame that contains the given decompressed offset.
find_frame_index(&self, decompressed_offset: u64) -> Option<usize>98 pub fn find_frame_index(&self, decompressed_offset: u64) -> Option<usize> {
99 if self.cumulative_decompressed_sizes.is_empty()
100 || decompressed_offset >= *self.cumulative_decompressed_sizes.last().unwrap()
101 {
102 return None;
103 }
104 self.cumulative_decompressed_sizes
105 .partition_point(|&size| size <= decompressed_offset)
106 .checked_sub(1)
107 }
108 }
109
110 #[derive(Debug)]
111 pub struct ZstdDisk {
112 file: File,
113 seek_table: ZstdSeekTable,
114 }
115
116 impl ZstdDisk {
from_file(mut file: File) -> anyhow::Result<ZstdDisk>117 pub fn from_file(mut file: File) -> anyhow::Result<ZstdDisk> {
118 // Verify file is large enough to contain a seek table (17 bytes)
119 if file.metadata()?.len() < 17 {
120 return Err(anyhow::anyhow!("File too small to contain zstd seek table"));
121 }
122
123 // Read last 9 bytes as seek table footer
124 let mut seektable_footer = [0u8; 9];
125 file.seek(std::io::SeekFrom::End(-9))?;
126 file.read_exact(&mut seektable_footer)?;
127
128 // Verify last 4 bytes of footer is seek table magic
129 if u32::from_le_bytes(seektable_footer[5..9].try_into()?) != ZSTD_SEEK_TABLE_MAGIC {
130 return Err(anyhow::anyhow!("Invalid zstd seek table magic"));
131 }
132
133 // Get number of frame from seek table
134 let num_frames = u32::from_le_bytes(seektable_footer[0..4].try_into()?);
135
136 // Read flags from seek table descriptor
137 let checksum_flag = (seektable_footer[4] >> 7) & 1 != 0;
138 if (seektable_footer[4] & 0x7C) != 0 {
139 bail!(
140 "This zstd seekable decoder cannot parse seek table with non-zero reserved flags"
141 );
142 }
143
144 let seek_table_entries_size = num_frames * (8 + (checksum_flag as u32 * 4));
145
146 // Seek to the beginning of the seek table
147 file.seek(std::io::SeekFrom::End(
148 -(9 + seek_table_entries_size as i64),
149 ))?;
150
151 // Return new ZstdDisk
152 let mut seek_table_entries: Vec<u8> = vec![0u8; seek_table_entries_size as usize];
153 file.read_exact(&mut seek_table_entries)?;
154
155 let seek_table =
156 ZstdSeekTable::from_footer(&seek_table_entries, num_frames, checksum_flag)?;
157
158 Ok(ZstdDisk { file, seek_table })
159 }
160 }
161
162 impl DiskGetLen for ZstdDisk {
get_len(&self) -> std::io::Result<u64>163 fn get_len(&self) -> std::io::Result<u64> {
164 self.seek_table
165 .cumulative_decompressed_sizes
166 .last()
167 .copied()
168 .ok_or(io::ErrorKind::InvalidData.into())
169 }
170 }
171
172 impl FileSetLen for ZstdDisk {
set_len(&self, _len: u64) -> std::io::Result<()>173 fn set_len(&self, _len: u64) -> std::io::Result<()> {
174 Err(io::Error::new(
175 io::ErrorKind::PermissionDenied,
176 "unsupported operation",
177 ))
178 }
179 }
180
181 impl AsRawDescriptor for ZstdDisk {
as_raw_descriptor(&self) -> RawDescriptor182 fn as_raw_descriptor(&self) -> RawDescriptor {
183 self.file.as_raw_descriptor()
184 }
185 }
186
187 struct CompressedReadInstruction {
188 frame_index: usize,
189 read_offset: u64,
190 read_size: u64,
191 }
192
compresed_frame_read_instruction( seek_table: &ZstdSeekTable, offset: u64, ) -> anyhow::Result<CompressedReadInstruction>193 fn compresed_frame_read_instruction(
194 seek_table: &ZstdSeekTable,
195 offset: u64,
196 ) -> anyhow::Result<CompressedReadInstruction> {
197 let frame_index = seek_table
198 .find_frame_index(offset)
199 .with_context(|| format!("no frame for offset {}", offset))?;
200 let compressed_offset = seek_table.cumulative_compressed_sizes[frame_index];
201 let next_compressed_offset = seek_table
202 .cumulative_compressed_sizes
203 .get(frame_index + 1)
204 .context("Offset out of range (next_compressed_offset overflow)")?;
205 let compressed_size = next_compressed_offset - compressed_offset;
206 Ok(CompressedReadInstruction {
207 frame_index,
208 read_offset: compressed_offset,
209 read_size: compressed_size,
210 })
211 }
212
213 impl FileReadWriteAtVolatile for ZstdDisk {
read_at_volatile(&self, slice: VolatileSlice, offset: u64) -> io::Result<usize>214 fn read_at_volatile(&self, slice: VolatileSlice, offset: u64) -> io::Result<usize> {
215 let read_instruction = compresed_frame_read_instruction(&self.seek_table, offset)
216 .map_err(|e| io::Error::new(io::ErrorKind::InvalidData, e))?;
217
218 let mut compressed_data = vec![0u8; read_instruction.read_size as usize];
219
220 let compressed_frame_slice = VolatileSlice::new(compressed_data.as_mut_slice());
221
222 self.file
223 .read_at_volatile(compressed_frame_slice, read_instruction.read_offset)
224 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?;
225
226 let mut decompressor: zstd::bulk::Decompressor<'_> = zstd::bulk::Decompressor::new()?;
227 let mut decompressed_data = Vec::with_capacity(ZSTD_DEFAULT_FRAME_SIZE);
228 let decoded_size =
229 decompressor.decompress_to_buffer(&compressed_data, &mut decompressed_data)?;
230
231 let decompressed_offset_in_frame =
232 offset - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
233
234 if decompressed_offset_in_frame >= decoded_size as u64 {
235 return Err(io::Error::new(
236 io::ErrorKind::InvalidData,
237 "BUG: Frame offset larger than decoded size",
238 ));
239 }
240
241 let read_len = min(
242 slice.size() as u64,
243 (decoded_size as u64) - decompressed_offset_in_frame,
244 ) as usize;
245 let data_to_copy = &decompressed_data[decompressed_offset_in_frame as usize..][..read_len];
246 slice
247 .sub_slice(0, read_len)
248 .map_err(|e| io::Error::new(io::ErrorKind::Other, e))?
249 .copy_from(data_to_copy);
250 Ok(data_to_copy.len())
251 }
252
write_at_volatile(&self, _slice: VolatileSlice, _offset: u64) -> io::Result<usize>253 fn write_at_volatile(&self, _slice: VolatileSlice, _offset: u64) -> io::Result<usize> {
254 Err(io::Error::new(
255 io::ErrorKind::PermissionDenied,
256 "unsupported operation",
257 ))
258 }
259 }
260
261 pub struct AsyncZstdDisk {
262 inner: IoSource<File>,
263 seek_table: ZstdSeekTable,
264 }
265
266 impl ToAsyncDisk for ZstdDisk {
to_async_disk(self: Box<Self>, ex: &Executor) -> DiskResult<Box<dyn AsyncDisk>>267 fn to_async_disk(self: Box<Self>, ex: &Executor) -> DiskResult<Box<dyn AsyncDisk>> {
268 Ok(Box::new(AsyncZstdDisk {
269 inner: ex.async_from(self.file).map_err(DiskError::ToAsync)?,
270 seek_table: self.seek_table,
271 }))
272 }
273 }
274
275 impl DiskGetLen for AsyncZstdDisk {
get_len(&self) -> io::Result<u64>276 fn get_len(&self) -> io::Result<u64> {
277 self.seek_table
278 .cumulative_decompressed_sizes
279 .last()
280 .copied()
281 .ok_or(io::ErrorKind::InvalidData.into())
282 }
283 }
284
285 impl FileSetLen for AsyncZstdDisk {
set_len(&self, _len: u64) -> io::Result<()>286 fn set_len(&self, _len: u64) -> io::Result<()> {
287 Err(io::Error::new(
288 io::ErrorKind::PermissionDenied,
289 "unsupported operation",
290 ))
291 }
292 }
293
294 impl FileAllocate for AsyncZstdDisk {
allocate(&self, _offset: u64, _length: u64) -> io::Result<()>295 fn allocate(&self, _offset: u64, _length: u64) -> io::Result<()> {
296 Err(io::Error::new(
297 io::ErrorKind::PermissionDenied,
298 "unsupported operation",
299 ))
300 }
301 }
302
303 #[async_trait(?Send)]
304 impl AsyncDisk for AsyncZstdDisk {
flush(&self) -> DiskResult<()>305 async fn flush(&self) -> DiskResult<()> {
306 // zstd is read-only, nothing to flush.
307 Ok(())
308 }
309
fsync(&self) -> DiskResult<()>310 async fn fsync(&self) -> DiskResult<()> {
311 // Do nothing because it's read-only.
312 Ok(())
313 }
314
fdatasync(&self) -> DiskResult<()>315 async fn fdatasync(&self) -> DiskResult<()> {
316 // Do nothing because it's read-only.
317 Ok(())
318 }
319
320 /// Reads data from `file_offset` of decompressed disk image till the end of current
321 /// zstd frame and write them into memory `mem` at `mem_offsets`. This function should
322 /// function the same as running `preadv()` on decompressed zstd image and reading into
323 /// the array of `iovec`s specified with `mem` and `mem_offsets`.
read_to_mem<'a>( &'a self, file_offset: u64, mem: Arc<dyn BackingMemory + Send + Sync>, mem_offsets: cros_async::MemRegionIter<'a>, ) -> DiskResult<usize>324 async fn read_to_mem<'a>(
325 &'a self,
326 file_offset: u64,
327 mem: Arc<dyn BackingMemory + Send + Sync>,
328 mem_offsets: cros_async::MemRegionIter<'a>,
329 ) -> DiskResult<usize> {
330 let read_instruction = compresed_frame_read_instruction(&self.seek_table, file_offset)
331 .map_err(|e| DiskError::ReadingData(io::Error::new(io::ErrorKind::InvalidData, e)))?;
332
333 let compressed_data = vec![0u8; read_instruction.read_size as usize];
334
335 let (compressed_read_size, compressed_data) = self
336 .inner
337 .read_to_vec(Some(read_instruction.read_offset), compressed_data)
338 .await
339 .map_err(|e| DiskError::ReadingData(io::Error::new(ErrorKind::Other, e)))?;
340
341 if compressed_read_size != read_instruction.read_size as usize {
342 return Err(DiskError::ReadingData(io::Error::new(
343 ErrorKind::UnexpectedEof,
344 "Read from compressed data result in wrong length",
345 )));
346 }
347
348 let mut decompressor: zstd::bulk::Decompressor<'_> =
349 zstd::bulk::Decompressor::new().map_err(DiskError::ReadingData)?;
350 let mut decompressed_data = Vec::with_capacity(ZSTD_DEFAULT_FRAME_SIZE);
351 let decoded_size = decompressor
352 .decompress_to_buffer(&compressed_data, &mut decompressed_data)
353 .map_err(DiskError::ReadingData)?;
354
355 let decompressed_offset_in_frame = file_offset
356 - self.seek_table.cumulative_decompressed_sizes[read_instruction.frame_index];
357
358 if decompressed_offset_in_frame as usize > decoded_size {
359 return Err(DiskError::ReadingData(io::Error::new(
360 ErrorKind::InvalidData,
361 "BUG: Frame offset larger than decoded size",
362 )));
363 }
364
365 // Copy the decompressed data to the provided memory regions.
366 let mut total_copied = 0;
367 for mem_region in mem_offsets {
368 let src_slice =
369 &decompressed_data[decompressed_offset_in_frame as usize + total_copied..];
370 let dst_slice = mem
371 .get_volatile_slice(mem_region)
372 .map_err(DiskError::GuestMemory)?;
373
374 let to_copy = min(src_slice.len(), dst_slice.size());
375
376 if to_copy > 0 {
377 dst_slice
378 .sub_slice(0, to_copy)
379 .map_err(|e| DiskError::ReadingData(io::Error::new(ErrorKind::Other, e)))?
380 .copy_from(&src_slice[..to_copy]);
381
382 total_copied += to_copy;
383
384 // if fully copied destination buffers, break the loop.
385 if total_copied == dst_slice.size() {
386 break;
387 }
388 }
389 }
390
391 Ok(total_copied)
392 }
393
write_from_mem<'a>( &'a self, _file_offset: u64, _mem: Arc<dyn BackingMemory + Send + Sync>, _mem_offsets: cros_async::MemRegionIter<'a>, ) -> DiskResult<usize>394 async fn write_from_mem<'a>(
395 &'a self,
396 _file_offset: u64,
397 _mem: Arc<dyn BackingMemory + Send + Sync>,
398 _mem_offsets: cros_async::MemRegionIter<'a>,
399 ) -> DiskResult<usize> {
400 Err(DiskError::UnsupportedOperation)
401 }
402
punch_hole(&self, _file_offset: u64, _length: u64) -> DiskResult<()>403 async fn punch_hole(&self, _file_offset: u64, _length: u64) -> DiskResult<()> {
404 Err(DiskError::UnsupportedOperation)
405 }
406
write_zeroes_at(&self, _file_offset: u64, _length: u64) -> DiskResult<()>407 async fn write_zeroes_at(&self, _file_offset: u64, _length: u64) -> DiskResult<()> {
408 Err(DiskError::UnsupportedOperation)
409 }
410 }
411
412 impl DiskFile for ZstdDisk {}
413
414 #[cfg(test)]
415 mod tests {
416 use super::*;
417
418 #[test]
test_find_frame_index_empty()419 fn test_find_frame_index_empty() {
420 let seek_table = ZstdSeekTable {
421 cumulative_decompressed_sizes: vec![0],
422 cumulative_compressed_sizes: vec![0],
423 };
424 assert_eq!(seek_table.find_frame_index(0), None);
425 assert_eq!(seek_table.find_frame_index(5), None);
426 }
427
428 #[test]
test_find_frame_index_single_frame()429 fn test_find_frame_index_single_frame() {
430 let seek_table = ZstdSeekTable {
431 cumulative_decompressed_sizes: vec![0, 100],
432 cumulative_compressed_sizes: vec![0, 50],
433 };
434 assert_eq!(seek_table.find_frame_index(0), Some(0));
435 assert_eq!(seek_table.find_frame_index(50), Some(0));
436 assert_eq!(seek_table.find_frame_index(99), Some(0));
437 assert_eq!(seek_table.find_frame_index(100), None);
438 }
439
440 #[test]
test_find_frame_index_multiple_frames()441 fn test_find_frame_index_multiple_frames() {
442 let seek_table = ZstdSeekTable {
443 cumulative_decompressed_sizes: vec![0, 100, 300, 500],
444 cumulative_compressed_sizes: vec![0, 50, 120, 200],
445 };
446 assert_eq!(seek_table.find_frame_index(0), Some(0));
447 assert_eq!(seek_table.find_frame_index(99), Some(0));
448 assert_eq!(seek_table.find_frame_index(100), Some(1));
449 assert_eq!(seek_table.find_frame_index(299), Some(1));
450 assert_eq!(seek_table.find_frame_index(300), Some(2));
451 assert_eq!(seek_table.find_frame_index(499), Some(2));
452 assert_eq!(seek_table.find_frame_index(500), None);
453 assert_eq!(seek_table.find_frame_index(1000), None);
454 }
455
456 #[test]
test_find_frame_index_with_skippable_frames()457 fn test_find_frame_index_with_skippable_frames() {
458 let seek_table = ZstdSeekTable {
459 cumulative_decompressed_sizes: vec![0, 100, 100, 100, 300],
460 cumulative_compressed_sizes: vec![0, 50, 60, 70, 150],
461 };
462 assert_eq!(seek_table.find_frame_index(0), Some(0));
463 assert_eq!(seek_table.find_frame_index(99), Some(0));
464 // Correctly skips the skippable frames.
465 assert_eq!(seek_table.find_frame_index(100), Some(3));
466 assert_eq!(seek_table.find_frame_index(299), Some(3));
467 assert_eq!(seek_table.find_frame_index(300), None);
468 }
469
470 #[test]
test_find_frame_index_with_last_skippable_frame()471 fn test_find_frame_index_with_last_skippable_frame() {
472 let seek_table = ZstdSeekTable {
473 cumulative_decompressed_sizes: vec![0, 20, 40, 40, 60, 60, 80, 80],
474 cumulative_compressed_sizes: vec![0, 10, 20, 30, 40, 50, 60, 70],
475 };
476 assert_eq!(seek_table.find_frame_index(0), Some(0));
477 assert_eq!(seek_table.find_frame_index(20), Some(1));
478 assert_eq!(seek_table.find_frame_index(21), Some(1));
479 assert_eq!(seek_table.find_frame_index(79), Some(5));
480 assert_eq!(seek_table.find_frame_index(80), None);
481 assert_eq!(seek_table.find_frame_index(300), None);
482 }
483 }
484