// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. // Copyright by contributors to this project. // SPDX-License-Identifier: (Apache-2.0 OR MIT) use crate::client::MlsError; use crate::crypto::{CipherSuiteProvider, SignatureSecretKey}; use crate::group::GroupContext; use crate::identity::SigningIdentity; use crate::iter::wrap_iter; use crate::tree_kem::math as tree_math; use alloc::vec; use alloc::vec::Vec; use itertools::Itertools; use mls_rs_codec::MlsEncode; use tree_math::{CopathNode, TreeIndex}; #[cfg(all(not(mls_build_async), feature = "rayon"))] use {crate::iter::ParallelIteratorExt, rayon::prelude::*}; #[cfg(mls_build_async)] use futures::{StreamExt, TryStreamExt}; #[cfg(feature = "std")] use std::collections::HashSet; use super::hpke_encryption::HpkeEncryptable; use super::leaf_node::ConfigProperties; use super::node::NodeTypeResolver; use super::{ node::{LeafIndex, NodeIndex}, path_secret::{PathSecret, PathSecretGenerator}, TreeKemPrivate, TreeKemPublic, UpdatePath, UpdatePathNode, ValidatedUpdatePath, }; #[cfg(test)] use crate::{group::CommitModifiers, signer::Signable}; pub struct TreeKem<'a> { tree_kem_public: &'a mut TreeKemPublic, private_key: &'a mut TreeKemPrivate, } pub struct EncapGeneration { pub update_path: UpdatePath, pub path_secrets: Vec>, pub commit_secret: PathSecret, } impl<'a> TreeKem<'a> { pub fn new( tree_kem_public: &'a mut TreeKemPublic, private_key: &'a mut TreeKemPrivate, ) -> Self { TreeKem { tree_kem_public, private_key, } } #[allow(clippy::too_many_arguments)] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn encap

( self, context: &mut GroupContext, excluding: &[LeafIndex], signer: &SignatureSecretKey, update_leaf_properties: ConfigProperties, signing_identity: Option, cipher_suite_provider: &P, #[cfg(test)] commit_modifiers: &CommitModifiers, ) -> Result where P: CipherSuiteProvider + Send + Sync, { let self_index = self.private_key.self_index; let path = self.tree_kem_public.nodes.direct_copath(self_index); let filtered = self.tree_kem_public.nodes.filtered(self_index)?; self.private_key.secret_keys.resize(path.len() + 1, None); let mut secret_generator = PathSecretGenerator::new(cipher_suite_provider); let mut path_secrets = vec![]; for (i, (node, f)) in path.iter().zip(&filtered).enumerate() { if !f { let secret = secret_generator.next_secret().await?; let (secret_key, public_key) = secret.to_hpke_key_pair(cipher_suite_provider).await?; self.private_key.secret_keys[i + 1] = Some(secret_key); self.tree_kem_public.update_node(public_key, node.path)?; path_secrets.push(Some(secret)); } else { self.private_key.secret_keys[i + 1] = None; path_secrets.push(None); } } #[cfg(test)] (commit_modifiers.modify_tree)(self.tree_kem_public); self.tree_kem_public .update_parent_hashes(self_index, false, cipher_suite_provider) .await?; let update_path_leaf = { let own_leaf = self.tree_kem_public.nodes.borrow_as_leaf_mut(self_index)?; self.private_key.secret_keys[0] = Some( own_leaf .commit( cipher_suite_provider, &context.group_id, *self_index, update_leaf_properties, signing_identity, signer, ) .await?, ); #[cfg(test)] if let Some(signer) = (commit_modifiers.modify_leaf)(own_leaf, signer) { let context = &(context.group_id.as_slice(), *self_index).into(); own_leaf .sign(cipher_suite_provider, &signer, context) .await .unwrap(); } own_leaf.clone() }; // Tree modifications are all done so we can update the tree hash and encrypt with the new context self.tree_kem_public .update_hashes(&[self_index], cipher_suite_provider) .await?; context.tree_hash = self .tree_kem_public .tree_hash(cipher_suite_provider) .await?; let context_bytes = context.mls_encode_to_vec()?; let node_updates = self .encrypt_path_secrets( path, &path_secrets, &context_bytes, cipher_suite_provider, excluding, ) .await?; #[cfg(test)] let node_updates = (commit_modifiers.modify_path)(node_updates); // Create an update path with the new node and parent node updates let update_path = UpdatePath { leaf_node: update_path_leaf, nodes: node_updates, }; Ok(EncapGeneration { update_path, path_secrets, commit_secret: secret_generator.next_secret().await?, }) } #[cfg(any(mls_build_async, not(feature = "rayon")))] #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn encrypt_path_secrets( &self, path: Vec>, path_secrets: &[Option], context_bytes: &[u8], cipher_suite: &P, excluding: &[LeafIndex], ) -> Result, MlsError> { let excluding = excluding.iter().copied().map(NodeIndex::from); #[cfg(feature = "std")] let excluding = excluding.collect::>(); #[cfg(not(feature = "std"))] let excluding = excluding.collect::>(); let mut node_updates = Vec::new(); for (index, path_secret) in path.into_iter().zip(path_secrets.iter()) { if let Some(path_secret) = path_secret { node_updates.push( self.encrypt_copath_node_resolution( cipher_suite, path_secret, index.copath, context_bytes, &excluding, ) .await?, ); } } Ok(node_updates) } #[cfg(all(not(mls_build_async), feature = "rayon"))] fn encrypt_path_secrets( &self, path: Vec>, path_secrets: &[Option], context_bytes: &[u8], cipher_suite: &P, excluding: &[LeafIndex], ) -> Result, MlsError> { let excluding = excluding.iter().copied().map(NodeIndex::from); #[cfg(feature = "std")] let excluding = excluding.collect::>(); #[cfg(not(feature = "std"))] let excluding = excluding.collect::>(); path.into_par_iter() .zip(path_secrets.par_iter()) .filter_map(|(node, path_secret)| { path_secret.as_ref().map(|path_secret| { self.encrypt_copath_node_resolution( cipher_suite, path_secret, node.copath, context_bytes, &excluding, ) }) }) .collect() } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] pub async fn decap( self, sender_index: LeafIndex, update_path: &ValidatedUpdatePath, added_leaves: &[LeafIndex], context_bytes: &[u8], cipher_suite_provider: &CP, ) -> Result where CP: CipherSuiteProvider, { let self_index = self.private_key.self_index; let lca_index = tree_math::leaf_lca_level(self_index.into(), sender_index.into()) as usize - 2; let mut path = self.tree_kem_public.nodes.direct_copath(self_index); let leaf = CopathNode::new(self_index.into(), 0); path.insert(0, leaf); let resolved_pos = self.find_resolved_pos(&path, lca_index)?; let ct_pos = self.find_ciphertext_pos(path[lca_index].path, path[resolved_pos].path, added_leaves)?; let lca_node = update_path.nodes[lca_index] .as_ref() .ok_or(MlsError::LcaNotFoundInDirectPath)?; let ct = lca_node .encrypted_path_secret .get(ct_pos) .ok_or(MlsError::LcaNotFoundInDirectPath)?; let secret = self.private_key.secret_keys[resolved_pos] .as_ref() .ok_or(MlsError::UpdateErrorNoSecretKey)?; let public = self .tree_kem_public .nodes .borrow_node(path[resolved_pos].path)? .as_ref() .ok_or(MlsError::UpdateErrorNoSecretKey)? .public_key(); let lca_path_secret = PathSecret::decrypt(cipher_suite_provider, secret, public, context_bytes, ct).await?; // Derive the rest of the secrets for the tree and assign to the proper nodes let mut node_secret_gen = PathSecretGenerator::starting_with(cipher_suite_provider, lca_path_secret); // Update secrets based on the decrypted path secret in the update self.private_key.secret_keys.resize(path.len() + 1, None); for (i, update) in update_path.nodes.iter().enumerate().skip(lca_index) { if let Some(update) = update { let secret = node_secret_gen.next_secret().await?; // Verify the private key we calculated properly matches the public key we inserted into the tree. This guarantees // that we will be able to decrypt later. let (hpke_private, hpke_public) = secret.to_hpke_key_pair(cipher_suite_provider).await?; if hpke_public != update.public_key { return Err(MlsError::PubKeyMismatch); } self.private_key.secret_keys[i + 1] = Some(hpke_private); } else { self.private_key.secret_keys[i + 1] = None; } } node_secret_gen.next_secret().await } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn encrypt_copath_node_resolution( &self, cipher_suite_provider: &P, path_secret: &PathSecret, copath_index: NodeIndex, context: &[u8], #[cfg(feature = "std")] excluding: &HashSet, #[cfg(not(feature = "std"))] excluding: &[NodeIndex], ) -> Result { let reso = self .tree_kem_public .nodes .get_resolution_index(copath_index)?; let make_ctxt = |idx| async move { let node = self .tree_kem_public .nodes .borrow_node(idx)? .as_non_empty()?; path_secret .encrypt(cipher_suite_provider, node.public_key(), context) .await }; let ctxts = wrap_iter(reso).filter(|&idx| async move { !excluding.contains(&idx) }); #[cfg(not(mls_build_async))] let ctxts = ctxts.map(make_ctxt); #[cfg(mls_build_async)] let ctxts = ctxts.then(make_ctxt); let ctxts = ctxts.try_collect().await?; let path_index = copath_index .parent_sibling(&self.tree_kem_public.total_leaf_count()) .ok_or(MlsError::ExpectedNode)? .parent; Ok(UpdatePathNode { public_key: self .tree_kem_public .nodes .borrow_as_parent(path_index)? .public_key .clone(), encrypted_path_secret: ctxts, }) } #[inline] fn find_resolved_pos( &self, path: &[CopathNode], mut lca_index: usize, ) -> Result { while self.tree_kem_public.nodes.is_blank(path[lca_index].path)? { lca_index -= 1; } // If we don't have the key, we should be an unmerged leaf at the resolved node. (If // we're not, an error will be thrown later.) if self.private_key.secret_keys[lca_index].is_none() { lca_index = 0; } Ok(lca_index) } #[inline] fn find_ciphertext_pos( &self, lca: NodeIndex, resolved: NodeIndex, excluding: &[LeafIndex], ) -> Result { let reso = self.tree_kem_public.nodes.get_resolution_index(lca)?; let (ct_pos, _) = reso .iter() .filter(|idx| **idx % 2 == 1 || !excluding.contains(&LeafIndex(**idx / 2))) .find_position(|idx| idx == &&resolved) .ok_or(MlsError::UpdateErrorNoSecretKey)?; Ok(ct_pos) } } #[cfg(test)] mod tests { use super::{tree_math, TreeKem}; use crate::{ cipher_suite::CipherSuite, client::test_utils::TEST_CIPHER_SUITE, crypto::test_utils::{test_cipher_suite_provider, TestCryptoProvider}, extension::test_utils::TestExtension, group::test_utils::{get_test_group_context, random_bytes}, identity::basic::BasicIdentityProvider, tree_kem::{ leaf_node::{ test_utils::{get_basic_test_node_sig_key, get_test_capabilities}, ConfigProperties, }, node::LeafIndex, Capabilities, TreeKemPrivate, TreeKemPublic, UpdatePath, ValidatedUpdatePath, }, ExtensionList, }; use alloc::{format, vec, vec::Vec}; use mls_rs_codec::MlsEncode; use mls_rs_core::crypto::CipherSuiteProvider; use tree_math::TreeIndex; // Verify that the tree is in the correct state after generating an update path fn verify_tree_update_path( tree: &TreeKemPublic, update_path: &UpdatePath, index: LeafIndex, capabilities: Option, extensions: Option, ) { // Make sure the update path is based on the direct path of the sender let direct_path = tree.nodes.direct_copath(index); for (i, n) in direct_path.iter().enumerate() { assert_eq!( *tree .nodes .borrow_node(n.path) .unwrap() .as_ref() .unwrap() .public_key(), update_path.nodes[i].public_key ); } // Verify that the leaf from the update path has been installed assert_eq!( tree.nodes.borrow_as_leaf(index).unwrap(), &update_path.leaf_node ); // Verify that updated capabilities were installed if let Some(capabilities) = capabilities { assert_eq!(update_path.leaf_node.capabilities, capabilities); } // Verify that update extensions were installed if let Some(extensions) = extensions { assert_eq!(update_path.leaf_node.extensions, extensions); } // Verify that we have a public keys up to the root let root = tree.total_leaf_count().root(); assert!(tree.nodes.borrow_node(root).unwrap().is_some()); } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn verify_tree_private_path( cipher_suite: &CipherSuite, public_tree: &TreeKemPublic, private_tree: &TreeKemPrivate, index: LeafIndex, ) { let provider = test_cipher_suite_provider(*cipher_suite); assert_eq!(private_tree.self_index, index); // Make sure we have private values along the direct path, and the public keys match let path_iter = public_tree .nodes .direct_copath(index) .into_iter() .enumerate(); for (i, n) in path_iter { let secret_key = private_tree.secret_keys[i + 1].as_ref().unwrap(); let public_key = public_tree .nodes .borrow_node(n.path) .unwrap() .as_ref() .unwrap() .public_key(); let test_data = random_bytes(32); let sealed = provider .hpke_seal(public_key, &[], None, &test_data) .await .unwrap(); let opened = provider .hpke_open(&sealed, secret_key, public_key, &[], None) .await .unwrap(); assert_eq!(test_data, opened); } } #[cfg_attr(not(mls_build_async), maybe_async::must_be_sync)] async fn encap_decap( cipher_suite: CipherSuite, size: usize, capabilities: Option, extensions: Option, ) { let cipher_suite_provider = test_cipher_suite_provider(cipher_suite); // Generate signing keys and key package generations, and private keys for multiple // participants in order to set up state let mut leaf_nodes = Vec::new(); let mut private_keys = Vec::new(); for index in 1..size { let (leaf_node, hpke_secret, _) = get_basic_test_node_sig_key(cipher_suite, &format!("{index}")).await; let private_key = TreeKemPrivate::new_self_leaf(LeafIndex(index as u32), hpke_secret); leaf_nodes.push(leaf_node); private_keys.push(private_key); } let (encap_node, encap_hpke_secret, encap_signer) = get_basic_test_node_sig_key(cipher_suite, "encap").await; // Build a test tree we can clone for all leaf nodes let (mut test_tree, mut encap_private_key) = TreeKemPublic::derive( encap_node, encap_hpke_secret, &BasicIdentityProvider, &Default::default(), ) .await .unwrap(); test_tree .add_leaves(leaf_nodes, &BasicIdentityProvider, &cipher_suite_provider) .await .unwrap(); // Clone the tree for the first leaf, generate a new key package for that leaf let mut encap_tree = test_tree.clone(); let update_leaf_properties = ConfigProperties { capabilities: capabilities.clone().unwrap_or_else(get_test_capabilities), extensions: extensions.clone().unwrap_or_default(), }; // Perform the encap function let encap_gen = TreeKem::new(&mut encap_tree, &mut encap_private_key) .encap( &mut get_test_group_context(42, cipher_suite).await, &[], &encap_signer, update_leaf_properties, None, &cipher_suite_provider, #[cfg(test)] &Default::default(), ) .await .unwrap(); // Verify that the state of the tree matches the produced update path verify_tree_update_path( &encap_tree, &encap_gen.update_path, LeafIndex(0), capabilities, extensions, ); // Verify that the private key matches the data in the public key verify_tree_private_path(&cipher_suite, &encap_tree, &encap_private_key, LeafIndex(0)) .await; let filtered = test_tree.nodes.filtered(LeafIndex(0)).unwrap(); let mut unfiltered_nodes = vec![None; filtered.len()]; filtered .into_iter() .enumerate() .filter(|(_, f)| !*f) .zip(encap_gen.update_path.nodes.iter()) .for_each(|((i, _), node)| { unfiltered_nodes[i] = Some(node.clone()); }); // Apply the update path to the rest of the leaf nodes using the decap function let validated_update_path = ValidatedUpdatePath { leaf_node: encap_gen.update_path.leaf_node, nodes: unfiltered_nodes, }; encap_tree .update_hashes(&[LeafIndex(0)], &cipher_suite_provider) .await .unwrap(); let mut receiver_trees: Vec = (1..size).map(|_| test_tree.clone()).collect(); for (i, tree) in receiver_trees.iter_mut().enumerate() { tree.apply_update_path( LeafIndex(0), &validated_update_path, &Default::default(), BasicIdentityProvider, &cipher_suite_provider, ) .await .unwrap(); let mut context = get_test_group_context(42, cipher_suite).await; context.tree_hash = tree.tree_hash(&cipher_suite_provider).await.unwrap(); TreeKem::new(tree, &mut private_keys[i]) .decap( LeafIndex(0), &validated_update_path, &[], &context.mls_encode_to_vec().unwrap(), &cipher_suite_provider, ) .await .unwrap(); tree.update_hashes(&[LeafIndex(0)], &cipher_suite_provider) .await .unwrap(); assert_eq!(tree, &encap_tree); } } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_encap_decap() { for cipher_suite in TestCryptoProvider::all_supported_cipher_suites() { encap_decap(cipher_suite, 10, None, None).await; } } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_encap_capabilities() { let cipher_suite = TEST_CIPHER_SUITE; let mut capabilities = get_test_capabilities(); capabilities.extensions.push(42.into()); encap_decap(cipher_suite, 10, Some(capabilities.clone()), None).await; } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_encap_extensions() { let cipher_suite = TEST_CIPHER_SUITE; let mut extensions = ExtensionList::default(); extensions.set_from(TestExtension { foo: 10 }).unwrap(); encap_decap(cipher_suite, 10, None, Some(extensions)).await; } #[maybe_async::test(not(mls_build_async), async(mls_build_async, crate::futures_test))] async fn test_encap_capabilities_extensions() { let cipher_suite = TEST_CIPHER_SUITE; let mut capabilities = get_test_capabilities(); capabilities.extensions.push(42.into()); let mut extensions = ExtensionList::default(); extensions.set_from(TestExtension { foo: 10 }).unwrap(); encap_decap(cipher_suite, 10, Some(capabilities), Some(extensions)).await; } }