1 /*
2  * Copyright (C) 2021 The Android Open Source Project
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *      http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 use openssl::hash::{DigestBytes, Hasher, MessageDigest};
18 use std::io::{Cursor, Read, Result, Write};
19 
20 /// `HashTree` is a merkle tree (and its root hash) that is compatible with fs-verity.
21 pub struct HashTree {
22     /// Binary presentation of the merkle tree
23     pub tree: Vec<u8>,
24     /// Root hash
25     pub root_hash: Vec<u8>,
26 }
27 
28 impl HashTree {
29     /// Creates merkle tree from `input`, using the given `salt` and hashing `algorithm`. `input`
30     /// is divided into `block_size` chunks.
from<R: Read>( input: &mut R, input_size: usize, salt: &[u8], block_size: usize, algorithm: MessageDigest, ) -> Result<Self>31     pub fn from<R: Read>(
32         input: &mut R,
33         input_size: usize,
34         salt: &[u8],
35         block_size: usize,
36         algorithm: MessageDigest,
37     ) -> Result<Self> {
38         let salt = zero_pad_salt(salt, algorithm);
39         let tree = generate_hash_tree(input, input_size, &salt, block_size, algorithm)?;
40 
41         // Root hash is from the first block of the hash or the input data if there is no hash tree
42         // generated which can happen when input data is smaller than block size
43         let root_hash = if tree.is_empty() {
44             let mut data = Vec::new();
45             input.read_to_end(&mut data)?;
46             hash_one_block(&data, &salt, block_size, algorithm)?.as_ref().to_vec()
47         } else {
48             let first_block = &tree[0..block_size];
49             hash_one_block(first_block, &salt, block_size, algorithm)?.as_ref().to_vec()
50         };
51         Ok(HashTree { tree, root_hash })
52     }
53 }
54 
55 /// Calculate hash tree for the blocks in `input`.
56 ///
57 /// This function implements: https://www.kernel.org/doc/html/latest/filesystems/fsverity.html#merkle-tree
58 ///
59 /// The file contents is divided into blocks, where the block size is configurable but is usually
60 /// 4096 bytes. The end of the last block is zero-padded if needed. Each block is then hashed,
61 /// producing the first level of hashes. Then, the hashes in this first level are grouped into
62 /// blocksize-byte blocks (zero-padding the ends as needed) and these blocks are hashed,
63 /// producing the second level of hashes. This proceeds up the tree until only a single block
64 /// remains.
generate_hash_tree<R: Read>( input: &mut R, input_size: usize, salt: &[u8], block_size: usize, algorithm: MessageDigest, ) -> Result<Vec<u8>>65 pub fn generate_hash_tree<R: Read>(
66     input: &mut R,
67     input_size: usize,
68     salt: &[u8],
69     block_size: usize,
70     algorithm: MessageDigest,
71 ) -> Result<Vec<u8>> {
72     let digest_size = algorithm.size();
73     let levels = calc_hash_levels(input_size, block_size, digest_size);
74     let tree_size = levels.iter().map(|r| r.len()).sum();
75 
76     // The contiguous memory that holds the entire merkle tree
77     let mut hash_tree = vec![0; tree_size];
78 
79     for (n, cur) in levels.iter().enumerate() {
80         if n == 0 {
81             // Level 0: the (zero-padded) input stream is hashed into level 0
82             let pad_size = round_to_multiple(input_size, block_size) - input_size;
83             let mut input = input.chain(Cursor::new(vec![0; pad_size]));
84             let mut level0 = Cursor::new(&mut hash_tree[cur.start..cur.end]);
85 
86             let mut a_block = vec![0; block_size];
87             let mut num_blocks = (input_size + block_size - 1) / block_size;
88             while num_blocks > 0 {
89                 input.read_exact(&mut a_block)?;
90                 let h = hash_one_block(&a_block, salt, block_size, algorithm)?;
91                 level0.write_all(h.as_ref()).unwrap();
92                 num_blocks -= 1;
93             }
94         } else {
95             // Intermediate levels: level n - 1 is hashed into level n
96             // Both levels belong to the same `hash_tree`. In order to have a mutable slice for
97             // level n while having a slice for level n - 1, take the mutable slice for both levels
98             // and split it.
99             let prev = &levels[n - 1];
100             let cur_and_prev = &mut hash_tree[cur.start..prev.end];
101             let (cur, prev) = cur_and_prev.split_at_mut(prev.start - cur.start);
102             let mut cur = Cursor::new(cur);
103             for data in prev.chunks(block_size) {
104                 let h = hash_one_block(data, salt, block_size, algorithm)?;
105                 cur.write_all(h.as_ref()).unwrap();
106             }
107         }
108     }
109     Ok(hash_tree)
110 }
111 
112 /// Hash one block of input using the given hash algorithm and the salt. Input might be smaller
113 /// than a block, in which case zero is padded.
hash_one_block( input: &[u8], salt: &[u8], block_size: usize, algorithm: MessageDigest, ) -> Result<DigestBytes>114 fn hash_one_block(
115     input: &[u8],
116     salt: &[u8],
117     block_size: usize,
118     algorithm: MessageDigest,
119 ) -> Result<DigestBytes> {
120     let mut ctx = Hasher::new(algorithm)?;
121     ctx.update(salt)?;
122     ctx.update(input)?;
123     let pad_size = block_size - input.len();
124     ctx.update(&vec![0; pad_size])?;
125     Ok(ctx.finish()?)
126 }
127 
128 type Range = std::ops::Range<usize>;
129 
130 /// Calculate the ranges of hash for each level
calc_hash_levels(input_size: usize, block_size: usize, digest_size: usize) -> Vec<Range>131 fn calc_hash_levels(input_size: usize, block_size: usize, digest_size: usize) -> Vec<Range> {
132     // The input is split into multiple blocks and each block is hashed, which becomes the input
133     // for the next level. Size of a single hash is `digest_size`.
134     let mut level_sizes = Vec::new();
135     loop {
136         // Input for this level is from either the last level (if exists), or the input parameter.
137         let input_size = *level_sizes.last().unwrap_or(&input_size);
138         if input_size <= block_size {
139             break;
140         }
141         let num_blocks = (input_size + block_size - 1) / block_size;
142         let hashes_size = round_to_multiple(num_blocks * digest_size, block_size);
143         level_sizes.push(hashes_size);
144     }
145 
146     // The hash tree is stored upside down. The top level is at offset 0. The second level comes
147     // next, and so on. Level 0 is located at the end.
148     //
149     // Given level_sizes [10, 3, 1], the offsets for each label are ...
150     //
151     // Level 2 is at offset 0
152     // Level 1 is at offset 1 (because Level 2 is of size 1)
153     // Level 0 is at offset 4 (because Level 1 is of size 3)
154     //
155     // This is done by scanning the sizes in reverse order
156     let mut ranges = level_sizes
157         .iter()
158         .rev()
159         .scan(0, |prev_end, size| {
160             let range = *prev_end..*prev_end + size;
161             *prev_end = range.end;
162             Some(range)
163         })
164         .collect::<Vec<_>>();
165     ranges.reverse(); // reverse again so that index N is for level N
166     ranges
167 }
168 
169 /// Round `n` up to the nearest multiple of `unit`
round_to_multiple(n: usize, unit: usize) -> usize170 fn round_to_multiple(n: usize, unit: usize) -> usize {
171     (n + unit - 1) & !(unit - 1)
172 }
173 
174 /// Pad zero to salt if necessary.
175 ///
176 /// According to https://www.kernel.org/doc/html/latest/filesystems/fsverity.html:
177 ///
178 /// If a salt was specified, then it’s zero-padded to the closest multiple of the input size of the
179 /// hash algorithm’s compression function, e.g. 64 bytes for SHA-256 or 128 bytes for SHA-512. The
180 /// padded salt is prepended to every data or Merkle tree block that is hashed.
zero_pad_salt(salt: &[u8], algorithm: MessageDigest) -> Vec<u8>181 fn zero_pad_salt(salt: &[u8], algorithm: MessageDigest) -> Vec<u8> {
182     if salt.is_empty() {
183         salt.to_vec()
184     } else {
185         let padded_len = round_to_multiple(salt.len(), algorithm.block_size());
186         let mut salt = salt.to_vec();
187         salt.resize(padded_len, 0);
188         salt
189     }
190 }
191 
192 #[cfg(test)]
193 mod tests {
194     use super::*;
195     use openssl::hash::MessageDigest;
196     use std::fs::{self, File};
197 
198     #[test]
compare_with_golden_output() -> Result<()>199     fn compare_with_golden_output() -> Result<()> {
200         // The golden outputs are generated by using the `fsverity` utility.
201         let sizes = ["512", "4K", "1M", "10000000", "272629760"];
202         for size in sizes.iter() {
203             let input_name = format!("tests/data/input.{}", size);
204             let mut input = File::open(&input_name)?;
205             let golden_hash_tree = fs::read(format!("{}.hash", input_name))?;
206             let golden_descriptor = fs::read(format!("{}.descriptor", input_name))?;
207             let golden_root_hash = &golden_descriptor[16..16 + 32];
208 
209             let size = std::fs::metadata(&input_name)?.len() as usize;
210             let salt = vec![1, 2, 3, 4, 5, 6];
211             let ht = HashTree::from(&mut input, size, &salt, 4096, MessageDigest::sha256())?;
212 
213             assert_eq!(golden_hash_tree.as_slice(), ht.tree.as_slice());
214             assert_eq!(golden_root_hash, ht.root_hash.as_slice());
215         }
216         Ok(())
217     }
218 }
219