// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use alloc::vec; use alloc::vec::Vec; #[cfg(feature = "std")] use core::fmt::Display; use itertools::Itertools; use mls_rs_codec::{MlsDecode, MlsEncode, MlsSize}; use mls_rs_core::extension::ExtensionList; use mls_rs_core::{error::IntoAnyError, identity::IdentityProvider}; #[cfg(feature = "tree_index")] use mls_rs_core::identity::SigningIdentity; use math as tree_math; use node::{LeafIndex, NodeIndex, NodeVec}; use self::leaf_node::LeafNode; use crate::client::MlsError; use crate::crypto::{self, CipherSuiteProvider, HpkeSecretKey}; #[cfg(feature = "by_ref_proposal")] use crate::group::proposal::{AddProposal, UpdateProposal}; #[cfg(any(test, feature = "by_ref_proposal"))] use crate::group::proposal::RemoveProposal; use crate::group::proposal_filter::ProposalBundle; use crate::tree_kem::tree_hash::TreeHashes; mod capabilities; pub(crate) mod hpke_encryption; mod lifetime; pub(crate) mod math; pub mod node; pub mod parent_hash; pub mod path_secret; mod private; mod tree_hash; pub mod tree_validator; pub mod update_path; pub use capabilities::*; pub use lifetime::*; pub(crate) use private::*; pub use update_path::*; use tree_index::*; pub mod kem; pub mod leaf_node; pub mod leaf_node_validator; mod tree_index; #[cfg(feature = "std")] pub(crate) mod tree_utils; #[cfg(test)] mod interop_test_vectors; #[cfg(feature = "custom_proposal")] use crate::group::proposal::ProposalType; #[derive(Clone, Debug, MlsEncode, MlsDecode, MlsSize, Default)] #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))] pub struct TreeKemPublic { #[cfg(feature = "tree_index")] #[cfg_attr(feature = "serde", serde(skip))] index: TreeIndex, pub(crate) nodes: NodeVec, tree_hashes: TreeHashes, } impl PartialEq for TreeKemPublic { fn eq(&self, other: &Self) -> bool { self.nodes == other.nodes } } impl TreeKemPublic { pub fn new() -> TreeKemPublic { Default::default() } #[cfg_attr(not(feature = "tree_index"), allow(unused))] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn import_node_data( nodes: NodeVec, identity_provider: &IP, extensions: &ExtensionList, ) -> Result where IP: IdentityProvider, { let mut tree = TreeKemPublic { nodes, ..Default::default() }; #[cfg(feature = "tree_index")] tree.initialize_index_if_necessary(identity_provider, extensions) .await?; Ok(tree) } #[cfg(feature = "tree_index")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn initialize_index_if_necessary( &mut self, identity_provider: &IP, extensions: &ExtensionList, ) -> Result<(), MlsError> { if !self.index.is_initialized() { self.index = TreeIndex::new(); for (leaf_index, leaf) in self.nodes.non_empty_leaves() { index_insert( &mut self.index, leaf, leaf_index, identity_provider, extensions, ) .await?; } } Ok(()) } #[cfg(feature = "tree_index")] pub(crate) fn get_leaf_node_with_identity(&self, identity: &[u8]) -> Option { self.index.get_leaf_index_with_identity(identity) } #[cfg(not(feature = "tree_index"))] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn get_leaf_node_with_identity( &self, identity: &[u8], id_provider: &I, extensions: &ExtensionList, ) -> Result, MlsError> { for (i, leaf) in self.nodes.non_empty_leaves() { let leaf_id = id_provider .identity(&leaf.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; if leaf_id == identity { return Ok(Some(i)); } } Ok(None) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn derive( leaf_node: LeafNode, secret_key: HpkeSecretKey, identity_provider: &I, extensions: &ExtensionList, ) -> Result<(TreeKemPublic, TreeKemPrivate), MlsError> { let mut public_tree = TreeKemPublic::new(); public_tree .add_leaf(leaf_node, identity_provider, extensions, None) .await?; let private_tree = TreeKemPrivate::new_self_leaf(LeafIndex(0), secret_key); Ok((public_tree, private_tree)) } pub fn total_leaf_count(&self) -> u32 { self.nodes.total_leaf_count() } #[cfg(any(test, all(feature = "custom_proposal", feature = "tree_index")))] pub fn occupied_leaf_count(&self) -> u32 { self.nodes.occupied_leaf_count() } pub fn get_leaf_node(&self, index: LeafIndex) -> Result<&LeafNode, MlsError> { self.nodes.borrow_as_leaf(index) } pub fn find_leaf_node(&self, leaf_node: &LeafNode) -> Option { self.nodes.non_empty_leaves().find_map( |(index, node)| { if node == leaf_node { Some(index) } else { None } }, ) } #[cfg(feature = "custom_proposal")] pub fn can_support_proposal(&self, proposal_type: ProposalType) -> bool { #[cfg(feature = "tree_index")] return self.index.count_supporting_proposal(proposal_type) == self.occupied_leaf_count(); #[cfg(not(feature = "tree_index"))] self.nodes .non_empty_leaves() .all(|(_, l)| l.capabilities.proposals.contains(&proposal_type)) } #[cfg(test)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn add_leaves( &mut self, leaf_nodes: Vec, id_provider: &I, cipher_suite_provider: &CP, ) -> Result, MlsError> { let mut start = LeafIndex(0); let mut added = vec![]; for leaf in leaf_nodes.into_iter() { start = self .add_leaf(leaf, id_provider, &Default::default(), Some(start)) .await?; added.push(start); } self.update_hashes(&added, cipher_suite_provider).await?; Ok(added) } pub fn non_empty_leaves(&self) -> impl Iterator + '_ { self.nodes.non_empty_leaves() } #[cfg(feature = "prior_epoch")] pub fn leaves(&self) -> impl Iterator> + '_ { self.nodes.leaves() } pub(crate) fn update_node( &mut self, pub_key: crypto::HpkePublicKey, index: NodeIndex, ) -> Result<(), MlsError> { self.nodes .borrow_or_fill_node_as_parent(index, &pub_key) .map(|p| { p.public_key = pub_key; p.unmerged_leaves = vec![]; }) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn apply_update_path( &mut self, sender: LeafIndex, update_path: &ValidatedUpdatePath, extensions: &ExtensionList, identity_provider: IP, cipher_suite_provider: &CP, ) -> Result<(), MlsError> where IP: IdentityProvider, CP: CipherSuiteProvider, { // Install the new leaf node let existing_leaf = self.nodes.borrow_as_leaf_mut(sender)?; #[cfg(feature = "tree_index")] let original_leaf_node = existing_leaf.clone(); #[cfg(feature = "tree_index")] let original_identity = identity_provider .identity(&original_leaf_node.signing_identity, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error()))?; *existing_leaf = update_path.leaf_node.clone(); // Update the rest of the nodes on the direct path let path = self.nodes.direct_copath(sender); for (node, pn) in update_path.nodes.iter().zip(path) { node.as_ref() .map(|n| self.update_node(n.public_key.clone(), pn.path)) .transpose()?; } #[cfg(feature = "tree_index")] self.index.remove(&original_leaf_node, &original_identity); index_insert( #[cfg(feature = "tree_index")] &mut self.index, #[cfg(not(feature = "tree_index"))] &self.nodes, &update_path.leaf_node, sender, &identity_provider, extensions, ) .await?; // Verify the parent hash of the new sender leaf node and update the parent hash values // in the local tree self.update_parent_hashes(sender, true, cipher_suite_provider) .await?; Ok(()) } fn update_unmerged(&mut self, index: LeafIndex) -> Result<(), MlsError> { // For a given leaf index, find parent nodes and add the leaf to the unmerged leaf self.nodes.direct_copath(index).into_iter().for_each(|i| { if let Ok(p) = self.nodes.borrow_as_parent_mut(i.path) { p.unmerged_leaves.push(index) } }); Ok(()) } #[cfg(feature = "by_ref_proposal")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn batch_edit( &mut self, proposal_bundle: &mut ProposalBundle, extensions: &ExtensionList, id_provider: &I, cipher_suite_provider: &CP, filter: bool, ) -> Result, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider, { // Apply removes (they commute with updates because they don't touch the same leaves) for i in (0..proposal_bundle.remove_proposals().len()).rev() { let index = proposal_bundle.remove_proposals()[i].proposal.to_remove; let res = self.nodes.blank_leaf_node(index); if res.is_ok() { // This shouldn't fail if `blank_leaf_node` succedded. self.nodes.blank_direct_path(index)?; } #[cfg(feature = "tree_index")] if let Ok(old_leaf) = &res { // If this fails, it's not because the proposal is bad. let identity = identity(&old_leaf.signing_identity, id_provider, extensions).await?; self.index.remove(old_leaf, &identity); } if proposal_bundle.remove_proposals()[i].is_by_value() || !filter { res?; } else if res.is_err() { proposal_bundle.remove::(i); } } // Remove from the tree old leaves from updates let mut partial_updates = vec![]; let senders = proposal_bundle.update_senders.iter().copied(); for (i, (p, index)) in proposal_bundle.updates.iter().zip(senders).enumerate() { let new_leaf = p.proposal.leaf_node.clone(); match self.nodes.blank_leaf_node(index) { Ok(old_leaf) => { #[cfg(feature = "tree_index")] let old_id = identity(&old_leaf.signing_identity, id_provider, extensions).await?; #[cfg(feature = "tree_index")] self.index.remove(&old_leaf, &old_id); partial_updates.push((index, old_leaf, new_leaf, i)); } _ => { if !filter || !p.is_by_reference() { return Err(MlsError::UpdatingNonExistingMember); } } } } #[cfg(feature = "tree_index")] let index_clone = self.index.clone(); let mut removed_leaves = vec![]; let mut updated_indices = vec![]; let mut bad_indices = vec![]; // Apply updates one by one. If there's an update which we can't apply or revert, we revert // all updates. for (index, old_leaf, new_leaf, i) in partial_updates.into_iter() { #[cfg(feature = "tree_index")] let res = index_insert(&mut self.index, &new_leaf, index, id_provider, extensions).await; #[cfg(not(feature = "tree_index"))] let res = index_insert(&self.nodes, &new_leaf, index, id_provider, extensions).await; let err = res.is_err(); if !filter { res?; } if !err { self.nodes.insert_leaf(index, new_leaf); removed_leaves.push(old_leaf); updated_indices.push(index); } else { #[cfg(feature = "tree_index")] let res = index_insert(&mut self.index, &old_leaf, index, id_provider, extensions).await; #[cfg(not(feature = "tree_index"))] let res = index_insert(&self.nodes, &old_leaf, index, id_provider, extensions).await; if res.is_ok() { self.nodes.insert_leaf(index, old_leaf); bad_indices.push(i); } else { // Revert all updates and stop. We're already in the "filter" case, so we don't throw an error. #[cfg(feature = "tree_index")] { self.index = index_clone; } removed_leaves .into_iter() .zip(updated_indices.iter()) .for_each(|(leaf, index)| self.nodes.insert_leaf(*index, leaf)); updated_indices = vec![]; break; } } } // If we managed to update something, blank direct paths updated_indices .iter() .try_for_each(|index| self.nodes.blank_direct_path(*index).map(|_| ()))?; // Remove rejected updates from applied proposals if updated_indices.is_empty() { // This takes care of the "revert all" scenario proposal_bundle.updates = vec![]; } else { for i in bad_indices.into_iter().rev() { proposal_bundle.remove::(i); proposal_bundle.update_senders.remove(i); } } // Apply adds let mut start = LeafIndex(0); let mut added = vec![]; let mut bad_indexes = vec![]; for i in 0..proposal_bundle.additions.len() { let leaf = proposal_bundle.additions[i] .proposal .key_package .leaf_node .clone(); let res = self .add_leaf(leaf, id_provider, extensions, Some(start)) .await; if let Ok(index) = res { start = index; added.push(start); } else if proposal_bundle.additions[i].is_by_value() || !filter { res?; } else { bad_indexes.push(i); } } for i in bad_indexes.into_iter().rev() { proposal_bundle.remove::(i); } self.nodes.trim(); let updated_leaves = proposal_bundle .remove_proposals() .iter() .map(|p| p.proposal.to_remove) .chain(updated_indices) .chain(added.iter().copied()) .collect_vec(); self.update_hashes(&updated_leaves, cipher_suite_provider) .await?; Ok(added) } #[cfg(not(feature = "by_ref_proposal"))] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn batch_edit_lite( &mut self, proposal_bundle: &ProposalBundle, extensions: &ExtensionList, id_provider: &I, cipher_suite_provider: &CP, ) -> Result, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider, { // Apply removes for p in &proposal_bundle.removals { let index = p.proposal.to_remove; #[cfg(feature = "tree_index")] { // If this fails, it's not because the proposal is bad. let old_leaf = self.nodes.blank_leaf_node(index)?; let identity = identity(&old_leaf.signing_identity, id_provider, extensions).await?; self.index.remove(&old_leaf, &identity); } #[cfg(not(feature = "tree_index"))] self.nodes.blank_leaf_node(index)?; self.nodes.blank_direct_path(index)?; } // Apply adds let mut start = LeafIndex(0); let mut added = vec![]; for p in &proposal_bundle.additions { let leaf = p.proposal.key_package.leaf_node.clone(); start = self .add_leaf(leaf, id_provider, extensions, Some(start)) .await?; added.push(start); } self.nodes.trim(); let updated_leaves = proposal_bundle .remove_proposals() .iter() .map(|p| p.proposal.to_remove) .chain(added.iter().copied()) .collect_vec(); self.update_hashes(&updated_leaves, cipher_suite_provider) .await?; Ok(added) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn add_leaf( &mut self, leaf: LeafNode, id_provider: &I, extensions: &ExtensionList, start: Option, ) -> Result { let index = self.nodes.next_empty_leaf(start.unwrap_or(LeafIndex(0))); #[cfg(feature = "tree_index")] index_insert(&mut self.index, &leaf, index, id_provider, extensions).await?; #[cfg(not(feature = "tree_index"))] index_insert(&self.nodes, &leaf, index, id_provider, extensions).await?; self.nodes.insert_leaf(index, leaf); self.update_unmerged(index)?; Ok(index) } } #[cfg(feature = "tree_index")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn identity( signing_id: &SigningIdentity, provider: &I, extensions: &ExtensionList, ) -> Result, MlsError> { provider .identity(signing_id, extensions) .await .map_err(|e| MlsError::IdentityProviderError(e.into_any_error())) } #[cfg(feature = "std")] impl Display for TreeKemPublic { fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { write!(f, "{}", tree_utils::build_ascii_tree(&self.nodes)) } } #[cfg(test)] use crate::group::{proposal::Proposal, proposal_filter::ProposalSource, Sender}; #[cfg(test)] impl TreeKemPublic { #[cfg(feature = "by_ref_proposal")] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn update_leaf( &mut self, leaf_index: u32, leaf_node: LeafNode, identity_provider: &I, cipher_suite_provider: &CP, ) -> Result<(), MlsError> where I: IdentityProvider, CP: CipherSuiteProvider, { let p = Proposal::Update(UpdateProposal { leaf_node }); let mut bundle = ProposalBundle::default(); bundle.add(p, Sender::Member(leaf_index), ProposalSource::ByValue); bundle.update_senders = vec![LeafIndex(leaf_index)]; self.batch_edit( &mut bundle, &Default::default(), identity_provider, cipher_suite_provider, true, ) .await?; Ok(()) } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn remove_leaves( &mut self, indexes: Vec, identity_provider: &I, cipher_suite_provider: &CP, ) -> Result, MlsError> where I: IdentityProvider, CP: CipherSuiteProvider, { let old_tree = self.clone(); let proposals = indexes .iter() .copied() .map(|to_remove| Proposal::Remove(RemoveProposal { to_remove })); let mut bundle = ProposalBundle::default(); for p in proposals { bundle.add(p, Sender::Member(0), ProposalSource::ByValue); } #[cfg(feature = "by_ref_proposal")] self.batch_edit( &mut bundle, &Default::default(), identity_provider, cipher_suite_provider, true, ) .await?; #[cfg(not(feature = "by_ref_proposal"))] self.batch_edit_lite( &bundle, &Default::default(), identity_provider, cipher_suite_provider, ) .await?; bundle .removals .iter() .map(|p| { let index = p.proposal.to_remove; let leaf = old_tree.get_leaf_node(index)?.clone(); Ok((index, leaf)) }) .collect() } pub fn get_leaf_nodes(&self) -> Vec<&LeafNode> { self.nodes.non_empty_leaves().map(|(_, l)| l).collect() } } #[cfg(test)] pub(crate) mod test_utils { use crate::crypto::test_utils::TestCryptoProvider; use crate::signer::Signable; use alloc::vec::Vec; use alloc::{format, vec}; use mls_rs_core::crypto::CipherSuiteProvider; use mls_rs_core::group::Capabilities; use mls_rs_core::identity::BasicCredential; use crate::identity::test_utils::get_test_signing_identity; use crate::{ cipher_suite::CipherSuite, crypto::{HpkeSecretKey, SignatureSecretKey}, identity::basic::BasicIdentityProvider, tree_kem::leaf_node::test_utils::get_basic_test_node_sig_key, }; use super::leaf_node::{ConfigProperties, LeafNodeSigningContext}; use super::node::LeafIndex; use super::Lifetime; use super::{ leaf_node::{test_utils::get_basic_test_node, LeafNode}, TreeKemPrivate, TreeKemPublic, }; #[derive(Debug)] pub(crate) struct TestTree { pub public: TreeKemPublic, pub private: TreeKemPrivate, pub creator_leaf: LeafNode, pub creator_signing_key: SignatureSecretKey, pub creator_hpke_secret: HpkeSecretKey, } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub(crate) async fn get_test_tree(cipher_suite: CipherSuite) -> TestTree { let (creator_leaf, creator_hpke_secret, creator_signing_key) = get_basic_test_node_sig_key(cipher_suite, "creator").await; let (test_public, test_private) = TreeKemPublic::derive( creator_leaf.clone(), creator_hpke_secret.clone(), &BasicIdentityProvider, &Default::default(), ) .await .unwrap(); TestTree { public: test_public, private: test_private, creator_leaf, creator_signing_key, creator_hpke_secret, } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn get_test_leaf_nodes(cipher_suite: CipherSuite) -> Vec { [ get_basic_test_node(cipher_suite, "A").await, get_basic_test_node(cipher_suite, "B").await, get_basic_test_node(cipher_suite, "C").await, ] .to_vec() } impl TreeKemPublic { #[cfg(feature = "tree_index")] pub fn equal_internals(&self, other: &TreeKemPublic) -> bool { self.tree_hashes == other.tree_hashes && self.index == other.index } } #[derive(Debug, Clone)] pub struct TreeWithSigners { pub tree: TreeKemPublic, pub signers: Vec>, pub group_id: Vec, } impl TreeWithSigners { #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn make_full_tree( n_leaves: u32, cs: &P, ) -> TreeWithSigners { let mut tree = TreeWithSigners { tree: TreeKemPublic::new(), signers: vec![], group_id: cs.random_bytes_vec(cs.kdf_extract_size()).unwrap(), }; tree.add_member("Alice", cs).await; // A adds B, B adds C, C adds D etc. for i in 1..n_leaves { tree.add_member(&format!("Alice{i}"), cs).await; tree.update_committer_path(i - 1, cs).await; } tree } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn add_member(&mut self, name: &str, cs: &P) { let (leaf, signer) = make_leaf(name, cs).await; let index = self.tree.nodes.next_empty_leaf(LeafIndex(0)); self.tree.nodes.insert_leaf(index, leaf); self.tree.update_unmerged(index).unwrap(); let index = *index as usize; match self.signers.len() { l if l == index => self.signers.push(Some(signer)), l if l > index => self.signers[index] = Some(signer), _ => panic!("signer tree size mismatch"), } } #[cfg(feature = "rfc_compliant")] #[cfg_attr(coverage_nightly, coverage(off))] pub fn remove_member(&mut self, member: u32) { self.tree .nodes .blank_direct_path(LeafIndex(member)) .unwrap(); self.tree.nodes.blank_leaf_node(LeafIndex(member)).unwrap(); *self .signers .get_mut(member as usize) .expect("signer tree size mismatch") = None; } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn update_committer_path( &mut self, committer: u32, cs: &P, ) { let committer = LeafIndex(committer); let path = self.tree.nodes.direct_copath(committer); let filtered = self.tree.nodes.filtered(committer).unwrap(); for (n, f) in path.into_iter().zip(filtered) { if !f { self.tree .update_node(cs.kem_generate().await.unwrap().1, n.path) .unwrap(); } } self.tree.tree_hashes.current = vec![]; self.tree.tree_hash(cs).await.unwrap(); self.tree .update_parent_hashes(committer, false, cs) .await .unwrap(); self.tree.tree_hashes.current = vec![]; self.tree.tree_hash(cs).await.unwrap(); let context = LeafNodeSigningContext { group_id: Some(&self.group_id), leaf_index: Some(*committer), }; let signer = self.signers[*committer as usize].as_ref().unwrap(); self.tree .nodes .borrow_as_leaf_mut(committer) .unwrap() .sign(cs, signer, &context) .await .unwrap(); self.tree.tree_hashes.current = vec![]; self.tree.tree_hash(cs).await.unwrap(); } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn make_leaf( name: &str, cs: &P, ) -> (LeafNode, SignatureSecretKey) { let (signing_identity, signature_key) = get_test_signing_identity(cs.cipher_suite(), name.as_bytes()).await; let capabilities = Capabilities { credentials: vec![BasicCredential::credential_type()], cipher_suites: TestCryptoProvider::all_supported_cipher_suites(), ..Default::default() }; let properties = ConfigProperties { capabilities, extensions: Default::default(), }; let (leaf, _) = LeafNode::generate( cs, properties, signing_identity, &signature_key, Lifetime::years(1).unwrap(), ) .await .unwrap(); (leaf, signature_key) } } #[cfg(test)] mod tests { use crate::client::test_utils::TEST_CIPHER_SUITE; use crate::crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}; #[cfg(feature = "custom_proposal")] use crate::group::proposal::ProposalType; use crate::identity::basic::BasicIdentityProvider; use crate::tree_kem::leaf_node::LeafNode; use crate::tree_kem::node::{LeafIndex, Node, NodeIndex, NodeTypeResolver, Parent}; use crate::tree_kem::parent_hash::ParentHash; use crate::tree_kem::test_utils::{get_test_leaf_nodes, get_test_tree}; use crate::tree_kem::{MlsError, TreeKemPublic}; use alloc::borrow::ToOwned; use alloc::vec; use alloc::vec::Vec; use assert_matches::assert_matches; #[cfg(feature = "by_ref_proposal")] use alloc::boxed::Box; #[cfg(feature = "by_ref_proposal")] use crate::{ client::test_utils::TEST_PROTOCOL_VERSION, group::{ proposal::{Proposal, RemoveProposal, UpdateProposal}, proposal_filter::{ProposalBundle, ProposalSource}, proposal_ref::ProposalRef, Sender, }, key_package::test_utils::test_key_package, }; #[cfg(any(feature = "by_ref_proposal", feature = "custo_proposal"))] use crate::tree_kem::leaf_node::test_utils::get_basic_test_node; #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_derive() { for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { let test_tree = get_test_tree(cipher_suite).await; assert_eq!( test_tree.public.nodes[0], Some(Node::Leaf(test_tree.creator_leaf.clone())) ); assert_eq!(test_tree.private.self_index, LeafIndex(0)); assert_eq!( test_tree.private.secret_keys[0], Some(test_tree.creator_hpke_secret) ); } } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_import_export() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut test_tree = get_test_tree(TEST_CIPHER_SUITE).await; let additional_key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; test_tree .public .add_leaves( additional_key_packages, &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); let imported = TreeKemPublic::import_node_data( test_tree.public.nodes.clone(), &BasicIdentityProvider, &Default::default(), ) .await .unwrap(); assert_eq!(test_tree.public.nodes, imported.nodes); #[cfg(feature = "tree_index")] assert_eq!(test_tree.public.index, imported.index); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_add_leaf() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut tree = TreeKemPublic::new(); let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; let res = tree .add_leaves( leaf_nodes.clone(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); // The leaf count should be equal to the number of packages we added assert_eq!(res.len(), leaf_nodes.len()); assert_eq!(tree.occupied_leaf_count(), leaf_nodes.len() as u32); // Each added package should be at the proper index and searchable in the tree res.into_iter().zip(leaf_nodes.clone()).for_each(|(r, kp)| { assert_eq!(tree.get_leaf_node(r).unwrap(), &kp); }); // Verify the underlying state #[cfg(feature = "tree_index")] assert_eq!(tree.index.len(), tree.occupied_leaf_count() as usize); assert_eq!(tree.nodes.len(), 5); assert_eq!(tree.nodes[0], leaf_nodes[0].clone().into()); assert_eq!(tree.nodes[1], None); assert_eq!(tree.nodes[2], leaf_nodes[1].clone().into()); assert_eq!(tree.nodes[3], None); assert_eq!(tree.nodes[4], leaf_nodes[2].clone().into()); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_get_key_packages() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut tree = TreeKemPublic::new(); let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) .await .unwrap(); let key_packages = tree.get_leaf_nodes(); assert_eq!(key_packages, key_packages.to_owned()); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_add_leaf_duplicate() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut tree = TreeKemPublic::new(); let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves( key_packages.clone(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); let res = tree .add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) .await; assert_matches!(res, Err(MlsError::DuplicateLeafData(_))); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_add_leaf_empty_leaf() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves( [key_packages[0].clone()].to_vec(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); tree.nodes[0] = None; // Set the original first node to none // tree.add_leaves( [key_packages[1].clone()].to_vec(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); assert_eq!(tree.nodes[0], key_packages[1].clone().into()); assert_eq!(tree.nodes[1], None); assert_eq!(tree.nodes[2], key_packages[0].clone().into()); assert_eq!(tree.nodes.len(), 3) } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_add_leaf_unmerged() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves( [key_packages[0].clone(), key_packages[1].clone()].to_vec(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); tree.nodes[3] = Parent { public_key: vec![].into(), parent_hash: ParentHash::empty(), unmerged_leaves: vec![], } .into(); tree.add_leaves( [key_packages[2].clone()].to_vec(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); assert_eq!( tree.nodes[3].as_parent().unwrap().unmerged_leaves, vec![LeafIndex(3)] ) } #[cfg(feature = "by_ref_proposal")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_update_leaf() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); // Create a tree let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) .await .unwrap(); // Add in parent nodes so we can detect them clearing after update tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| { tree.nodes .borrow_or_fill_node_as_parent(n.path, &b"pub_key".to_vec().into()) .unwrap(); }); let original_size = tree.occupied_leaf_count(); let original_leaf_index = LeafIndex(1); let updated_leaf = get_basic_test_node(TEST_CIPHER_SUITE, "A").await; tree.update_leaf( *original_leaf_index, updated_leaf.clone(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); // The tree should not have grown due to an update assert_eq!(tree.occupied_leaf_count(), original_size); // The cache of tree package indexes should not have grown #[cfg(feature = "tree_index")] assert_eq!(tree.index.len() as u32, tree.occupied_leaf_count()); // The key package should be updated in the tree assert_eq!( tree.get_leaf_node(original_leaf_index).unwrap(), &updated_leaf ); // Verify that the direct path has been cleared tree.nodes.direct_copath(LeafIndex(0)).iter().for_each(|n| { assert!(tree.nodes[n.path as usize].is_none()); }); } #[cfg(feature = "by_ref_proposal")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_update_leaf_not_found() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); // Create a tree let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) .await .unwrap(); let new_key_package = get_basic_test_node(TEST_CIPHER_SUITE, "new").await; let res = tree .update_leaf( 128, new_key_package, &BasicIdentityProvider, &cipher_suite_provider, ) .await; assert_matches!(res, Err(MlsError::UpdatingNonExistingMember)); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_remove_leaf() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); // Create a tree let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; let indexes = tree .add_leaves( key_packages.clone(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); let original_leaf_count = tree.occupied_leaf_count(); // Remove two leaves from the tree let expected_result: Vec<(LeafIndex, LeafNode)> = indexes.clone().into_iter().zip(key_packages).collect(); let res = tree .remove_leaves( indexes.clone(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); // The order may change assert!(res.iter().all(|x| expected_result.contains(x))); assert!(expected_result.iter().all(|x| res.contains(x))); // The leaves should be removed from the tree assert_eq!( tree.occupied_leaf_count(), original_leaf_count - indexes.len() as u32 ); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_remove_leaf_middle() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); // Create a tree let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; let to_remove = tree .add_leaves( leaf_nodes.clone(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap()[0]; let original_leaf_count = tree.occupied_leaf_count(); let res = tree .remove_leaves( vec![to_remove], &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); assert_eq!(res, vec![(to_remove, leaf_nodes[0].clone())]); // The leaf count should have been reduced by 1 assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1); // There should be a blank in the tree assert_eq!( tree.nodes.get(NodeIndex::from(to_remove) as usize).unwrap(), &None ); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_create_blanks() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); // Create a tree let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let key_packages = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves(key_packages, &BasicIdentityProvider, &cipher_suite_provider) .await .unwrap(); let original_leaf_count = tree.occupied_leaf_count(); let to_remove = vec![LeafIndex(2)]; // Remove the leaf from the tree tree.remove_leaves(to_remove, &BasicIdentityProvider, &cipher_suite_provider) .await .unwrap(); // The occupied leaf count should have been reduced by 1 assert_eq!(tree.occupied_leaf_count(), original_leaf_count - 1); // The total leaf count should remain unchanged assert_eq!(tree.total_leaf_count(), original_leaf_count); // The location of key_packages[1] should now be blank let removed_location = tree .nodes .get(NodeIndex::from(LeafIndex(2)) as usize) .unwrap(); assert_eq!(removed_location, &None); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_remove_leaf_failure() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); // Create a tree let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let res = tree .remove_leaves( vec![LeafIndex(128)], &BasicIdentityProvider, &cipher_suite_provider, ) .await; assert_matches!(res, Err(MlsError::InvalidNodeIndex(256))); } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_find_leaf_node() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); // Create a tree let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves( leaf_nodes.clone(), &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); // Find each node for (i, leaf_node) in leaf_nodes.iter().enumerate() { let expected_index = LeafIndex(i as u32 + 1); assert_eq!(tree.find_leaf_node(leaf_node), Some(expected_index)); } } // TODO add test for the lite version #[cfg(feature = "by_ref_proposal")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn batch_edit_works() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut tree = get_test_tree(TEST_CIPHER_SUITE).await.public; let leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider) .await .unwrap(); let mut bundle = ProposalBundle::default(); let kp = test_key_package(TEST_PROTOCOL_VERSION, TEST_CIPHER_SUITE, "D").await; let add = Proposal::Add(Box::new(kp.into())); bundle.add(add, Sender::Member(0), ProposalSource::ByValue); let update = UpdateProposal { leaf_node: get_basic_test_node(TEST_CIPHER_SUITE, "A").await, }; let update = Proposal::Update(update); let pref = ProposalRef::new_fake(vec![1, 2, 3]); bundle.add(update, Sender::Member(1), ProposalSource::ByReference(pref)); bundle.update_senders = vec![LeafIndex(1)]; let remove = RemoveProposal { to_remove: LeafIndex(2), }; let remove = Proposal::Remove(remove); bundle.add(remove, Sender::Member(0), ProposalSource::ByValue); tree.batch_edit( &mut bundle, &Default::default(), &BasicIdentityProvider, &cipher_suite_provider, true, ) .await .unwrap(); assert_eq!(bundle.add_proposals().len(), 1); assert_eq!(bundle.remove_proposals().len(), 1); assert_eq!(bundle.update_proposals().len(), 1); } #[cfg(feature = "custom_proposal")] #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn custom_proposal_support() { let cipher_suite_provider = test_cipher_suite_provider(TEST_CIPHER_SUITE); let mut tree = TreeKemPublic::new(); let test_proposal_type = ProposalType::from(42); let mut leaf_nodes = get_test_leaf_nodes(TEST_CIPHER_SUITE).await; leaf_nodes .iter_mut() .for_each(|n| n.capabilities.proposals.push(test_proposal_type)); tree.add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider) .await .unwrap(); assert!(tree.can_support_proposal(test_proposal_type)); assert!(!tree.can_support_proposal(ProposalType::from(43))); let test_node = get_basic_test_node(TEST_CIPHER_SUITE, "another").await; tree.add_leaves( vec![test_node], &BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); assert!(!tree.can_support_proposal(test_proposal_type)); } }