1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::client::MlsError;
6 use crate::crypto::{CipherSuiteProvider, HpkePublicKey};
7 use crate::tree_kem::math as tree_math;
8 use crate::tree_kem::node::{LeafIndex, Node, NodeIndex};
9 use crate::tree_kem::TreeKemPublic;
10 use alloc::vec::Vec;
11 use core::{
12     fmt::{self, Debug},
13     ops::Deref,
14 };
15 use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize};
16 use mls_rs_core::error::IntoAnyError;
17 use tree_math::TreeIndex;
18 
19 use super::leaf_node::LeafNodeSource;
20 
21 #[cfg(feature = "std")]
22 use std::collections::HashSet;
23 
24 #[cfg(not(feature = "std"))]
25 use alloc::collections::BTreeSet;
26 
27 #[derive(Clone, Debug, MlsSize, MlsEncode)]
28 struct ParentHashInput<'a> {
29     #[mls_codec(with = "mls_rs_codec::byte_vec")]
30     public_key: &'a HpkePublicKey,
31     #[mls_codec(with = "mls_rs_codec::byte_vec")]
32     parent_hash: &'a [u8],
33     #[mls_codec(with = "mls_rs_codec::byte_vec")]
34     original_sibling_tree_hash: &'a [u8],
35 }
36 
37 #[derive(Clone, MlsSize, MlsEncode, MlsDecode, PartialEq, Eq)]
38 #[cfg_attr(feature = "arbitrary", derive(arbitrary::Arbitrary))]
39 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
40 pub struct ParentHash(
41     #[mls_codec(with = "mls_rs_codec::byte_vec")]
42     #[cfg_attr(feature = "serde", serde(with = "mls_rs_core::vec_serde"))]
43     Vec<u8>,
44 );
45 
46 impl Debug for ParentHash {
fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result47     fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48         mls_rs_core::debug::pretty_bytes(&self.0)
49             .named("ParentHash")
50             .fmt(f)
51     }
52 }
53 
54 impl From<Vec<u8>> for ParentHash {
from(v: Vec<u8>) -> Self55     fn from(v: Vec<u8>) -> Self {
56         Self(v)
57     }
58 }
59 
60 impl Deref for ParentHash {
61     type Target = Vec<u8>;
62 
deref(&self) -> &Self::Target63     fn deref(&self) -> &Self::Target {
64         &self.0
65     }
66 }
67 
68 impl ParentHash {
69     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
new<P: CipherSuiteProvider>( cipher_suite_provider: &P, public_key: &HpkePublicKey, parent_hash: &ParentHash, original_sibling_tree_hash: &[u8], ) -> Result<Self, MlsError>70     pub async fn new<P: CipherSuiteProvider>(
71         cipher_suite_provider: &P,
72         public_key: &HpkePublicKey,
73         parent_hash: &ParentHash,
74         original_sibling_tree_hash: &[u8],
75     ) -> Result<Self, MlsError> {
76         let input = ParentHashInput {
77             public_key,
78             parent_hash,
79             original_sibling_tree_hash,
80         };
81 
82         let input_bytes = input.mls_encode_to_vec()?;
83 
84         let hash = cipher_suite_provider
85             .hash(&input_bytes)
86             .await
87             .map_err(|e| MlsError::CryptoProviderError(e.into_any_error()))?;
88 
89         Ok(Self(hash))
90     }
91 
empty() -> Self92     pub fn empty() -> Self {
93         ParentHash(Vec::new())
94     }
95 
matches(&self, hash: &ParentHash) -> bool96     pub fn matches(&self, hash: &ParentHash) -> bool {
97         //TODO: Constant time equals
98         hash == self
99     }
100 }
101 
102 impl Node {
get_parent_hash(&self) -> Option<ParentHash>103     fn get_parent_hash(&self) -> Option<ParentHash> {
104         match self {
105             Node::Parent(p) => Some(p.parent_hash.clone()),
106             Node::Leaf(l) => match &l.leaf_node_source {
107                 LeafNodeSource::Commit(parent_hash) => Some(parent_hash.clone()),
108                 _ => None,
109             },
110         }
111     }
112 }
113 
114 impl TreeKemPublic {
115     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
parent_hash_for_leaf<P: CipherSuiteProvider>( &mut self, cipher_suite_provider: &P, index: LeafIndex, ) -> Result<ParentHash, MlsError>116     async fn parent_hash_for_leaf<P: CipherSuiteProvider>(
117         &mut self,
118         cipher_suite_provider: &P,
119         index: LeafIndex,
120     ) -> Result<ParentHash, MlsError> {
121         let mut hash = ParentHash::empty();
122 
123         for node in self.nodes.direct_copath(index).into_iter().rev() {
124             if self.nodes.is_resolution_empty(node.copath) {
125                 continue;
126             }
127 
128             let parent = self.nodes.borrow_as_parent_mut(node.path)?;
129 
130             let calculated = ParentHash::new(
131                 cipher_suite_provider,
132                 &parent.public_key,
133                 &hash,
134                 &self.tree_hashes.current[node.copath as usize],
135             )
136             .await?;
137 
138             (parent.parent_hash, hash) = (hash, calculated);
139         }
140 
141         Ok(hash)
142     }
143 
144     // Updates all of the required parent hash values, and returns the calculated parent hash value for the leaf node
145     // If an update path is provided, additionally verify that the calculated parent hash matches
146     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
update_parent_hashes<P: CipherSuiteProvider>( &mut self, index: LeafIndex, verify_leaf_hash: bool, cipher_suite_provider: &P, ) -> Result<(), MlsError>147     pub(crate) async fn update_parent_hashes<P: CipherSuiteProvider>(
148         &mut self,
149         index: LeafIndex,
150         verify_leaf_hash: bool,
151         cipher_suite_provider: &P,
152     ) -> Result<(), MlsError> {
153         // First update the relevant original hashes used for parent hash computation.
154         self.update_hashes(&[index], cipher_suite_provider).await?;
155 
156         let leaf_hash = self
157             .parent_hash_for_leaf(cipher_suite_provider, index)
158             .await?;
159 
160         let leaf = self.nodes.borrow_as_leaf_mut(index)?;
161 
162         if verify_leaf_hash {
163             // Verify the parent hash of the new sender leaf node and update the parent hash values
164             // in the local tree
165             if let LeafNodeSource::Commit(parent_hash) = &leaf.leaf_node_source {
166                 if !leaf_hash.matches(parent_hash) {
167                     return Err(MlsError::ParentHashMismatch);
168                 }
169             } else {
170                 return Err(MlsError::InvalidLeafNodeSource);
171             }
172         } else {
173             leaf.leaf_node_source = LeafNodeSource::Commit(leaf_hash);
174         }
175 
176         // Update hashes after changes to the tree.
177         self.update_hashes(&[index], cipher_suite_provider).await
178     }
179 
180     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_parent_hashes<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, ) -> Result<(), MlsError>181     pub(super) async fn validate_parent_hashes<P: CipherSuiteProvider>(
182         &self,
183         cipher_suite_provider: &P,
184     ) -> Result<(), MlsError> {
185         let original_hashes = self.compute_original_hashes(cipher_suite_provider).await?;
186 
187         let nodes_to_validate = self
188             .nodes
189             .non_empty_parents()
190             .map(|(node_index, _)| node_index);
191 
192         #[cfg(feature = "std")]
193         let mut nodes_to_validate = nodes_to_validate.collect::<HashSet<_>>();
194         #[cfg(not(feature = "std"))]
195         let mut nodes_to_validate = nodes_to_validate.collect::<BTreeSet<_>>();
196 
197         let num_leaves = self.total_leaf_count();
198 
199         // For each leaf l, validate all non-blank nodes on the chain from l up the tree.
200         for (leaf_index, _) in self.nodes.non_empty_leaves() {
201             let mut n = NodeIndex::from(leaf_index);
202 
203             while let Some(mut ps) = n.parent_sibling(&num_leaves) {
204                 // Find the first non-blank ancestor p of n and p's co-path child s.
205                 while self.nodes.is_blank(ps.parent)? {
206                     // If we reached the root, we're done with this chain.
207                     let Some(ps_parent) = ps.parent.parent_sibling(&num_leaves) else {
208                         return Ok(());
209                     };
210 
211                     ps = ps_parent;
212                 }
213 
214                 // Check is n's parent_hash field matches the parent hash of p with co-path child s.
215                 let p_parent = self.nodes.borrow_as_parent(ps.parent)?;
216 
217                 let n_node = self
218                     .nodes
219                     .borrow_node(n)?
220                     .as_ref()
221                     .ok_or(MlsError::ExpectedNode)?;
222 
223                 let calculated = ParentHash::new(
224                     cipher_suite_provider,
225                     &p_parent.public_key,
226                     &p_parent.parent_hash,
227                     &original_hashes[ps.sibling as usize],
228                 )
229                 .await?;
230 
231                 if n_node.get_parent_hash() == Some(calculated) {
232                     // Check that "n is in the resolution of c, and the intersection of p's unmerged_leaves with the subtree
233                     // under c is equal to the resolution of c with n removed".
234                     let Some(cp) = ps.sibling.parent_sibling(&num_leaves) else {
235                         return Err(MlsError::ParentHashMismatch);
236                     };
237 
238                     let c = cp.sibling;
239                     let c_resolution = self.nodes.get_resolution_index(c)?.into_iter();
240 
241                     #[cfg(feature = "std")]
242                     let mut c_resolution = c_resolution.collect::<HashSet<_>>();
243                     #[cfg(not(feature = "std"))]
244                     let mut c_resolution = c_resolution.collect::<BTreeSet<_>>();
245 
246                     let p_unmerged_in_c_subtree = self
247                         .unmerged_in_subtree(ps.parent, c)?
248                         .iter()
249                         .copied()
250                         .map(|x| *x * 2);
251 
252                     #[cfg(feature = "std")]
253                     let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<HashSet<_>>();
254                     #[cfg(not(feature = "std"))]
255                     let p_unmerged_in_c_subtree = p_unmerged_in_c_subtree.collect::<BTreeSet<_>>();
256 
257                     if c_resolution.remove(&n)
258                         && c_resolution == p_unmerged_in_c_subtree
259                         && nodes_to_validate.remove(&ps.parent)
260                     {
261                         // If n's parent_hash field matches and p has not been validated yet, mark p as validated and continue.
262                         n = ps.parent;
263                     } else {
264                         // If p is validated for the second time, the check fails ("all non-blank parent nodes are covered by exactly one such chain").
265                         return Err(MlsError::ParentHashMismatch);
266                     }
267                 } else {
268                     // If n's parent_hash field doesn't match, we're done with this chain.
269                     break;
270                 }
271             }
272         }
273 
274         // The check passes iff all non-blank nodes are validated.
275         if nodes_to_validate.is_empty() {
276             Ok(())
277         } else {
278             Err(MlsError::ParentHashMismatch)
279         }
280     }
281 }
282 
283 #[cfg(test)]
284 pub(crate) mod test_utils {
285 
286     use super::*;
287     use crate::{
288         cipher_suite::CipherSuite,
289         crypto::test_utils::test_cipher_suite_provider,
290         identity::basic::BasicIdentityProvider,
291         tree_kem::{leaf_node::test_utils::get_basic_test_node, node::Parent},
292     };
293 
294     use alloc::vec;
295 
296     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_parent( cipher_suite: CipherSuite, unmerged_leaves: Vec<LeafIndex>, ) -> Parent297     pub(crate) async fn test_parent(
298         cipher_suite: CipherSuite,
299         unmerged_leaves: Vec<LeafIndex>,
300     ) -> Parent {
301         let (_, public_key) = test_cipher_suite_provider(cipher_suite)
302             .kem_generate()
303             .await
304             .unwrap();
305 
306         Parent {
307             public_key,
308             parent_hash: ParentHash::empty(),
309             unmerged_leaves,
310         }
311     }
312 
313     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_parent_node( cipher_suite: CipherSuite, unmerged_leaves: Vec<LeafIndex>, ) -> Node314     pub(crate) async fn test_parent_node(
315         cipher_suite: CipherSuite,
316         unmerged_leaves: Vec<LeafIndex>,
317     ) -> Node {
318         Node::Parent(test_parent(cipher_suite, unmerged_leaves).await)
319     }
320 
321     // Create figure 12 from MLS RFC
322     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_test_tree_fig_12(cipher_suite: CipherSuite) -> TreeKemPublic323     pub(crate) async fn get_test_tree_fig_12(cipher_suite: CipherSuite) -> TreeKemPublic {
324         let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
325 
326         let mut tree = TreeKemPublic::new();
327 
328         let mut leaves = Vec::new();
329 
330         for l in ["A", "B", "C", "D", "E", "F", "G"] {
331             leaves.push(get_basic_test_node(cipher_suite, l).await);
332         }
333 
334         tree.add_leaves(leaves, &BasicIdentityProvider, &cipher_suite_provider)
335             .await
336             .unwrap();
337 
338         tree.nodes[1] = Some(test_parent_node(cipher_suite, vec![]).await);
339         tree.nodes[3] = Some(test_parent_node(cipher_suite, vec![LeafIndex(3)]).await);
340 
341         tree.nodes[7] =
342             Some(test_parent_node(cipher_suite, vec![LeafIndex(3), LeafIndex(6)]).await);
343 
344         tree.nodes[9] = Some(test_parent_node(cipher_suite, vec![LeafIndex(5)]).await);
345 
346         tree.nodes[11] =
347             Some(test_parent_node(cipher_suite, vec![LeafIndex(5), LeafIndex(6)]).await);
348 
349         tree.update_parent_hashes(LeafIndex(0), false, &cipher_suite_provider)
350             .await
351             .unwrap();
352 
353         tree.update_parent_hashes(LeafIndex(4), false, &cipher_suite_provider)
354             .await
355             .unwrap();
356 
357         tree
358     }
359 }
360 
361 #[cfg(test)]
362 mod tests {
363     use super::*;
364     use crate::client::test_utils::TEST_CIPHER_SUITE;
365     use crate::crypto::test_utils::test_cipher_suite_provider;
366     use crate::tree_kem::leaf_node::test_utils::get_basic_test_node;
367     use crate::tree_kem::leaf_node::LeafNodeSource;
368     use crate::tree_kem::test_utils::TreeWithSigners;
369     use crate::tree_kem::MlsError;
370     use assert_matches::assert_matches;
371 
372     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_missing_parent_hash()373     async fn test_missing_parent_hash() {
374         let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
375         let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
376 
377         *test_tree.nodes.borrow_as_leaf_mut(LeafIndex(0)).unwrap() =
378             get_basic_test_node(TEST_CIPHER_SUITE, "foo").await;
379 
380         let missing_parent_hash_res = test_tree
381             .update_parent_hashes(
382                 LeafIndex(0),
383                 true,
384                 &test_cipher_suite_provider(TEST_CIPHER_SUITE),
385             )
386             .await;
387 
388         assert_matches!(
389             missing_parent_hash_res,
390             Err(MlsError::InvalidLeafNodeSource)
391         );
392     }
393 
394     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_parent_hash_mismatch()395     async fn test_parent_hash_mismatch() {
396         let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
397         let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
398 
399         let unexpected_parent_hash = ParentHash::from(hex!("f00d"));
400 
401         test_tree
402             .nodes
403             .borrow_as_leaf_mut(LeafIndex(0))
404             .unwrap()
405             .leaf_node_source = LeafNodeSource::Commit(unexpected_parent_hash);
406 
407         let invalid_parent_hash_res = test_tree
408             .update_parent_hashes(
409                 LeafIndex(0),
410                 true,
411                 &test_cipher_suite_provider(TEST_CIPHER_SUITE),
412             )
413             .await;
414 
415         assert_matches!(invalid_parent_hash_res, Err(MlsError::ParentHashMismatch));
416     }
417 
418     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_parent_hash_invalid()419     async fn test_parent_hash_invalid() {
420         let cs = test_cipher_suite_provider(TEST_CIPHER_SUITE);
421         let mut test_tree = TreeWithSigners::make_full_tree(8, &cs).await.tree;
422 
423         test_tree.nodes[2] = None;
424 
425         let res = test_tree
426             .validate_parent_hashes(&test_cipher_suite_provider(TEST_CIPHER_SUITE))
427             .await;
428 
429         assert_matches!(res, Err(MlsError::ParentHashMismatch));
430     }
431 }
432