1 // Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2 // Copyright by contributors to this project.
3 // SPDX-License-Identifier: (Apache-2.0 OR MIT)
4 
5 use crate::client::MlsError;
6 use crate::crypto::{CipherSuiteProvider, SignatureSecretKey};
7 use crate::group::GroupContext;
8 use crate::identity::SigningIdentity;
9 use crate::iter::wrap_iter;
10 use crate::tree_kem::math as tree_math;
11 use alloc::vec;
12 use alloc::vec::Vec;
13 use itertools::Itertools;
14 use mls_rs_codec::MlsEncode;
15 use tree_math::{CopathNode, TreeIndex};
16 
17 #[cfg(all(not(mls_build_async), feature = "rayon"))]
18 use {crate::iter::ParallelIteratorExt, rayon::prelude::*};
19 
20 #[cfg(mls_build_async)]
21 use futures::{StreamExt, TryStreamExt};
22 
23 #[cfg(feature = "std")]
24 use std::collections::HashSet;
25 
26 use super::hpke_encryption::HpkeEncryptable;
27 use super::leaf_node::ConfigProperties;
28 use super::node::NodeTypeResolver;
29 use super::{
30     node::{LeafIndex, NodeIndex},
31     path_secret::{PathSecret, PathSecretGenerator},
32     TreeKemPrivate, TreeKemPublic, UpdatePath, UpdatePathNode, ValidatedUpdatePath,
33 };
34 
35 #[cfg(test)]
36 use crate::{group::CommitModifiers, signer::Signable};
37 
38 pub struct TreeKem<'a> {
39     tree_kem_public: &'a mut TreeKemPublic,
40     private_key: &'a mut TreeKemPrivate,
41 }
42 
43 pub struct EncapGeneration {
44     pub update_path: UpdatePath,
45     pub path_secrets: Vec<Option<PathSecret>>,
46     pub commit_secret: PathSecret,
47 }
48 
49 impl<'a> TreeKem<'a> {
new( tree_kem_public: &'a mut TreeKemPublic, private_key: &'a mut TreeKemPrivate, ) -> Self50     pub fn new(
51         tree_kem_public: &'a mut TreeKemPublic,
52         private_key: &'a mut TreeKemPrivate,
53     ) -> Self {
54         TreeKem {
55             tree_kem_public,
56             private_key,
57         }
58     }
59 
60     #[allow(clippy::too_many_arguments)]
61     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
encap<P>( self, context: &mut GroupContext, excluding: &[LeafIndex], signer: &SignatureSecretKey, update_leaf_properties: ConfigProperties, signing_identity: Option<SigningIdentity>, cipher_suite_provider: &P, #[cfg(test)] commit_modifiers: &CommitModifiers, ) -> Result<EncapGeneration, MlsError> where P: CipherSuiteProvider + Send + Sync,62     pub async fn encap<P>(
63         self,
64         context: &mut GroupContext,
65         excluding: &[LeafIndex],
66         signer: &SignatureSecretKey,
67         update_leaf_properties: ConfigProperties,
68         signing_identity: Option<SigningIdentity>,
69         cipher_suite_provider: &P,
70         #[cfg(test)] commit_modifiers: &CommitModifiers,
71     ) -> Result<EncapGeneration, MlsError>
72     where
73         P: CipherSuiteProvider + Send + Sync,
74     {
75         let self_index = self.private_key.self_index;
76         let path = self.tree_kem_public.nodes.direct_copath(self_index);
77         let filtered = self.tree_kem_public.nodes.filtered(self_index)?;
78 
79         self.private_key.secret_keys.resize(path.len() + 1, None);
80 
81         let mut secret_generator = PathSecretGenerator::new(cipher_suite_provider);
82         let mut path_secrets = vec![];
83 
84         for (i, (node, f)) in path.iter().zip(&filtered).enumerate() {
85             if !f {
86                 let secret = secret_generator.next_secret().await?;
87 
88                 let (secret_key, public_key) =
89                     secret.to_hpke_key_pair(cipher_suite_provider).await?;
90 
91                 self.private_key.secret_keys[i + 1] = Some(secret_key);
92                 self.tree_kem_public.update_node(public_key, node.path)?;
93                 path_secrets.push(Some(secret));
94             } else {
95                 self.private_key.secret_keys[i + 1] = None;
96                 path_secrets.push(None);
97             }
98         }
99 
100         #[cfg(test)]
101         (commit_modifiers.modify_tree)(self.tree_kem_public);
102 
103         self.tree_kem_public
104             .update_parent_hashes(self_index, false, cipher_suite_provider)
105             .await?;
106 
107         let update_path_leaf = {
108             let own_leaf = self.tree_kem_public.nodes.borrow_as_leaf_mut(self_index)?;
109 
110             self.private_key.secret_keys[0] = Some(
111                 own_leaf
112                     .commit(
113                         cipher_suite_provider,
114                         &context.group_id,
115                         *self_index,
116                         update_leaf_properties,
117                         signing_identity,
118                         signer,
119                     )
120                     .await?,
121             );
122 
123             #[cfg(test)]
124             if let Some(signer) = (commit_modifiers.modify_leaf)(own_leaf, signer) {
125                 let context = &(context.group_id.as_slice(), *self_index).into();
126 
127                 own_leaf
128                     .sign(cipher_suite_provider, &signer, context)
129                     .await
130                     .unwrap();
131             }
132 
133             own_leaf.clone()
134         };
135 
136         // Tree modifications are all done so we can update the tree hash and encrypt with the new context
137         self.tree_kem_public
138             .update_hashes(&[self_index], cipher_suite_provider)
139             .await?;
140 
141         context.tree_hash = self
142             .tree_kem_public
143             .tree_hash(cipher_suite_provider)
144             .await?;
145 
146         let context_bytes = context.mls_encode_to_vec()?;
147 
148         let node_updates = self
149             .encrypt_path_secrets(
150                 path,
151                 &path_secrets,
152                 &context_bytes,
153                 cipher_suite_provider,
154                 excluding,
155             )
156             .await?;
157 
158         #[cfg(test)]
159         let node_updates = (commit_modifiers.modify_path)(node_updates);
160 
161         // Create an update path with the new node and parent node updates
162         let update_path = UpdatePath {
163             leaf_node: update_path_leaf,
164             nodes: node_updates,
165         };
166 
167         Ok(EncapGeneration {
168             update_path,
169             path_secrets,
170             commit_secret: secret_generator.next_secret().await?,
171         })
172     }
173 
174     #[cfg(any(mls_build_async, not(feature = "rayon")))]
175     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
encrypt_path_secrets<P: CipherSuiteProvider>( &self, path: Vec<CopathNode<NodeIndex>>, path_secrets: &[Option<PathSecret>], context_bytes: &[u8], cipher_suite: &P, excluding: &[LeafIndex], ) -> Result<Vec<UpdatePathNode>, MlsError>176     async fn encrypt_path_secrets<P: CipherSuiteProvider>(
177         &self,
178         path: Vec<CopathNode<NodeIndex>>,
179         path_secrets: &[Option<PathSecret>],
180         context_bytes: &[u8],
181         cipher_suite: &P,
182         excluding: &[LeafIndex],
183     ) -> Result<Vec<UpdatePathNode>, MlsError> {
184         let excluding = excluding.iter().copied().map(NodeIndex::from);
185 
186         #[cfg(feature = "std")]
187         let excluding = excluding.collect::<HashSet<NodeIndex>>();
188         #[cfg(not(feature = "std"))]
189         let excluding = excluding.collect::<Vec<NodeIndex>>();
190 
191         let mut node_updates = Vec::new();
192 
193         for (index, path_secret) in path.into_iter().zip(path_secrets.iter()) {
194             if let Some(path_secret) = path_secret {
195                 node_updates.push(
196                     self.encrypt_copath_node_resolution(
197                         cipher_suite,
198                         path_secret,
199                         index.copath,
200                         context_bytes,
201                         &excluding,
202                     )
203                     .await?,
204                 );
205             }
206         }
207 
208         Ok(node_updates)
209     }
210 
211     #[cfg(all(not(mls_build_async), feature = "rayon"))]
encrypt_path_secrets<P: CipherSuiteProvider>( &self, path: Vec<CopathNode<NodeIndex>>, path_secrets: &[Option<PathSecret>], context_bytes: &[u8], cipher_suite: &P, excluding: &[LeafIndex], ) -> Result<Vec<UpdatePathNode>, MlsError>212     fn encrypt_path_secrets<P: CipherSuiteProvider>(
213         &self,
214         path: Vec<CopathNode<NodeIndex>>,
215         path_secrets: &[Option<PathSecret>],
216         context_bytes: &[u8],
217         cipher_suite: &P,
218         excluding: &[LeafIndex],
219     ) -> Result<Vec<UpdatePathNode>, MlsError> {
220         let excluding = excluding.iter().copied().map(NodeIndex::from);
221 
222         #[cfg(feature = "std")]
223         let excluding = excluding.collect::<HashSet<NodeIndex>>();
224         #[cfg(not(feature = "std"))]
225         let excluding = excluding.collect::<Vec<NodeIndex>>();
226 
227         path.into_par_iter()
228             .zip(path_secrets.par_iter())
229             .filter_map(|(node, path_secret)| {
230                 path_secret.as_ref().map(|path_secret| {
231                     self.encrypt_copath_node_resolution(
232                         cipher_suite,
233                         path_secret,
234                         node.copath,
235                         context_bytes,
236                         &excluding,
237                     )
238                 })
239             })
240             .collect()
241     }
242 
243     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
decap<CP>( self, sender_index: LeafIndex, update_path: &ValidatedUpdatePath, added_leaves: &[LeafIndex], context_bytes: &[u8], cipher_suite_provider: &CP, ) -> Result<PathSecret, MlsError> where CP: CipherSuiteProvider,244     pub async fn decap<CP>(
245         self,
246         sender_index: LeafIndex,
247         update_path: &ValidatedUpdatePath,
248         added_leaves: &[LeafIndex],
249         context_bytes: &[u8],
250         cipher_suite_provider: &CP,
251     ) -> Result<PathSecret, MlsError>
252     where
253         CP: CipherSuiteProvider,
254     {
255         let self_index = self.private_key.self_index;
256 
257         let lca_index =
258             tree_math::leaf_lca_level(self_index.into(), sender_index.into()) as usize - 2;
259 
260         let mut path = self.tree_kem_public.nodes.direct_copath(self_index);
261         let leaf = CopathNode::new(self_index.into(), 0);
262         path.insert(0, leaf);
263         let resolved_pos = self.find_resolved_pos(&path, lca_index)?;
264 
265         let ct_pos =
266             self.find_ciphertext_pos(path[lca_index].path, path[resolved_pos].path, added_leaves)?;
267 
268         let lca_node = update_path.nodes[lca_index]
269             .as_ref()
270             .ok_or(MlsError::LcaNotFoundInDirectPath)?;
271 
272         let ct = lca_node
273             .encrypted_path_secret
274             .get(ct_pos)
275             .ok_or(MlsError::LcaNotFoundInDirectPath)?;
276 
277         let secret = self.private_key.secret_keys[resolved_pos]
278             .as_ref()
279             .ok_or(MlsError::UpdateErrorNoSecretKey)?;
280 
281         let public = self
282             .tree_kem_public
283             .nodes
284             .borrow_node(path[resolved_pos].path)?
285             .as_ref()
286             .ok_or(MlsError::UpdateErrorNoSecretKey)?
287             .public_key();
288 
289         let lca_path_secret =
290             PathSecret::decrypt(cipher_suite_provider, secret, public, context_bytes, ct).await?;
291 
292         // Derive the rest of the secrets for the tree and assign to the proper nodes
293         let mut node_secret_gen =
294             PathSecretGenerator::starting_with(cipher_suite_provider, lca_path_secret);
295 
296         // Update secrets based on the decrypted path secret in the update
297         self.private_key.secret_keys.resize(path.len() + 1, None);
298 
299         for (i, update) in update_path.nodes.iter().enumerate().skip(lca_index) {
300             if let Some(update) = update {
301                 let secret = node_secret_gen.next_secret().await?;
302 
303                 // Verify the private key we calculated properly matches the public key we inserted into the tree. This guarantees
304                 // that we will be able to decrypt later.
305                 let (hpke_private, hpke_public) =
306                     secret.to_hpke_key_pair(cipher_suite_provider).await?;
307 
308                 if hpke_public != update.public_key {
309                     return Err(MlsError::PubKeyMismatch);
310                 }
311 
312                 self.private_key.secret_keys[i + 1] = Some(hpke_private);
313             } else {
314                 self.private_key.secret_keys[i + 1] = None;
315             }
316         }
317 
318         node_secret_gen.next_secret().await
319     }
320 
321     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
encrypt_copath_node_resolution<P: CipherSuiteProvider>( &self, cipher_suite_provider: &P, path_secret: &PathSecret, copath_index: NodeIndex, context: &[u8], #[cfg(feature = "std")] excluding: &HashSet<NodeIndex>, #[cfg(not(feature = "std"))] excluding: &[NodeIndex], ) -> Result<UpdatePathNode, MlsError>322     async fn encrypt_copath_node_resolution<P: CipherSuiteProvider>(
323         &self,
324         cipher_suite_provider: &P,
325         path_secret: &PathSecret,
326         copath_index: NodeIndex,
327         context: &[u8],
328         #[cfg(feature = "std")] excluding: &HashSet<NodeIndex>,
329         #[cfg(not(feature = "std"))] excluding: &[NodeIndex],
330     ) -> Result<UpdatePathNode, MlsError> {
331         let reso = self
332             .tree_kem_public
333             .nodes
334             .get_resolution_index(copath_index)?;
335 
336         let make_ctxt = |idx| async move {
337             let node = self
338                 .tree_kem_public
339                 .nodes
340                 .borrow_node(idx)?
341                 .as_non_empty()?;
342 
343             path_secret
344                 .encrypt(cipher_suite_provider, node.public_key(), context)
345                 .await
346         };
347 
348         let ctxts = wrap_iter(reso).filter(|&idx| async move { !excluding.contains(&idx) });
349 
350         #[cfg(not(mls_build_async))]
351         let ctxts = ctxts.map(make_ctxt);
352 
353         #[cfg(mls_build_async)]
354         let ctxts = ctxts.then(make_ctxt);
355 
356         let ctxts = ctxts.try_collect().await?;
357 
358         let path_index = copath_index
359             .parent_sibling(&self.tree_kem_public.total_leaf_count())
360             .ok_or(MlsError::ExpectedNode)?
361             .parent;
362 
363         Ok(UpdatePathNode {
364             public_key: self
365                 .tree_kem_public
366                 .nodes
367                 .borrow_as_parent(path_index)?
368                 .public_key
369                 .clone(),
370             encrypted_path_secret: ctxts,
371         })
372     }
373 
374     #[inline]
find_resolved_pos( &self, path: &[CopathNode<NodeIndex>], mut lca_index: usize, ) -> Result<usize, MlsError>375     fn find_resolved_pos(
376         &self,
377         path: &[CopathNode<NodeIndex>],
378         mut lca_index: usize,
379     ) -> Result<usize, MlsError> {
380         while self.tree_kem_public.nodes.is_blank(path[lca_index].path)? {
381             lca_index -= 1;
382         }
383 
384         // If we don't have the key, we should be an unmerged leaf at the resolved node. (If
385         // we're not, an error will be thrown later.)
386         if self.private_key.secret_keys[lca_index].is_none() {
387             lca_index = 0;
388         }
389 
390         Ok(lca_index)
391     }
392 
393     #[inline]
find_ciphertext_pos( &self, lca: NodeIndex, resolved: NodeIndex, excluding: &[LeafIndex], ) -> Result<usize, MlsError>394     fn find_ciphertext_pos(
395         &self,
396         lca: NodeIndex,
397         resolved: NodeIndex,
398         excluding: &[LeafIndex],
399     ) -> Result<usize, MlsError> {
400         let reso = self.tree_kem_public.nodes.get_resolution_index(lca)?;
401 
402         let (ct_pos, _) = reso
403             .iter()
404             .filter(|idx| **idx % 2 == 1 || !excluding.contains(&LeafIndex(**idx / 2)))
405             .find_position(|idx| idx == &&resolved)
406             .ok_or(MlsError::UpdateErrorNoSecretKey)?;
407 
408         Ok(ct_pos)
409     }
410 }
411 
412 #[cfg(test)]
413 mod tests {
414     use super::{tree_math, TreeKem};
415     use crate::{
416         cipher_suite::CipherSuite,
417         client::test_utils::TEST_CIPHER_SUITE,
418         crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider},
419         extension::test_utils::TestExtension,
420         group::test_utils::{get_test_group_context, random_bytes},
421         identity::basic::BasicIdentityProvider,
422         tree_kem::{
423             leaf_node::{
424                 test_utils::{get_basic_test_node_sig_key, get_test_capabilities},
425                 ConfigProperties,
426             },
427             node::LeafIndex,
428             Capabilities, TreeKemPrivate, TreeKemPublic, UpdatePath, ValidatedUpdatePath,
429         },
430         ExtensionList,
431     };
432     use alloc::{format, vec, vec::Vec};
433     use mls_rs_codec::MlsEncode;
434     use mls_rs_core::crypto::CipherSuiteProvider;
435     use tree_math::TreeIndex;
436 
437     // Verify that the tree is in the correct state after generating an update path
verify_tree_update_path( tree: &TreeKemPublic, update_path: &UpdatePath, index: LeafIndex, capabilities: Option<Capabilities>, extensions: Option<ExtensionList>, )438     fn verify_tree_update_path(
439         tree: &TreeKemPublic,
440         update_path: &UpdatePath,
441         index: LeafIndex,
442         capabilities: Option<Capabilities>,
443         extensions: Option<ExtensionList>,
444     ) {
445         // Make sure the update path is based on the direct path of the sender
446         let direct_path = tree.nodes.direct_copath(index);
447 
448         for (i, n) in direct_path.iter().enumerate() {
449             assert_eq!(
450                 *tree
451                     .nodes
452                     .borrow_node(n.path)
453                     .unwrap()
454                     .as_ref()
455                     .unwrap()
456                     .public_key(),
457                 update_path.nodes[i].public_key
458             );
459         }
460 
461         // Verify that the leaf from the update path has been installed
462         assert_eq!(
463             tree.nodes.borrow_as_leaf(index).unwrap(),
464             &update_path.leaf_node
465         );
466 
467         // Verify that updated capabilities were installed
468         if let Some(capabilities) = capabilities {
469             assert_eq!(update_path.leaf_node.capabilities, capabilities);
470         }
471 
472         // Verify that update extensions were installed
473         if let Some(extensions) = extensions {
474             assert_eq!(update_path.leaf_node.extensions, extensions);
475         }
476 
477         // Verify that we have a public keys up to the root
478         let root = tree.total_leaf_count().root();
479         assert!(tree.nodes.borrow_node(root).unwrap().is_some());
480     }
481 
482     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
verify_tree_private_path( cipher_suite: &CipherSuite, public_tree: &TreeKemPublic, private_tree: &TreeKemPrivate, index: LeafIndex, )483     async fn verify_tree_private_path(
484         cipher_suite: &CipherSuite,
485         public_tree: &TreeKemPublic,
486         private_tree: &TreeKemPrivate,
487         index: LeafIndex,
488     ) {
489         let provider = test_cipher_suite_provider(*cipher_suite);
490 
491         assert_eq!(private_tree.self_index, index);
492 
493         // Make sure we have private values along the direct path, and the public keys match
494         let path_iter = public_tree
495             .nodes
496             .direct_copath(index)
497             .into_iter()
498             .enumerate();
499 
500         for (i, n) in path_iter {
501             let secret_key = private_tree.secret_keys[i + 1].as_ref().unwrap();
502 
503             let public_key = public_tree
504                 .nodes
505                 .borrow_node(n.path)
506                 .unwrap()
507                 .as_ref()
508                 .unwrap()
509                 .public_key();
510 
511             let test_data = random_bytes(32);
512 
513             let sealed = provider
514                 .hpke_seal(public_key, &[], None, &test_data)
515                 .await
516                 .unwrap();
517 
518             let opened = provider
519                 .hpke_open(&sealed, secret_key, public_key, &[], None)
520                 .await
521                 .unwrap();
522 
523             assert_eq!(test_data, opened);
524         }
525     }
526 
527     #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)]
encap_decap( cipher_suite: CipherSuite, size: usize, capabilities: Option<Capabilities>, extensions: Option<ExtensionList>, )528     async fn encap_decap(
529         cipher_suite: CipherSuite,
530         size: usize,
531         capabilities: Option<Capabilities>,
532         extensions: Option<ExtensionList>,
533     ) {
534         let cipher_suite_provider = test_cipher_suite_provider(cipher_suite);
535 
536         // Generate signing keys and key package generations, and private keys for multiple
537         // participants in order to set up state
538 
539         let mut leaf_nodes = Vec::new();
540         let mut private_keys = Vec::new();
541 
542         for index in 1..size {
543             let (leaf_node, hpke_secret, _) =
544                 get_basic_test_node_sig_key(cipher_suite, &format!("{index}")).await;
545 
546             let private_key = TreeKemPrivate::new_self_leaf(LeafIndex(index as u32), hpke_secret);
547 
548             leaf_nodes.push(leaf_node);
549             private_keys.push(private_key);
550         }
551 
552         let (encap_node, encap_hpke_secret, encap_signer) =
553             get_basic_test_node_sig_key(cipher_suite, "encap").await;
554 
555         // Build a test tree we can clone for all leaf nodes
556         let (mut test_tree, mut encap_private_key) = TreeKemPublic::derive(
557             encap_node,
558             encap_hpke_secret,
559             &BasicIdentityProvider,
560             &Default::default(),
561         )
562         .await
563         .unwrap();
564 
565         test_tree
566             .add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider)
567             .await
568             .unwrap();
569 
570         // Clone the tree for the first leaf, generate a new key package for that leaf
571         let mut encap_tree = test_tree.clone();
572 
573         let update_leaf_properties = ConfigProperties {
574             capabilities: capabilities.clone().unwrap_or_else(get_test_capabilities),
575             extensions: extensions.clone().unwrap_or_default(),
576         };
577 
578         // Perform the encap function
579         let encap_gen = TreeKem::new(&mut encap_tree, &mut encap_private_key)
580             .encap(
581                 &mut get_test_group_context(42, cipher_suite).await,
582                 &[],
583                 &encap_signer,
584                 update_leaf_properties,
585                 None,
586                 &cipher_suite_provider,
587                 #[cfg(test)]
588                 &Default::default(),
589             )
590             .await
591             .unwrap();
592 
593         // Verify that the state of the tree matches the produced update path
594         verify_tree_update_path(
595             &encap_tree,
596             &encap_gen.update_path,
597             LeafIndex(0),
598             capabilities,
599             extensions,
600         );
601 
602         // Verify that the private key matches the data in the public key
603         verify_tree_private_path(&cipher_suite, &encap_tree, &encap_private_key, LeafIndex(0))
604             .await;
605 
606         let filtered = test_tree.nodes.filtered(LeafIndex(0)).unwrap();
607         let mut unfiltered_nodes = vec![None; filtered.len()];
608         filtered
609             .into_iter()
610             .enumerate()
611             .filter(|(_, f)| !*f)
612             .zip(encap_gen.update_path.nodes.iter())
613             .for_each(|((i, _), node)| {
614                 unfiltered_nodes[i] = Some(node.clone());
615             });
616 
617         // Apply the update path to the rest of the leaf nodes using the decap function
618         let validated_update_path = ValidatedUpdatePath {
619             leaf_node: encap_gen.update_path.leaf_node,
620             nodes: unfiltered_nodes,
621         };
622 
623         encap_tree
624             .update_hashes(&[LeafIndex(0)], &cipher_suite_provider)
625             .await
626             .unwrap();
627 
628         let mut receiver_trees: Vec<TreeKemPublic> = (1..size).map(|_| test_tree.clone()).collect();
629 
630         for (i, tree) in receiver_trees.iter_mut().enumerate() {
631             tree.apply_update_path(
632                 LeafIndex(0),
633                 &validated_update_path,
634                 &Default::default(),
635                 BasicIdentityProvider,
636                 &cipher_suite_provider,
637             )
638             .await
639             .unwrap();
640 
641             let mut context = get_test_group_context(42, cipher_suite).await;
642             context.tree_hash = tree.tree_hash(&cipher_suite_provider).await.unwrap();
643 
644             TreeKem::new(tree, &mut private_keys[i])
645                 .decap(
646                     LeafIndex(0),
647                     &validated_update_path,
648                     &[],
649                     &context.mls_encode_to_vec().unwrap(),
650                     &cipher_suite_provider,
651                 )
652                 .await
653                 .unwrap();
654 
655             tree.update_hashes(&[LeafIndex(0)], &cipher_suite_provider)
656                 .await
657                 .unwrap();
658 
659             assert_eq!(tree, &encap_tree);
660         }
661     }
662 
663     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_encap_decap()664     async fn test_encap_decap() {
665         for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() {
666             encap_decap(cipher_suite, 10, None, None).await;
667         }
668     }
669 
670     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_encap_capabilities()671     async fn test_encap_capabilities() {
672         let cipher_suite = TEST_CIPHER_SUITE;
673         let mut capabilities = get_test_capabilities();
674         capabilities.extensions.push(42.into());
675 
676         encap_decap(cipher_suite, 10, Some(capabilities.clone()), None).await;
677     }
678 
679     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_encap_extensions()680     async fn test_encap_extensions() {
681         let cipher_suite = TEST_CIPHER_SUITE;
682         let mut extensions = ExtensionList::default();
683         extensions.set_from(TestExtension { foo: 10 }).unwrap();
684 
685         encap_decap(cipher_suite, 10, None, Some(extensions)).await;
686     }
687 
688     #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))]
test_encap_capabilities_extensions()689     async fn test_encap_capabilities_extensions() {
690         let cipher_suite = TEST_CIPHER_SUITE;
691         let mut capabilities = get_test_capabilities();
692         capabilities.extensions.push(42.into());
693 
694         let mut extensions = ExtensionList::default();
695         extensions.set_from(TestExtension { foo: 10 }).unwrap();
696 
697         encap_decap(cipher_suite, 10, Some(capabilities), Some(extensions)).await;
698     }
699 }
700