1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 use super::leaf_node::LeafNode;
6 use super::node::{LeafIndex, NodeVec};
7 use super::tree_math::BfsIterTopDown;
8 use crate::client::MlsError;
9 use crate::crypto::CipherSuiteProvider;
10 use crate::tree_kem::math as tree_math;
11 use crate::tree_kem::node::Parent;
12 use crate::tree_kem::TreeKemPublic;
13 use alloc::collections::VecDeque;
14 use alloc::vec;
15 use alloc::vec::Vec;
16 use core::fmt::{self, Debug};
17 use itertools::Itertools;
18 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
19 use mls_rs_core::error::IntoAnyError;
20 use tree_math::TreeIndex;
21
22 use core::ops::Deref;
23
24 #[derive(Clone, Default, MlsSize, MlsEncode, MlsDecode, PartialEq)]
25 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
26 pub(crate) struct TreeHash(
27 #[mls_codec(with = "mls_rs_codec::byte_vec")]
28 #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
29 Vec<u8>,
30 );
31
32 impl Debug for TreeHash {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result33 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
34 mls_rs_core::debug::pretty_bytes(&self.0)
35 .named("TreeHash")
36 .fmt(f)
37 }
38 }
39
40 impl Deref for TreeHash {
41 type Target = [u8];
42
deref(&self) -> &Self::Target43 fn deref(&self) -> &Self::Target {
44 &self.0
45 }
46 }
47
48 #[derive(Clone, Debug, Default, MlsSize, MlsEncode, MlsDecode, PartialEq)]
49 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
50 pub(crate) struct TreeHashes {
51 pub current: Vec<TreeHash>,
52 }
53
54 #[derive(Debug, MlsSize, MlsEncode)]
55 struct LeafNodeHashInput<'a> {
56 leaf_index: LeafIndex,
57 leaf_node: Option<&'a LeafNode>,
58 }
59
60 #[derive(Debug, MlsSize, MlsEncode)]
61 struct ParentNodeTreeHashInput<'a> {
62 parent_node: Option<&'a Parent>,
63 #[mls_codec(with = "mls_rs_codec::byte_vec")]
64 left_hash: &'a [u8],
65 #[mls_codec(with = "mls_rs_codec::byte_vec")]
66 right_hash: &'a [u8],
67 }
68
69 #[derive(Debug, MlsSize, MlsEncode)]
70 #[repr(u8)]
71 enum TreeHashInput<'a> {
72 Leaf(LeafNodeHashInput<'a>) = 1u8,
73 Parent(ParentNodeTreeHashInput<'a>) = 2u8,
74 }
75
76 impl TreeKemPublic {
77 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
78 #[inline(never)]
tree_hash<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, ) -> Result<Vec<u8>, MlsError>79 pub async fn tree_hash<P: CipherSuiteProvider>(
80 &mut self,
81 cipher_suite_provider: &P,
82 ) -> Result<Vec<u8>, MlsError> {
83 self.initialize_hashes(cipher_suite_provider).await?;
84 let root = self.total_leaf_count().root();
85 Ok(self.tree_hashes.current[root as usize].to_vec())
86 }
87
88 // Update hashes after `committer` makes changes to the tree. `path_blank` is the
89 // list of leaves whose paths were blanked, i.e. updates and removes.
90 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_hashes<P: CipherSuiteProvider>( &mut self, updated_leaves: &[LeafIndex], cipher_suite_provider: &P, ) -> Result<(), MlsError>91 pub async fn update_hashes<P: CipherSuiteProvider>(
92 &mut self,
93 updated_leaves: &[LeafIndex],
94 cipher_suite_provider: &P,
95 ) -> Result<(), MlsError> {
96 let num_leaves = self.total_leaf_count();
97
98 let trailing_blanks = (0..num_leaves)
99 .rev()
100 .map_while(|l| {
101 self.tree_hashes
102 .current
103 .get(2 * l as usize)
104 .is_none()
105 .then_some(LeafIndex(l))
106 })
107 .collect::<Vec<_>>();
108
109 // Update the current hashes for direct paths of all modified leaves.
110 tree_hash(
111 &mut self.tree_hashes.current,
112 &self.nodes,
113 Some([updated_leaves, &trailing_blanks].concat()),
114 &[],
115 num_leaves,
116 cipher_suite_provider,
117 )
118 .await?;
119
120 Ok(())
121 }
122
123 // Initialize all hashes after creating / importing a tree.
124 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
initialize_hashes<P>(&mut self, cipher_suite_provider: &P) -> Result<(), MlsError> where P: CipherSuiteProvider,125 async fn initialize_hashes<P>(&mut self, cipher_suite_provider: &P) -> Result<(), MlsError>
126 where
127 P: CipherSuiteProvider,
128 {
129 if self.tree_hashes.current.is_empty() {
130 let num_leaves = self.total_leaf_count();
131
132 tree_hash(
133 &mut self.tree_hashes.current,
134 &self.nodes,
135 None,
136 &[],
137 num_leaves,
138 cipher_suite_provider,
139 )
140 .await?;
141 }
142
143 Ok(())
144 }
145
unmerged_in_subtree( &self, node_unmerged: u32, subtree_root: u32, ) -> Result<&[LeafIndex], MlsError>146 pub(crate) fn unmerged_in_subtree(
147 &self,
148 node_unmerged: u32,
149 subtree_root: u32,
150 ) -> Result<&[LeafIndex], MlsError> {
151 let unmerged = &self.nodes.borrow_as_parent(node_unmerged)?.unmerged_leaves;
152 let (left, right) = tree_math::subtree(subtree_root);
153 let mut start = 0;
154 while start < unmerged.len() && unmerged[start] < left {
155 start += 1;
156 }
157 let mut end = start;
158 while end < unmerged.len() && unmerged[end] < right {
159 end += 1;
160 }
161 Ok(&unmerged[start..end])
162 }
163
different_unmerged(&self, ancestor: u32, descendant: u32) -> Result<bool, MlsError>164 fn different_unmerged(&self, ancestor: u32, descendant: u32) -> Result<bool, MlsError> {
165 Ok(!self.nodes.is_blank(ancestor)?
166 && !self.nodes.is_blank(descendant)?
167 && self.unmerged_in_subtree(ancestor, descendant)?
168 != self.nodes.borrow_as_parent(descendant)?.unmerged_leaves)
169 }
170
171 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
compute_original_hashes<P: CipherSuiteProvider>( &self, cipher_suite: &P, ) -> Result<Vec<TreeHash>, MlsError>172 pub(crate) async fn compute_original_hashes<P: CipherSuiteProvider>(
173 &self,
174 cipher_suite: &P,
175 ) -> Result<Vec<TreeHash>, MlsError> {
176 let num_leaves = self.nodes.total_leaf_count() as usize;
177 let root = (num_leaves as u32).root();
178
179 // The value `filtered_sets[n]` is a list of all ancestors `a` of `n` s.t. we have to compute
180 // the tree hash of `n` with the unmerged leaves of `a` filtered out.
181 let mut filtered_sets = vec![vec![]; num_leaves * 2 - 1];
182 filtered_sets[root as usize].push(root);
183 let mut tree_hashes = vec![vec![]; num_leaves * 2 - 1];
184
185 let bfs_iter = BfsIterTopDown::new(num_leaves).skip(1);
186
187 for n in bfs_iter {
188 let Some(ps) = (n as u32).parent_sibling(&(num_leaves as u32)) else {
189 break;
190 };
191
192 let p = ps.parent;
193 filtered_sets[n] = filtered_sets[p as usize].clone();
194
195 if self.different_unmerged(*filtered_sets[p as usize].last().unwrap(), p)? {
196 filtered_sets[n].push(p);
197
198 // Compute tree hash of `n` without unmerged leaves of `p`. This also computes the tree hash
199 // for any descendants of `n` added to `filtered_sets` later via `clone`.
200 let (start_leaf, end_leaf) = tree_math::subtree(n as u32);
201
202 tree_hash(
203 &mut tree_hashes[p as usize],
204 &self.nodes,
205 Some((*start_leaf..*end_leaf).map(LeafIndex).collect_vec()),
206 &self.nodes.borrow_as_parent(p)?.unmerged_leaves,
207 num_leaves as u32,
208 cipher_suite,
209 )
210 .await?;
211 }
212 }
213
214 // Set the `original_hashes` based on the computed `hashes`.
215 let mut original_hashes = vec![TreeHash::default(); num_leaves * 2 - 1];
216
217 // If root has unmerged leaves, we recompute it's original hash. Else, we can use the current hash.
218 let root_original = if !self.nodes.is_blank(root)? && !self.nodes.is_leaf(root) {
219 let root_unmerged = &self.nodes.borrow_as_parent(root)?.unmerged_leaves;
220
221 if !root_unmerged.is_empty() {
222 let mut hashes = vec![];
223
224 tree_hash(
225 &mut hashes,
226 &self.nodes,
227 None,
228 root_unmerged,
229 num_leaves as u32,
230 cipher_suite,
231 )
232 .await?;
233
234 Some(hashes)
235 } else {
236 None
237 }
238 } else {
239 None
240 };
241
242 for (i, hash) in original_hashes.iter_mut().enumerate() {
243 let a = filtered_sets[i].last().unwrap();
244 *hash = if self.nodes.is_blank(*a)? || a == &root {
245 if let Some(root_original) = &root_original {
246 root_original[i].clone()
247 } else {
248 self.tree_hashes.current[i].clone()
249 }
250 } else {
251 tree_hashes[*a as usize][i].clone()
252 }
253 }
254
255 Ok(original_hashes)
256 }
257 }
258
259 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
tree_hash<P: CipherSuiteProvider>( hashes: &mut Vec<TreeHash>, nodes: &NodeVec, leaves_to_update: Option<Vec<LeafIndex>>, filtered_leaves: &[LeafIndex], num_leaves: u32, cipher_suite_provider: &P, ) -> Result<(), MlsError>260 async fn tree_hash<P: CipherSuiteProvider>(
261 hashes: &mut Vec<TreeHash>,
262 nodes: &NodeVec,
263 leaves_to_update: Option<Vec<LeafIndex>>,
264 filtered_leaves: &[LeafIndex],
265 num_leaves: u32,
266 cipher_suite_provider: &P,
267 ) -> Result<(), MlsError> {
268 let leaves_to_update =
269 leaves_to_update.unwrap_or_else(|| (0..num_leaves).map(LeafIndex).collect::<Vec<_>>());
270
271 // Resize the array in case the tree was extended or truncated
272 hashes.resize(num_leaves as usize * 2 - 1, TreeHash::default());
273
274 let mut node_queue = VecDeque::with_capacity(leaves_to_update.len());
275
276 for l in leaves_to_update.iter().filter(|l| ***l < num_leaves) {
277 let leaf = (!filtered_leaves.contains(l))
278 .then_some(nodes.borrow_as_leaf(*l).ok())
279 .flatten();
280
281 hashes[2 * **l as usize] = TreeHash(hash_for_leaf(*l, leaf, cipher_suite_provider).await?);
282
283 if let Some(ps) = (2 * **l).parent_sibling(&num_leaves) {
284 node_queue.push_back(ps.parent);
285 }
286 }
287
288 while let Some(n) = node_queue.pop_front() {
289 let hash = TreeHash(
290 hash_for_parent(
291 nodes.borrow_as_parent(n).ok(),
292 cipher_suite_provider,
293 filtered_leaves,
294 &hashes[n.left_unchecked() as usize],
295 &hashes[n.right_unchecked() as usize],
296 )
297 .await?,
298 );
299
300 hashes[n as usize] = hash;
301
302 if let Some(ps) = n.parent_sibling(&num_leaves) {
303 node_queue.push_back(ps.parent);
304 }
305 }
306
307 Ok(())
308 }
309
310 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
hash_for_leaf<P: CipherSuiteProvider>( leaf_index: LeafIndex, leaf_node: Option<&LeafNode>, cipher_suite_provider: &P, ) -> Result<Vec<u8>, MlsError>311 async fn hash_for_leaf<P: CipherSuiteProvider>(
312 leaf_index: LeafIndex,
313 leaf_node: Option<&LeafNode>,
314 cipher_suite_provider: &P,
315 ) -> Result<Vec<u8>, MlsError> {
316 let input = TreeHashInput::Leaf(LeafNodeHashInput {
317 leaf_index,
318 leaf_node,
319 });
320
321 cipher_suite_provider
322 .hash(&input.mls_encode_to_vec()?)
323 .await
324 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
325 }
326
327 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
hash_for_parent<P: CipherSuiteProvider>( parent_node: Option<&Parent>, cipher_suite_provider: &P, filtered: &[LeafIndex], left_hash: &[u8], right_hash: &[u8], ) -> Result<Vec<u8>, MlsError>328 async fn hash_for_parent<P: CipherSuiteProvider>(
329 parent_node: Option<&Parent>,
330 cipher_suite_provider: &P,
331 filtered: &[LeafIndex],
332 left_hash: &[u8],
333 right_hash: &[u8],
334 ) -> Result<Vec<u8>, MlsError> {
335 let mut parent_node = parent_node.cloned();
336
337 if let Some(ref mut parent_node) = parent_node {
338 parent_node
339 .unmerged_leaves
340 .retain(|unmerged_index| !filtered.contains(unmerged_index));
341 }
342
343 let input = TreeHashInput::Parent(ParentNodeTreeHashInput {
344 parent_node: parent_node.as_ref(),
345 left_hash,
346 right_hash,
347 });
348
349 cipher_suite_provider
350 .hash(&input.mls_encode_to_vec()?)
351 .await
352 .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))
353 }
354
355 #[cfg(test)]
356 mod tests {
357 use mls_rs_codec::MlsDecode;
358
359 use crate::{
360 cipher_suite::CipherSuite,
361 crypto::test_utils::{test_cipher_suite_provider, try_test_cipher_suite_provider},
362 identity::basic::BasicIdentityProvider,
363 tree_kem::{node::NodeVec, parent_hash::test_utils::get_test_tree_fig_12},
364 };
365
366 use super::*;
367
368 #[derive(serde::Deserialize, serde::Serialize)]
369 struct TestCase {
370 cipher_suite: u16,
371 #[serde(with = "hex::serde")]
372 tree_data: Vec<u8>,
373 #[serde(with = "hex::serde")]
374 tree_hash: Vec<u8>,
375 }
376
377 impl TestCase {
378 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
379 #[cfg_attr(coverage_nightly, coverage(off))]
generate() -> Vec<TestCase>380 async fn generate() -> Vec<TestCase> {
381 let mut test_cases = Vec::new();
382
383 for cipher_suite in CipherSuite::all() {
384 let mut tree = get_test_tree_fig_12(cipher_suite).await;
385
386 test_cases.push(TestCase {
387 cipher_suite: cipher_suite.into(),
388 tree_data: tree.nodes.mls_encode_to_vec().unwrap(),
389 tree_hash: tree
390 .tree_hash(&test_cipher_suite_provider(cipher_suite))
391 .await
392 .unwrap(),
393 })
394 }
395
396 test_cases
397 }
398 }
399
400 #[cfg(mls_build_async)]
load_test_cases() -> Vec<TestCase>401 async fn load_test_cases() -> Vec<TestCase> {
402 load_test_case_json!(tree_hash, TestCase::generate().await)
403 }
404
405 #[cfg(not(mls_build_async))]
load_test_cases() -> Vec<TestCase>406 fn load_test_cases() -> Vec<TestCase> {
407 load_test_case_json!(tree_hash, TestCase::generate())
408 }
409
410 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_tree_hash()411 async fn test_tree_hash() {
412 let cases = load_test_cases().await;
413
414 for one_case in cases {
415 let Some(cs_provider) = try_test_cipher_suite_provider(one_case.cipher_suite) else {
416 continue;
417 };
418
419 let mut tree = TreeKemPublic::import_node_data(
420 NodeVec::mls_decode(&mut &*one_case.tree_data).unwrap(),
421 &BasicIdentityProvider,
422 &Default::default(),
423 )
424 .await
425 .unwrap();
426
427 let calculated_hash = tree.tree_hash(&cs_provider).await.unwrap();
428
429 assert_eq!(calculated_hash, one_case.tree_hash);
430 }
431 }
432 }
433