1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5 #[cfg(feature = "std")]
6 use std::collections::HashSet;
7
8 #[cfg(not(feature = "std"))]
9 use alloc::{vec, vec::Vec};
10 use tree_math::TreeIndex;
11
12 use super::node::{Node, NodeIndex};
13 use crate::client::MlsError;
14 use crate::crypto::CipherSuiteProvider;
15 use crate::group::GroupContext;
16 use crate::iter::wrap_impl_iter;
17 use crate::tree_kem::math as tree_math;
18 use crate::tree_kem::{leaf_node_validator::LeafNodeValidator, TreeKemPublic};
19 use mls_rs_core::identity::IdentityProvider;
20
21 #[cfg(all(not(mls_build_async), feature = "rayon"))]
22 use rayon::prelude::*;
23
24 #[cfg(mls_build_async)]
25 use futures::{StreamExt, TryStreamExt};
26
27 pub(crate) struct TreeValidator<'a, C, CSP>
28 where
29 C: IdentityProvider,
30 CSP: CipherSuiteProvider,
31 {
32 expected_tree_hash: &'a [u8],
33 leaf_node_validator: LeafNodeValidator<'a, C, CSP>,
34 group_id: &'a [u8],
35 cipher_suite_provider: &'a CSP,
36 }
37
38 impl<'a, C: IdentityProvider, CSP: CipherSuiteProvider> TreeValidator<'a, C, CSP> {
new( cipher_suite_provider: &'a CSP, context: &'a GroupContext, identity_provider: &'a C, ) -> Self39 pub fn new(
40 cipher_suite_provider: &'a CSP,
41 context: &'a GroupContext,
42 identity_provider: &'a C,
43 ) -> Self {
44 TreeValidator {
45 expected_tree_hash: &context.tree_hash,
46 leaf_node_validator: LeafNodeValidator::new(
47 cipher_suite_provider,
48 identity_provider,
49 Some(&context.extensions),
50 ),
51 group_id: &context.group_id,
52 cipher_suite_provider,
53 }
54 }
55
56 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError>57 pub async fn validate(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError> {
58 self.validate_tree_hash(tree).await?;
59
60 tree.validate_parent_hashes(self.cipher_suite_provider)
61 .await?;
62
63 self.validate_no_trailing_blanks(tree)?;
64 self.validate_leaves(tree).await?;
65 validate_unmerged(tree)
66 }
67
validate_no_trailing_blanks(&self, tree: &TreeKemPublic) -> Result<(), MlsError>68 fn validate_no_trailing_blanks(&self, tree: &TreeKemPublic) -> Result<(), MlsError> {
69 tree.nodes
70 .last()
71 .ok_or(MlsError::UnexpectedEmptyTree)?
72 .is_some()
73 .then_some(())
74 .ok_or(MlsError::UnexpectedTrailingBlanks)
75 }
76
77 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_tree_hash(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError>78 async fn validate_tree_hash(&self, tree: &mut TreeKemPublic) -> Result<(), MlsError> {
79 //Verify that the tree hash of the ratchet tree matches the tree_hash field in the GroupInfo.
80 let tree_hash = tree.tree_hash(self.cipher_suite_provider).await?;
81
82 if tree_hash != self.expected_tree_hash {
83 return Err(MlsError::TreeHashMismatch);
84 }
85
86 Ok(())
87 }
88
89 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
validate_leaves(&self, tree: &TreeKemPublic) -> Result<(), MlsError>90 async fn validate_leaves(&self, tree: &TreeKemPublic) -> Result<(), MlsError> {
91 let leaves = wrap_impl_iter(tree.nodes.non_empty_leaves());
92
93 #[cfg(mls_build_async)]
94 let leaves = leaves.map(Ok);
95
96 { leaves }
97 .try_for_each(|(index, leaf_node)| async move {
98 self.leaf_node_validator
99 .revalidate(leaf_node, self.group_id, *index)
100 .await
101 })
102 .await
103 }
104 }
105
validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError>106 fn validate_unmerged(tree: &TreeKemPublic) -> Result<(), MlsError> {
107 let unmerged_sets = tree.nodes.iter().map(|n| {
108 #[cfg(feature = "std")]
109 if let Some(Node::Parent(p)) = n {
110 HashSet::from_iter(p.unmerged_leaves.iter().cloned())
111 } else {
112 HashSet::new()
113 }
114
115 #[cfg(not(feature = "std"))]
116 if let Some(Node::Parent(p)) = n {
117 p.unmerged_leaves.clone()
118 } else {
119 vec![]
120 }
121 });
122
123 let mut unmerged_sets = unmerged_sets.collect::<Vec<_>>();
124
125 // For each leaf L, we search for the longest prefix P[1], P[2], ..., P[k] of the direct path of L
126 // such that for each i=1..k, either L is in the unmerged leaves of P[i], or P[i] is blank. We will
127 // then check that L is unmerged at each P[1], ..., P[k] and no other node.
128 let leaf_count = tree.total_leaf_count();
129
130 for (index, _) in tree.nodes.non_empty_leaves() {
131 let mut n = NodeIndex::from(index);
132
133 while let Some(ps) = n.parent_sibling(&leaf_count) {
134 if tree.nodes.is_blank(ps.parent)? {
135 n = ps.parent;
136 continue;
137 }
138
139 let parent_node = tree.nodes.borrow_as_parent(ps.parent)?;
140
141 if parent_node.unmerged_leaves.contains(&index) {
142 unmerged_sets[ps.parent as usize].retain(|i| i != &index);
143
144 n = ps.parent;
145 } else {
146 break;
147 }
148 }
149 }
150
151 let unmerged_sets = unmerged_sets.iter().all(|set| set.is_empty());
152
153 unmerged_sets
154 .then_some(())
155 .ok_or(MlsError::UnmergedLeavesMismatch)
156 }
157
158 #[cfg(test)]
159 mod tests {
160 use alloc::vec;
161 use assert_matches::assert_matches;
162
163 use super::*;
164 use crate::{
165 cipher_suite::CipherSuite,
166 client::test_utils::TEST_CIPHER_SUITE,
167 crypto::test_utils::test_cipher_suite_provider,
168 crypto::test_utils::TestCryptoProvider,
169 group::test_utils::{get_test_group_context, random_bytes},
170 identity::basic::BasicIdentityProvider,
171 tree_kem::{
172 kem::TreeKem,
173 leaf_node::test_utils::{default_properties, get_basic_test_node},
174 node::{LeafIndex, Node, Parent},
175 parent_hash::{test_utils::get_test_tree_fig_12, ParentHash},
176 test_utils::get_test_tree,
177 },
178 };
179
180 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
test_parent_node(cipher_suite: CipherSuite) -> Parent181 async fn test_parent_node(cipher_suite: CipherSuite) -> Parent {
182 let (_, public_key) = test_cipher_suite_provider(cipher_suite)
183 .kem_generate()
184 .await
185 .unwrap();
186
187 Parent {
188 public_key,
189 parent_hash: ParentHash::empty(),
190 unmerged_leaves: vec![],
191 }
192 }
193
194 #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
get_valid_tree(cipher_suite: CipherSuite) -> TreeKemPublic195 async fn get_valid_tree(cipher_suite: CipherSuite) -> TreeKemPublic {
196 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
197
198 let mut test_tree = get_test_tree(cipher_suite).await;
199
200 let leaf1 = get_basic_test_node(cipher_suite, "leaf1").await;
201 let leaf2 = get_basic_test_node(cipher_suite, "leaf2").await;
202
203 test_tree
204 .public
205 .add_leaves(
206 vec![leaf1, leaf2],
207 &BasicIdentityProvider,
208 &cipher_suite_provider,
209 )
210 .await
211 .unwrap();
212
213 test_tree.public.nodes[1] = Some(Node::Parent(test_parent_node(cipher_suite).await));
214 test_tree.public.nodes[3] = Some(Node::Parent(test_parent_node(cipher_suite).await));
215
216 TreeKem::new(&mut test_tree.public, &mut test_tree.private)
217 .encap(
218 &mut get_test_group_context(42, cipher_suite).await,
219 &[LeafIndex(1), LeafIndex(2)],
220 &test_tree.creator_signing_key,
221 default_properties(),
222 None,
223 &cipher_suite_provider,
224 #[cfg(test)]
225 &Default::default(),
226 )
227 .await
228 .unwrap();
229
230 test_tree.public
231 }
232
233 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_valid_tree()234 async fn test_valid_tree() {
235 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
236 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
237
238 let mut test_tree = get_valid_tree(cipher_suite).await;
239
240 let mut context = get_test_group_context(1, cipher_suite).await;
241 context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
242
243 let validator =
244 TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
245
246 validator.validate(&mut test_tree).await.unwrap();
247 }
248 }
249
250 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_tree_hash_mismatch()251 async fn test_tree_hash_mismatch() {
252 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
253 let mut test_tree = get_valid_tree(cipher_suite).await;
254
255 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
256 let context = get_test_group_context(1, cipher_suite).await;
257
258 let validator =
259 TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
260
261 let res = validator.validate(&mut test_tree).await;
262
263 assert_matches!(res, Err(MlsError::TreeHashMismatch));
264 }
265 }
266
267 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_parent_hash_mismatch()268 async fn test_parent_hash_mismatch() {
269 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
270 let mut test_tree = get_valid_tree(cipher_suite).await;
271
272 let parent_node = test_tree.nodes.borrow_as_parent_mut(1).unwrap();
273 parent_node.parent_hash = ParentHash::from(random_bytes(32));
274
275 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
276 let mut context = get_test_group_context(1, cipher_suite).await;
277 context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
278
279 let validator =
280 TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
281
282 let res = validator.validate(&mut test_tree).await;
283
284 assert_matches!(res, Err(MlsError::ParentHashMismatch));
285 }
286 }
287
288 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_key_package_validation_failure()289 async fn test_key_package_validation_failure() {
290 for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
291 let mut test_tree = get_valid_tree(cipher_suite).await;
292
293 test_tree
294 .nodes
295 .borrow_as_leaf_mut(LeafIndex(0))
296 .unwrap()
297 .signature = random_bytes(32);
298
299 let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
300 let mut context = get_test_group_context(1, cipher_suite).await;
301 context.tree_hash = test_tree.tree_hash(&cipher_suite_provider).await.unwrap();
302
303 let validator =
304 TreeValidator::new(&cipher_suite_provider, &context, &BasicIdentityProvider);
305
306 let res = validator.validate(&mut test_tree).await;
307
308 assert_matches!(res, Err(MlsError::InvalidSignature));
309 }
310 }
311
312 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
verify_unmerged_with_correct_tree()313 async fn verify_unmerged_with_correct_tree() {
314 let tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
315 validate_unmerged(&tree).unwrap();
316 }
317
318 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
verify_unmerged_with_blank_leaf()319 async fn verify_unmerged_with_blank_leaf() {
320 let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
321
322 // Blank leaf D unmerged at nodes 3, 7
323 tree.nodes[6] = None;
324
325 assert_matches!(
326 validate_unmerged(&tree),
327 Err(MlsError::UnmergedLeavesMismatch)
328 );
329 }
330
331 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
verify_unmerged_with_broken_path()332 async fn verify_unmerged_with_broken_path() {
333 let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
334
335 // Make D with direct path [3, 7] unmerged at 7 but not 3
336 tree.nodes.borrow_as_parent_mut(3).unwrap().unmerged_leaves = vec![];
337
338 assert_matches!(
339 validate_unmerged(&tree),
340 Err(MlsError::UnmergedLeavesMismatch)
341 );
342 }
343
344 #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
verify_unmerged_with_leaf_outside_tree()345 async fn verify_unmerged_with_leaf_outside_tree() {
346 let mut tree = get_test_tree_fig_12(TEST_CIPHER_SUITE).await;
347
348 // Add leaf E from the right subtree of the root to unmerged leaves of node 1 on the left
349 tree.nodes.borrow_as_parent_mut(1).unwrap().unmerged_leaves = vec![LeafIndex(4)];
350
351 assert_matches!(
352 validate_unmerged(&tree),
353 Err(MlsError::UnmergedLeavesMismatch)
354 );
355 }
356 }
357