1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 #[cfg(feature = "std")]
6 use std::collections::HashSet;
7 
8 #[cfg(not(feature = "std"))]
9 use alloc::{vec, vec::Vec};
10 use tree_math::TreeIndex;
11 
12 use super::node::{Node, NodeIndex};
13 use crate::client::MlsError;
14 use crate::crypto::CipherSuiteProvider;
15 use crate::group::GroupContext;
16 use crate::iter::wrap_impl_iter;
17 use crate::tree_kem::math as tree_math;
18 use crate::tree_kem::{leaf_node_validator::LeafNodeValidator, TreeKemPublic};
19 use mls_rs_core::identity::IdentityProvider;
20 
21 #[cfg(all(not(mls_build_async), feature = "rayon"))]
22 use rayon::prelude::*;
23 
24 #[cfg(mls_build_async)]
25 use futures::{StreamExt, TryStreamExt};
26 
27 pub(crate) struct TreeValidator<'a, C, CSP>
28 where
29     C: IdentityProvider,
30     CSP: CipherSuiteProvider,
31 {
32     expected_tree_hash: &'a [u8],
33     leaf_node_validator: LeafNodeValidator<'a, C, CSP>,
34     group_id: &'a [u8],
35     cipher_suite_provider: &'a CSP,
36 }
37 
38 impl<'a, C: IdentityProvider, CSP: CipherSuiteProvider> TreeValidator<'a, C, CSP> {
new( cipher_suite_provider: &'a CSP, context: &'a GroupContext, identity_provider: &'a C, ) -> Self39     pub fn new(
40         cipher_suite_provider: &'a CSP,
41         context: &'a GroupContext,
42         identity_provider: &'a C,
43     ) -> Self {
44         TreeValidator {
45             expected_tree_hash: &context.tree_hash,
46             leaf_node_validator: LeafNodeValidator::new(
47                 cipher_suite_provider,
48                 identity_provider,
49                 Some(&context.extensions),
50             ),
51             group_id: &context.group_id,
52             cipher_suite_provider,
53         }
54     }
55 
56     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError>57     pub async fn validate(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError> {
58         self.validate_tree_hash(tree).await?;
59 
60         tree.validate_parent_hashes(self.cipher_suite_provider)
61             .await?;
62 
63         self.validate_no_trailing_blanks(tree)?;
64         self.validate_leaves(tree).await?;
65         validate_unmerged(tree)
66     }
67 
validate_no_trailing_blanks(&self, tree: &TreeKemPublic) -> Result<(), MlsError>68     fn validate_no_trailing_blanks(&self, tree: &TreeKemPublic) -> Result<(), MlsError> {
69         tree.nodes
70             .last()
71             .ok_or(MlsError::UnexpectedEmptyTree)?
72             .is_some()
73             .then_some(())
74             .ok_or(MlsError::UnexpectedTrailingBlanks)
75     }
76 
77     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_tree_hash(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError>78     async fn validate_tree_hash(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError> {
79         //Verify that the tree hash of the ratchet tree matches the tree_hash field in the GroupInfo.
80         let tree_hash = tree.tree_hash(self.cipher_suite_provider).await?;
81 
82         if tree_hash != self.expected_tree_hash {
83             return Err(MlsError::TreeHashMismatch);
84         }
85 
86         Ok(())
87     }
88 
89     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_leaves(&self, tree: &TreeKemPublic) -> Result<(), MlsError>90     async fn validate_leaves(&self, tree: &TreeKemPublic) -> Result<(), MlsError> {
91         let leaves = wrap_impl_iter(tree.nodes.non_empty_leaves());
92 
93         #[cfg(mls_build_async)]
94         let leaves = leaves.map(Ok);
95 
96         { leaves }
97             .try_for_each(|(index, leaf_node)| async move {
98                 self.leaf_node_validator
99                     .revalidate(leaf_node, self.group_id, *index)
100                     .await
101             })
102             .await
103     }
104 }
105 
validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError>106 fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> {
107     let unmerged_sets = tree.nodes.iter().map(|n| {
108         #[cfg(feature = "std")]
109         if let Some(Node::Parent(p)) = n {
110             HashSet::from_iter(p.unmerged_leaves.iter().cloned())
111         } else {
112             HashSet::new()
113         }
114 
115         #[cfg(not(feature = "std"))]
116         if let Some(Node::Parent(p)) = n {
117             p.unmerged_leaves.clone()
118         } else {
119             vec![]
120         }
121     });
122 
123     let mut unmerged_sets = unmerged_sets.collect::<Vec<_>>();
124 
125     // For each leaf L, we search for the longest prefix P[1], P[2], ..., P[k] of the direct path of L
126     // such that for each i=1..k, either L is in the unmerged leaves of P[i], or P[i] is blank. We will
127     // then check that L is unmerged at each P[1], ..., P[k] and no other node.
128     let leaf_count = tree.total_leaf_count();
129 
130     for (index, _) in tree.nodes.non_empty_leaves() {
131         let mut n = NodeIndex::from(index);
132 
133         while let Some(ps) = n.parent_sibling(&leaf_count) {
134             if tree.nodes.is_blank(ps.parent)? {
135                 n = ps.parent;
136                 continue;
137             }
138 
139             let parent_node = tree.nodes.borrow_as_parent(ps.parent)?;
140 
141             if parent_node.unmerged_leaves.contains(&index) {
142                 unmerged_sets[ps.parent as usize].retain(|i| i != &index);
143 
144                 n = ps.parent;
145             } else {
146                 break;
147             }
148         }
149     }
150 
151     let unmerged_sets = unmerged_sets.iter().all(|set| set.is_empty());
152 
153     unmerged_sets
154         .then_some(())
155         .ok_or(MlsError::UnmergedLeavesMismatch)
156 }
157 
158 #[cfg(test)]
159 mod tests {
160     use alloc::vec;
161     use assert_matches::assert_matches;
162 
163     use super::*;
164     use crate::{
165         cipher_suite::CipherSuite,
166         client::test_utils::TEST_CIPHER_SUITE,
167         crypto::test_utils::test_cipher_suite_provider,
168         crypto::test_utils::TestCryptoProvider,
169         group::test_utils::{get_test_group_context, random_bytes},
170         identity::basic::BasicIdentityProvider,
171         tree_kem::{
172             kem::TreeKem,
173             leaf_node::test_utils::{default_properties, get_basic_test_node},
174             node::{LeafIndex, Node, Parent},
175             parent_hash::{test_utils::get_test_tree_fig_12, ParentHash},
176             test_utils::get_test_tree,
177         },
178     };
179 
180     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_parent_node(cipher_suite: CipherSuite) -> Parent181     async fn test_parent_node(cipher_suite: CipherSuite) -> Parent {
182         let (_, public_key) = test_cipher_suite_provider(cipher_suite)
183             .kem_generate()
184             .await
185             .unwrap();
186 
187         Parent {
188             public_key,
189             parent_hash: ParentHash::empty(),
190             unmerged_leaves: vec![],
191         }
192     }
193 
194     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_valid_tree(cipher_suite: CipherSuite) -> TreeKemPublic195     async fn get_valid_tree(cipher_suite: CipherSuite) -> TreeKemPublic {
196         let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
197 
198         let mut test_tree = get_test_tree(cipher_suite).await;
199 
200         let leaf1 = get_basic_test_node(cipher_suite, "leaf1").await;
201         let leaf2 = get_basic_test_node(cipher_suite, "leaf2").await;
202 
203         test_tree
204             .public
205             .add_leaves(
206                 vec![leaf1, leaf2],
207                 &BasicIdentityProvider,
208                 &cipher_suite_provider,
209             )
210             .await
211             .unwrap();
212 
213         test_tree.public.nodes[1] = Some(Node::Parent(test_parent_node(cipher_suite).await));
214         test_tree.public.nodes[3] = Some(Node::Parent(test_parent_node(cipher_suite).await));
215 
216         TreeKem::new(&mut test_tree.public, &mut test_tree.private)
217             .encap(
218                 &mut get_test_group_context(42, cipher_suite).await,
219                 &[LeafIndex(1), LeafIndex(2)],
220                 &test_tree.creator_signing_key,
221                 default_properties(),
222                 None,
223                 &cipher_suite_provider,
224                 #[cfg(test)]
225                 &Default::default(),
226             )
227             .await
228             .unwrap();
229 
230         test_tree.public
231     }
232 
233     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_valid_tree()234     async fn test_valid_tree() {
235         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
236             let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
237 
238             let mut test_tree = get_valid_tree(cipher_suite).await;
239 
240             let mut context = get_test_group_context(1, cipher_suite).await;
241             context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
242 
243             let validator =
244                 TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
245 
246             validator.validate(&mut test_tree).await.unwrap();
247         }
248     }
249 
250     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_tree_hash_mismatch()251     async fn test_tree_hash_mismatch() {
252         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
253             let mut test_tree = get_valid_tree(cipher_suite).await;
254 
255             let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
256             let context = get_test_group_context(1, cipher_suite).await;
257 
258             let validator =
259                 TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
260 
261             let res = validator.validate(&mut test_tree).await;
262 
263             assert_matches!(res, Err(MlsError::TreeHashMismatch));
264         }
265     }
266 
267     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_parent_hash_mismatch()268     async fn test_parent_hash_mismatch() {
269         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
270             let mut test_tree = get_valid_tree(cipher_suite).await;
271 
272             let parent_node = test_tree.nodes.borrow_as_parent_mut(1).unwrap();
273             parent_node.parent_hash = ParentHash::from(random_bytes(32));
274 
275             let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
276             let mut context = get_test_group_context(1, cipher_suite).await;
277             context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
278 
279             let validator =
280                 TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
281 
282             let res = validator.validate(&mut test_tree).await;
283 
284             assert_matches!(res, Err(MlsError::ParentHashMismatch));
285         }
286     }
287 
288     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_key_package_validation_failure()289     async fn test_key_package_validation_failure() {
290         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
291             let mut test_tree = get_valid_tree(cipher_suite).await;
292 
293             test_tree
294                 .nodes
295                 .borrow_as_leaf_mut(LeafIndex(0))
296                 .unwrap()
297                 .signature = random_bytes(32);
298 
299             let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
300             let mut context = get_test_group_context(1, cipher_suite).await;
301             context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
302 
303             let validator =
304                 TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
305 
306             let res = validator.validate(&mut test_tree).await;
307 
308             assert_matches!(res, Err(MlsError::InvalidSignature));
309         }
310     }
311 
312     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
verify_unmerged_with_correct_tree()313     async fn verify_unmerged_with_correct_tree() {
314         let tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
315         validate_unmerged(&tree).unwrap();
316     }
317 
318     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
verify_unmerged_with_blank_leaf()319     async fn verify_unmerged_with_blank_leaf() {
320         let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
321 
322         // Blank leaf D unmerged at nodes 3, 7
323         tree.nodes[6] = None;
324 
325         assert_matches!(
326             validate_unmerged(&tree),
327             Err(MlsError::UnmergedLeavesMismatch)
328         );
329     }
330 
331     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
verify_unmerged_with_broken_path()332     async fn verify_unmerged_with_broken_path() {
333         let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
334 
335         // Make D with direct path [3, 7] unmerged at 7 but not 3
336         tree.nodes.borrow_as_parent_mut(3).unwrap().unmerged_leaves = vec![];
337 
338         assert_matches!(
339             validate_unmerged(&tree),
340             Err(MlsError::UnmergedLeavesMismatch)
341         );
342     }
343 
344     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
verify_unmerged_with_leaf_outside_tree()345     async fn verify_unmerged_with_leaf_outside_tree() {
346         let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
347 
348         // Add leaf E from the right subtree of the root to unmerged leaves of node 1 on the left
349         tree.nodes.borrow_as_parent_mut(1).unwrap().unmerged_leaves = vec![LeafIndex(4)];
350 
351         assert_matches!(
352             validate_unmerged(&tree),
353             Err(MlsError::UnmergedLeavesMismatch)
354         );
355     }
356 }
357